Source code for tike.ptycho.solvers.rpie

import logging

import cupy as cp
import cupyx.scipy.stats
import numpy.typing as npt

import tike.communicators
import tike.linalg
import tike.operators
import tike.opt
import tike.ptycho.object
import tike.ptycho.position
import tike.ptycho.probe
import tike.ptycho.exitwave
import tike.precision
import tike.random

from .options import *
from .lstsq import _momentum_checked

logger = logging.getLogger(__name__)


[docs]def rpie( op: tike.operators.Ptycho, comm: tike.communicators.Comm, data: typing.List[npt.NDArray], batches: typing.List[typing.List[npt.NDArray[cp.intc]]], *, parameters: PtychoParameters, epoch: int, ) -> PtychoParameters: """Solve the ptychography problem using regularized ptychographical engine. The rPIE update direction can be shown to be equivalent to a conventional gradient descent direction but rescaled by the preconditioner term. i.e. If the rPIE step size (alpha) is 0 and the preconditioner is zero, we have the vanilla gradient descent direction. Parameters ---------- op : :py:class:`tike.operators.Ptycho` A ptychography operator. comm : :py:class:`tike.communicators.Comm` An object which manages communications between GPUs and nodes. data : list((FRAME, WIDE, HIGH) float32, ...) A list of unique CuPy arrays for each device containing the intensity (square of the absolute value) of the propagated wavefront; i.e. what the detector records. FFT-shifted so the diffraction peak is at the corners. batches : list(list((BATCH_SIZE, ) int, ...), ...) A list of list of indices along the FRAME axis of `data` for each device which define the batches of `data` to process simultaneously. parameters : :py:class:`tike.ptycho.solvers.PtychoParameters` An object which contains reconstruction parameters. Returns ------- result : :py:class:`tike.ptycho.solvers.PtychoParameters` An object which contains the updated reconstruction parameters. References ---------- Maiden, Andrew M., and John M. Rodenburg. 2009. “An Improved Ptychographical Phase Retrieval Algorithm for Diffractive Imaging.” Ultramicroscopy 109 (10): 1256–62. https://doi.org/10.1016/j.ultramic.2009.05.012. Andrew Maiden, Daniel Johnson, and Peng Li, "Further improvements to the ptychographical iterative engine," Optica 4, 736-745 (2017) https://doi.org/10.1364/OPTICA.4.000736 .. seealso:: :py:mod:`tike.ptycho` """ probe = parameters.probe scan = parameters.scan psi = parameters.psi algorithm_options = parameters.algorithm_options exitwave_options = parameters.exitwave_options probe_options = parameters.probe_options if probe_options is None: recover_probe = False else: recover_probe = probe_options.recover_probe position_options = parameters.position_options object_options = parameters.object_options eigen_probe = parameters.eigen_probe eigen_weights = parameters.eigen_weights if eigen_probe is None: beigen_probe = [None] * comm.pool.num_workers else: beigen_probe = eigen_probe if eigen_weights is None: beigen_weights = [None] * comm.pool.num_workers else: beigen_weights = eigen_weights if parameters.algorithm_options.batch_method == 'compact': order = range else: order = tike.random.randomizer_np.permutation psi_update_numerator = [None] * comm.pool.num_workers probe_update_numerator = [None] * comm.pool.num_workers position_update_numerator = [None] * comm.pool.num_workers position_update_denominator = [None] * comm.pool.num_workers batch_cost: typing.List[float] = [] for n in order(algorithm_options.num_batch): ( cost, psi_update_numerator, probe_update_numerator, position_update_numerator, position_update_denominator, beigen_weights, ) = (list(a) for a in zip(*comm.pool.map( _get_nearplane_gradients, data, scan, psi, probe, exitwave_options.measured_pixels, psi_update_numerator, probe_update_numerator, position_update_numerator, position_update_denominator, beigen_probe, beigen_weights, batches, comm.streams, n=n, op=op, object_options=object_options, probe_options=probe_options, recover_probe=recover_probe, position_options=position_options, exitwave_options=exitwave_options, ))) batch_cost.append(comm.Allreduce_mean(cost, axis=None).get()) if algorithm_options.batch_method != 'compact': ( psi, probe, ) = _update( comm, psi, probe, psi_update_numerator, probe_update_numerator, object_options, probe_options, recover_probe, algorithm_options, ) psi_update_numerator = [None] * comm.pool.num_workers probe_update_numerator = [None] * comm.pool.num_workers algorithm_options.costs.append(batch_cost) if position_options is not None: ( scan, position_options, ) = (list(a) for a in zip(*comm.pool.map( _update_position, scan, position_options, position_update_numerator, position_update_denominator, max_shift=probe[0].shape[-1] * 0.1, alpha=algorithm_options.alpha, epoch=epoch, ))) if algorithm_options.batch_method == 'compact': ( psi, probe, ) = _update( comm, psi, probe, psi_update_numerator, probe_update_numerator, object_options, probe_options, recover_probe, algorithm_options, errors=list(np.mean(x) for x in algorithm_options.costs[-3:]), ) if eigen_weights is not None: eigen_weights = comm.pool.map( _normalize_eigen_weights, beigen_weights, ) parameters.probe = probe parameters.psi = psi parameters.scan = scan parameters.algorithm_options = algorithm_options parameters.probe_options = probe_options parameters.object_options = object_options parameters.position_options = position_options parameters.eigen_weights = eigen_weights return parameters
def _normalize_eigen_weights(eigen_weights): return eigen_weights / tike.linalg.mnorm( eigen_weights, axis=(-3), keepdims=True, ) def _update( comm: tike.communicators.Comm, psi: npt.NDArray[cp.csingle], probe: npt.NDArray[cp.csingle], psi_update_numerator: npt.NDArray[cp.csingle], probe_update_numerator: npt.NDArray[cp.csingle], object_options: ObjectOptions, probe_options: ProbeOptions, recover_probe: bool, algorithm_options: RpieOptions, errors: typing.Union[None, typing.List[float]] = None, ): if object_options: psi_update_numerator = comm.Allreduce_reduce_gpu( psi_update_numerator)[0] dpsi = psi_update_numerator deno = ( (1 - algorithm_options.alpha) * object_options.preconditioner[0] + algorithm_options.alpha * object_options.preconditioner[0].max( axis=(-2, -1), keepdims=True, )) psi[0] = psi[0] + dpsi / deno if object_options.use_adaptive_moment: if errors: ( dpsi, object_options.v, object_options.m, ) = _momentum_checked( g=dpsi, v=object_options.v, m=object_options.m, mdecay=object_options.mdecay, errors=errors, memory_length=3, ) else: ( dpsi, object_options.v, object_options.m, ) = tike.opt.adam( g=dpsi, v=object_options.v, m=object_options.m, vdecay=object_options.vdecay, mdecay=object_options.mdecay, ) psi[0] = psi[0] + dpsi / deno psi = comm.pool.bcast([psi[0]]) if recover_probe: probe_update_numerator = comm.Allreduce_reduce_gpu( probe_update_numerator)[0] b0 = tike.ptycho.probe.finite_probe_support( probe[0], p=probe_options.probe_support, radius=probe_options.probe_support_radius, degree=probe_options.probe_support_degree, ) b1 = probe_options.additional_probe_penalty * cp.linspace( 0, 1, probe[0].shape[-3], dtype='float32')[..., None, None] dprobe = (probe_update_numerator - (b1 + b0) * probe[0]) deno = ( (1 - algorithm_options.alpha) * probe_options.preconditioner[0] + algorithm_options.alpha * probe_options.preconditioner[0].max( axis=(-2, -1), keepdims=True, ) + b0 + b1) probe[0] = probe[0] + dprobe / deno if probe_options.use_adaptive_moment: # ptychoshelves only applies momentum to the main probe mode = 0 if errors: ( dprobe[0, 0, mode, :, :], probe_options.v, probe_options.m, ) = _momentum_checked( g=(dprobe)[0, 0, mode, :, :], v=probe_options.v, m=probe_options.m, mdecay=probe_options.mdecay, errors=errors, memory_length=3, ) else: ( dprobe[0, 0, mode, :, :], probe_options.v, probe_options.m, ) = tike.opt.adam( g=(dprobe)[0, 0, mode, :, :], v=probe_options.v, m=probe_options.m, vdecay=probe_options.vdecay, mdecay=probe_options.mdecay, ) probe[0] = probe[0] + dprobe / deno probe = comm.pool.bcast([probe[0]]) return psi, probe def _get_nearplane_gradients( data: npt.NDArray, scan: npt.NDArray, psi: npt.NDArray, probe: npt.NDArray, measured_pixels: npt.NDArray, psi_update_numerator: typing.Union[None, npt.NDArray], probe_update_numerator: typing.Union[None, npt.NDArray], position_update_numerator: typing.Union[None, npt.NDArray], position_update_denominator: typing.Union[None, npt.NDArray], eigen_probe: typing.Union[None, npt.NDArray], eigen_weights: typing.Union[None, npt.NDArray], batches: typing.List[typing.List[int]], streams: typing.List[cp.cuda.Stream], *, n: int, op: tike.operators.Ptycho, object_options: typing.Union[None, ObjectOptions] = None, probe_options: typing.Union[None, ProbeOptions] = None, recover_probe: bool, position_options: typing.Union[None, PositionOptions], exitwave_options: ExitWaveOptions, ) -> typing.List[npt.NDArray]: cost = 0.0 count = 1.0 / len(batches[n]) psi_update_numerator = cp.zeros_like( psi) if psi_update_numerator is None else psi_update_numerator probe_update_numerator = cp.zeros_like( probe) if probe_update_numerator is None else probe_update_numerator position_update_numerator = cp.empty_like( scan ) if position_update_numerator is None else position_update_numerator position_update_denominator = cp.empty_like( scan ) if position_update_denominator is None else position_update_denominator def keep_some_args_constant( ind_args, lo: int, hi: int, ): (data,) = ind_args nonlocal cost, psi_update_numerator, probe_update_numerator nonlocal position_update_numerator, position_update_denominator nonlocal eigen_weights, scan unique_probe = tike.ptycho.probe.get_varying_probe( probe, eigen_probe, eigen_weights[lo:hi] if eigen_weights is not None else None, ) farplane = op.fwd(probe=unique_probe, scan=scan[lo:hi], psi=psi) intensity = cp.sum( cp.square(cp.abs(farplane)), axis=list(range(1, farplane.ndim - 2)), ) each_cost = getattr( tike.operators, f'{exitwave_options.noise_model}_each_pattern', )( data[:, measured_pixels][:, None, :], intensity[:, measured_pixels][:, None, :], ) cost += cp.sum(each_cost) * count if exitwave_options.noise_model == 'poisson': xi = (1 - data / intensity)[:, None, None, :, :] grad_cost = farplane * xi step_length = cp.full( shape=(farplane.shape[0], 1, farplane.shape[2], 1, 1), fill_value=exitwave_options.step_length_start, ) if exitwave_options.step_length_usemodes == 'dominant_mode': step_length = tike.ptycho.exitwave.poisson_steplength_dominant_mode( xi, intensity, data, measured_pixels, step_length, exitwave_options.step_length_weight, ) else: step_length = tike.ptycho.exitwave.poisson_steplength_all_modes( xi, cp.square(cp.abs(farplane)), intensity, data, measured_pixels, step_length, exitwave_options.step_length_weight, ) farplane[..., measured_pixels] = (-step_length * grad_cost)[..., measured_pixels] else: # Gaussian noise model for exitwave updates, steplength = 1 # TODO: optimal step lengths using 2nd order taylor expansion farplane[..., measured_pixels] = -getattr( tike.operators, f'{exitwave_options.noise_model}_grad')( data, farplane, intensity, )[..., measured_pixels] unmeasured_pixels = cp.logical_not(measured_pixels) farplane[..., unmeasured_pixels] *= ( exitwave_options.unmeasured_pixels_scaling - 1.0) pad, end = op.diffraction.pad, op.diffraction.end diff = op.propagation.adj(farplane, overwrite=True)[..., pad:end, pad:end] if object_options: grad_psi = (cp.conj(unique_probe) * diff / probe.shape[-3]).reshape( scan[lo:hi].shape[0] * probe.shape[-3], *probe.shape[-2:]) psi_update_numerator = op.diffraction.patch.adj( patches=grad_psi, images=psi_update_numerator, positions=scan[lo:hi], nrepeat=probe.shape[-3], ) if position_options or probe_options: patches = op.diffraction.patch.fwd( patches=cp.zeros_like(diff[..., 0, 0, :, :]), images=psi, positions=scan[lo:hi], )[..., None, None, :, :] if recover_probe: probe_update_numerator += cp.sum( cp.conj(patches) * diff, axis=-5, keepdims=True, ) if eigen_weights is not None: m: int = 0 OP = patches * probe[..., m:m + 1, :, :] eigen_numerator = cp.sum( cp.real(cp.conj(OP) * diff[..., m:m + 1, :, :]), axis=(-1, -2), ) eigen_denominator = cp.sum( cp.abs(OP)**2, axis=(-1, -2), ) eigen_weights[lo:hi, ..., 0:1, m:m+1] += ( 0.1 * (eigen_numerator / eigen_denominator) ) # yapf: disable if position_options: grad_x, grad_y = tike.ptycho.position.gaussian_gradient(patches) crop = probe.shape[-1] // 4 position_update_numerator[lo:hi, ..., 0] = cp.sum( cp.real( cp.conj( grad_x[..., crop:-crop, crop:-crop] * unique_probe[..., crop:-crop, crop:-crop] ) * diff[..., crop:-crop, crop:-crop] ), axis=(-4, -3, -2, -1), ) position_update_denominator[lo:hi, ..., 0] = cp.sum( cp.abs( grad_x[..., crop:-crop, crop:-crop] * unique_probe[..., crop:-crop, crop:-crop] ) ** 2, axis=(-4, -3, -2, -1), ) position_update_numerator[lo:hi, ..., 1] = cp.sum( cp.real( cp.conj( grad_y[..., crop:-crop, crop:-crop] * unique_probe[..., crop:-crop, crop:-crop] ) * diff[..., crop:-crop, crop:-crop] ), axis=(-4, -3, -2, -1), ) position_update_denominator[lo:hi, ..., 1] = cp.sum( cp.abs( grad_y[..., crop:-crop, crop:-crop] * unique_probe[..., crop:-crop, crop:-crop] ) ** 2, axis=(-4, -3, -2, -1), ) tike.communicators.stream.stream_and_modify2( f=keep_some_args_constant, ind_args=[ data, ], streams=streams, lo=batches[n][0], hi=batches[n][-1] + 1, ) return ( cost, psi_update_numerator, probe_update_numerator, position_update_numerator, position_update_denominator, eigen_weights, ) def _update_position( scan: npt.NDArray, position_options: PositionOptions, position_update_numerator: npt.NDArray, position_update_denominator: npt.NDArray, *, alpha=0.05, max_shift=1, epoch=0, ): if epoch < position_options.update_start: return scan, position_options step = (position_update_numerator) / ( (1 - alpha) * position_update_denominator + alpha * max(position_update_denominator.max(), 1e-6)) if position_options.update_magnitude_limit > 0: step = cp.clip( step, a_min=-position_options.update_magnitude_limit, a_max=position_options.update_magnitude_limit, ) # Remove outliars and subtract the mean step = step - cupyx.scipy.stats.trim_mean(step, 0.05) if position_options.use_adaptive_moment: ( step, position_options.v, position_options.m, ) = tike.opt.adam( step, position_options.v, position_options.m, vdecay=position_options.vdecay, mdecay=position_options.mdecay, ) scan -= step return scan, position_options