Source code for tike.ptycho.solvers.combined

import logging

import numpy as np

from tike.linalg import orthogonalize_gs
from tike.opt import conjugate_gradient, batch_indicies, get_batch
from ..position import update_positions_pd

logger = logging.getLogger(__name__)


[docs]def cgrad( op, comm, data, probe, scan, psi, recover_psi=True, recover_probe=True, recover_positions=False, cg_iter=4, cost=None, eigen_probe=None, eigen_weights=None, num_batch=1, subset_is_random=None, step_length=1, probe_is_orthogonal=False, ): # yapf: disable """Solve the ptychography problem using conjugate gradient. Parameters ---------- op : :py:class:`tike.operators.Ptycho` A ptychography operator. comm : :py:class:`tike.communicators.Comm` An object which manages communications between both GPUs and nodes. .. seealso:: :py:mod:`tike.ptycho` """ cost = np.inf # Unique batch for each device batches = [ batch_indicies(s.shape[-2], num_batch, subset_is_random) for s in scan ] for n in range(num_batch): bdata = comm.pool.map(get_batch, data, batches, n=n) bscan = comm.pool.map(get_batch, scan, batches, n=n) if recover_psi: psi, cost = _update_object( op, comm, bdata, psi, bscan, probe, num_iter=cg_iter, step_length=step_length, ) if recover_probe: probe, cost = _update_probe( op, comm, bdata, psi, bscan, probe, num_iter=cg_iter, step_length=step_length, probe_is_orthogonal=probe_is_orthogonal, mode=list(range(probe[0].shape[-3])), ) if recover_positions and comm.pool.num_workers == 1: bscan, cost = update_positions_pd( op, comm.pool.gather(bdata, axis=1), psi[0], probe[0], comm.pool.gather(bscan, axis=1), ) bscan = comm.pool.bcast(bscan) # TODO: Assign bscan into scan when positions are updated return {'psi': psi, 'probe': probe, 'cost': cost, 'scan': scan}
def _update_probe(op, comm, data, psi, scan, probe, num_iter, step_length, probe_is_orthogonal, mode): """Solve the probe recovery problem.""" def cost_function(probe): cost_out = comm.pool.map(op.cost, data, psi, scan, probe) if comm.use_mpi: return comm.Allreduce_reduce(cost_out, 'cpu') else: return comm.reduce(cost_out, 'cpu') def grad(probe): grad_list = comm.pool.map( op.grad_probe, data, psi, scan, probe, mode=mode, ) if comm.use_mpi: return comm.Allreduce_reduce(grad_list, 'gpu') else: return comm.reduce(grad_list, 'gpu') def dir_multi(dir): """Scatter dir to all GPUs""" return comm.pool.bcast(dir) def update_multi(x, gamma, d): def f(x, d): return x[..., mode, :, :] + gamma * d return comm.pool.map(f, x, d) probe, cost = conjugate_gradient( op.xp, x=probe, cost_function=cost_function, grad=grad, dir_multi=dir_multi, update_multi=update_multi, num_iter=num_iter, step_length=step_length, ) if probe[0].shape[-3] > 1 and probe_is_orthogonal: probe = comm.pool.map(orthogonalize_gs, probe, axis=(-2, -1)) logger.info('%10s cost is %+12.5e', 'probe', cost) return probe, cost def _update_object(op, comm, data, psi, scan, probe, num_iter, step_length): """Solve the object recovery problem.""" def cost_function_multi(psi, **kwargs): cost_out = comm.pool.map(op.cost, data, psi, scan, probe) if comm.use_mpi: return comm.Allreduce_reduce(cost_out, 'cpu') else: return comm.reduce(cost_out, 'cpu') def grad_multi(psi): grad_list = comm.pool.map(op.grad_psi, data, psi, scan, probe) if comm.use_mpi: return comm.Allreduce_reduce(grad_list, 'gpu') else: return comm.reduce(grad_list, 'gpu') def dir_multi(dir): """Scatter dir to all GPUs""" return comm.pool.bcast(dir) def update_multi(psi, gamma, dir): def f(psi, dir): return psi + gamma * dir return list(comm.pool.map(f, psi, dir)) psi, cost = conjugate_gradient( op.xp, x=psi, cost_function=cost_function_multi, grad=grad_multi, dir_multi=dir_multi, update_multi=update_multi, num_iter=num_iter, step_length=step_length, ) logger.info('%10s cost is %+12.5e', 'object', cost) return psi, cost