Source code for tike.ptycho.solvers.divided

import logging

import cupy as cp

from tike.linalg import lstsq, projection, norm, orthogonalize_gs
from tike.opt import batch_indicies, get_batch, put_batch

from ..position import update_positions_pd
from ..probe import orthogonalize_eig, get_varying_probe, update_eigen_probe

logger = logging.getLogger(__name__)


[docs]def lstsq_grad( op, comm, data, probe, scan, psi, recover_psi=True, recover_probe=False, recover_positions=False, cg_iter=4, cost=None, eigen_probe=None, eigen_weights=None, num_batch=1, subset_is_random=True, probe_is_orthogonal=False, ): # yapf: disable """Solve the ptychography problem using Odstrcil et al's approach. The near- and farfield- ptychography problems are solved separately using gradient descent in the farfield and linear-least-squares in the nearfield. Parameters ---------- op : tike.operators.Ptycho A ptychography operator. comm : tike.communicators.Comm An object which manages communications between GPUs and nodes. References ---------- Michal Odstrcil, Andreas Menzel, and Manuel Guizar-Sicaros. Iterative least-squares solver for generalized maximum-likelihood ptychography. Optics Express. 2018. """ # Unique batch for each device batches = [ batch_indicies(s.shape[-2], num_batch, subset_is_random) for s in scan ] for n in range(num_batch): bdata = comm.pool.map(get_batch, data, batches, n=n) bscan = comm.pool.map(get_batch, scan, batches, n=n) if isinstance(eigen_probe, list): beigen_weights = comm.pool.map( get_batch, eigen_weights, batches, n=n, ) beigen_probe = eigen_probe else: beigen_probe = [None] * comm.pool.num_workers beigen_weights = [None] * comm.pool.num_workers unique_probe = comm.pool.map( get_varying_probe, probe, beigen_probe, beigen_weights, ) nearplane, cost = zip(*comm.pool.map( _update_wavefront, bdata, unique_probe, bscan, psi, op=op, )) if comm.use_mpi: cost = comm.Allreduce_reduce(cost, 'cpu') else: cost = comm.reduce(cost, 'cpu') ( psi, probe, beigen_probe, beigen_weights, ) = _update_nearplane( op, comm, nearplane, psi, bscan, probe, unique_probe, beigen_probe, beigen_weights, recover_psi, recover_probe, probe_is_orthogonal, ) if isinstance(eigen_probe, list): comm.pool.map( put_batch, beigen_weights, eigen_weights, batches, n=n, ) result = { 'psi': psi, 'probe': probe, 'cost': cost, 'scan': scan, } if isinstance(eigen_probe, list): result['eigen_probe'] = eigen_probe result['eigen_weights'] = eigen_weights return result
def _get_nearplane_gradients(nearplane, psi, scan_, probe, unique_probe, op, m, recover_psi, recover_probe): pad, end = op.diffraction.pad, op.diffraction.end patches = op.diffraction.patch.fwd( patches=cp.zeros(nearplane[..., m:m + 1, pad:end, pad:end].shape, dtype='complex64')[..., 0, 0, :, :], images=psi, positions=scan_, )[..., None, None, :, :] # χ (diff) is the target for the nearplane problem; the difference # between the desired nearplane and the current nearplane that we wish # to minimize. diff = nearplane[..., m:m + 1, pad:end, pad:end] - unique_probe[..., m:m + 1, :, :] * patches logger.info('%10s cost is %+12.5e', 'nearplane', norm(diff)) if recover_psi: grad_psi = cp.conj(unique_probe[..., m:m + 1, :, :]) * diff # (25b) Common object gradient. Use a weighted (normalized) sum # instead of division as described in publication to improve # numerical stability. common_grad_psi = op.diffraction.patch.adj( patches=grad_psi[..., 0, 0, :, :], images=cp.zeros(psi.shape, dtype='complex64'), positions=scan_, ) dOP = op.diffraction.patch.fwd( patches=cp.zeros(patches.shape, dtype='complex64')[..., 0, 0, :, :], images=common_grad_psi, positions=scan_, )[..., None, None, :, :] * unique_probe[..., m:m + 1, :, :] A1 = cp.sum((dOP * dOP.conj()).real + 0.5, axis=(-2, -1)) A1 += 0.5 * cp.mean(A1, axis=-3, keepdims=True) b1 = cp.sum((dOP.conj() * diff).real, axis=(-2, -1)) if recover_probe: grad_probe = cp.conj(patches) * diff # (25a) Common probe gradient. Use simple average instead of # division as described in publication because that's what # ptychoshelves does common_grad_probe = cp.mean( grad_probe, axis=-5, keepdims=True, ) dPO = common_grad_probe * patches A4 = cp.sum((dPO * dPO.conj()).real + 0.5, axis=(-2, -1)) A4 += 0.5 * cp.mean(A4, axis=-3, keepdims=True) b2 = cp.sum((dPO.conj() * diff).real, axis=(-2, -1)) # (22) Use least-squares to find the optimal step sizes simultaneously if recover_psi and recover_probe: A2 = cp.sum((dOP * dPO.conj()), axis=(-2, -1)) A3 = A2.conj() determinant = A1 * A4 - A2 * A3 x1 = -cp.conj(A2 * b2 - A4 * b1) / determinant x2 = cp.conj(A1 * b2 - A3 * b1) / determinant elif recover_psi: x1 = b1 / A1 elif recover_probe: x2 = b2 / A4 weighted_step_psi = None weighted_step_probe = None if recover_psi: step = x1[..., None, None] # (27b) Object update weighted_step_psi = cp.mean(step, keepdims=True, axis=-5) if recover_probe: step = x2[..., None, None] weighted_step_probe = cp.mean(step, axis=-5, keepdims=True) if __debug__: patches = op.diffraction.patch.fwd( patches=cp.zeros(nearplane[..., m:m + 1, pad:end, pad:end].shape, dtype='complex64')[..., 0, 0, :, :], images=psi, positions=scan_, )[..., None, None, :, :] logger.info( '%10s cost is %+12.5e', 'nearplane', norm(probe[..., m:m + 1, :, :] * patches - nearplane[..., m:m + 1, pad:end, pad:end])) return (patches, diff, grad_probe, common_grad_psi, common_grad_probe, weighted_step_psi, weighted_step_probe) def _get_residuals(grad_probe, grad_probe_mean): return grad_probe - grad_probe_mean def _update_residuals(R, eigen_probe, axis, c, m): R -= projection( R, eigen_probe[..., c:c + 1, m:m + 1, :, :], axis=axis, ) return R def _update_nearplane(op, comm, nearplane, psi, scan_, probe, unique_probe, eigen_probe, eigen_weights, recover_psi, recover_probe, probe_is_orthogonal): for m in range(probe[0].shape[-3]): ( patches, diff, grad_probe, common_grad_psi, common_grad_probe, weighted_step_psi, weighted_step_probe, ) = (list(a) for a in zip(*comm.pool.map( _get_nearplane_gradients, nearplane, psi, scan_, probe, unique_probe, op=op, m=m, recover_psi=recover_psi, recover_probe=recover_probe, ))) if recover_probe and eigen_probe[0] is not None: logger.info('Updating eigen probes') # (30) residual probe updates grad_probe_mean = comm.pool.bcast( comm.pool.reduce_mean( common_grad_probe, axis=-5, )) R = comm.pool.map(_get_residuals, grad_probe, grad_probe_mean) for c in range(eigen_probe[0].shape[-4]): a, b = update_eigen_probe( comm, R, [p[..., c:c + 1, m:m + 1, :, :] for p in eigen_probe], [w[..., c, m] for w in eigen_weights], patches, diff, β=0.01, # TODO: Adjust according to mini-batch size ) for p, w, x, y in zip(eigen_probe, eigen_weights, a, b): p[..., c:c + 1, m:m + 1, :, :] = x w[..., c, m] = y if eigen_probe[0].shape[-4] <= c + 1: # Subtract projection of R onto new probe from R R = comm.pool.map( _update_residuals, R, eigen_probe, axis=(-2, -1), c=c, m=m, ) # Update each direction if recover_psi: weighted_step_psi[0] = comm.pool.reduce_mean( weighted_step_psi, axis=-5, )[..., 0, 0, 0] common_grad_psi[0] = comm.pool.reduce_gpu(common_grad_psi) psi[0] += weighted_step_psi[0] * common_grad_psi[0] psi = comm.pool.bcast(psi[0]) if recover_probe: weighted_step_probe[0] = comm.pool.reduce_mean( weighted_step_probe, axis=-5, ) common_grad_probe[0] = comm.pool.reduce_mean( common_grad_probe, axis=-5, ) # (27a) Probe update probe[0][..., m:m + 1, :, :] += (weighted_step_probe[0] * common_grad_probe[0]) if probe[0].shape[-3] > 1 and probe_is_orthogonal: probe[0] = orthogonalize_gs(probe[0], axis=(-2, -1)) probe = comm.pool.bcast(probe[0]) return psi, probe, eigen_probe, eigen_weights def _update_wavefront(data, varying_probe, scan, psi, op): # Compute the diffraction patterns for all of the probe modes at once. # We need access to all of the modes of a position to solve the phase # problem. The Ptycho operator doesn't do this natively, so it's messy. patches = cp.zeros(data.shape, dtype='complex64') patches = op.diffraction.patch.fwd( patches=patches, images=psi, positions=scan, patch_width=varying_probe.shape[-1], ) patches = patches.reshape(*scan.shape[:-1], 1, 1, op.detector_shape, op.detector_shape) nearplane = cp.tile(patches, reps=(1, 1, 1, varying_probe.shape[-3], 1, 1)) pad, end = op.diffraction.pad, op.diffraction.end nearplane[..., pad:end, pad:end] *= varying_probe # Solve the farplane phase problem ---------------------------------------- farplane = op.propagation.fwd(nearplane, overwrite=True) intensity = cp.sum(cp.square(cp.abs(farplane)), axis=(2, 3)) cost = op.propagation.cost(data, intensity) logger.info('%10s cost is %+12.5e', 'farplane', cost) farplane -= 0.5 * op.propagation.grad(data, farplane, intensity) if __debug__: intensity = cp.sum(cp.square(cp.abs(farplane)), axis=(2, 3)) cost = op.propagation.cost(data, intensity) logger.info('%10s cost is %+12.5e', 'farplane', cost) # TODO: Only compute cost every 20 iterations or on a log sampling? farplane = op.propagation.adj(farplane, overwrite=True) return farplane, cost