Source code for tike.ptycho.solvers.rpie

import logging
import typing

import cupy as cp
import cupyx.scipy.stats
import numpy as np
import numpy.typing as npt

import tike.communicators
import tike.linalg
import tike.operators
import tike.opt
import tike.ptycho.object
import tike.ptycho.position
import tike.ptycho.probe
import tike.ptycho.exitwave
import tike.precision
import tike.random

from .options import *
from .lstsq import _momentum_checked

logger = logging.getLogger(__name__)


[docs]def rpie( parameters: PtychoParameters, data: npt.NDArray, batches: typing.List[npt.NDArray[cp.intc]], streams: typing.List[cp.cuda.Stream], worker_index: int, *, op: tike.operators.Ptycho, epoch: int, ) -> PtychoParameters: """Solve the ptychography problem using regularized ptychographical engine. The rPIE update direction can be shown to be equivalent to a conventional gradient descent direction but rescaled by the preconditioner term. i.e. If the rPIE step size (alpha) is 0 and the preconditioner is zero, we have the vanilla gradient descent direction. Parameters ---------- op : :py:class:`tike.operators.Ptycho` A ptychography operator. comm : :py:class:`tike.communicators.Comm` An object which manages communications between GPUs and nodes. data : list((FRAME, WIDE, HIGH) float32, ...) A list of unique CuPy arrays for each device containing the intensity (square of the absolute value) of the propagated wavefront; i.e. what the detector records. FFT-shifted so the diffraction peak is at the corners. batches : list(list((BATCH_SIZE, ) int, ...), ...) A list of list of indices along the FRAME axis of `data` for each device which define the batches of `data` to process simultaneously. parameters : :py:class:`tike.ptycho.solvers.PtychoParameters` An object which contains reconstruction parameters. Returns ------- result : :py:class:`tike.ptycho.solvers.PtychoParameters` An object which contains the updated reconstruction parameters. References ---------- Maiden, Andrew M., and John M. Rodenburg. 2009. “An Improved Ptychographical Phase Retrieval Algorithm for Diffractive Imaging.” Ultramicroscopy 109 (10): 1256–62. https://doi.org/10.1016/j.ultramic.2009.05.012. Andrew Maiden, Daniel Johnson, and Peng Li, "Further improvements to the ptychographical iterative engine," Optica 4, 736-745 (2017) https://doi.org/10.1364/OPTICA.4.000736 .. seealso:: :py:mod:`tike.ptycho` """ scan = parameters.scan psi = parameters.psi probe = parameters.probe algorithm_options = parameters.algorithm_options eigen_weights = parameters.eigen_weights eigen_probe = parameters.eigen_probe measured_pixels = parameters.exitwave_options.measured_pixels exitwave_options = parameters.exitwave_options position_options = parameters.position_options object_options = parameters.object_options probe_options = parameters.probe_options recover_probe = probe_options is not None and epoch >= probe_options.update_start # CONVERSTION AREA ABOVE --------------------------------------- if parameters.algorithm_options.batch_method == 'compact': order = range else: order = tike.random.randomizer_np.permutation psi_update_numerator = None probe_update_numerator = None position_update_numerator = None position_update_denominator = None batch_cost = cp.empty(algorithm_options.num_batch, dtype=tike.precision.floating) for n in order(algorithm_options.num_batch): ( costs, psi_update_numerator, probe_update_numerator, position_update_numerator, position_update_denominator, eigen_weights, ) = _get_nearplane_gradients( data, scan, psi, probe, measured_pixels, psi_update_numerator, probe_update_numerator, position_update_numerator, position_update_denominator, eigen_probe, eigen_weights, batches, streams, n=n, op=op, object_options=object_options, probe_options=probe_options, recover_probe=recover_probe, position_options=position_options, exitwave_options=exitwave_options, ) batch_cost[n] = cp.mean(costs) if algorithm_options.batch_method != 'compact': ( psi, probe, ) = _update( psi, probe, psi_update_numerator, probe_update_numerator, object_options, probe_options, recover_probe, algorithm_options, ) psi_update_numerator = None probe_update_numerator = None algorithm_options.costs.append([float(batch_cost.mean().get())]) # if position_options is not None: # ( # scan, # position_options, # ) = _update_position( # scan, # position_options, # position_update_numerator, # position_update_denominator, # max_shift=probe[0].shape[-1] * 0.1, # alpha=algorithm_options.alpha, # epoch=epoch, # ) if algorithm_options.batch_method == 'compact': ( psi, probe, ) = _update( psi, probe, psi_update_numerator, probe_update_numerator, object_options, probe_options, recover_probe, algorithm_options, errors=[float(x[worker_index]) for x in algorithm_options.costs[-3:]], ) if eigen_weights is not None: eigen_weights = _normalize_eigen_weights( eigen_weights, ) # CONVERSION AREA BELOW ---------------------- parameters.scan = scan parameters.psi = psi parameters.probe = probe parameters.algorithm_options = algorithm_options parameters.eigen_weights = eigen_weights parameters.eigen_probe = eigen_probe parameters.exitwave_options = exitwave_options parameters.position_options = position_options parameters.object_options = object_options parameters.probe_options = probe_options return parameters
def _normalize_eigen_weights(eigen_weights): return eigen_weights / tike.linalg.mnorm( eigen_weights, axis=(-3), keepdims=True, ) def _update( psi: npt.NDArray[cp.csingle], probe: npt.NDArray[cp.csingle], psi_update_numerator: npt.NDArray[cp.csingle], probe_update_numerator: npt.NDArray[cp.csingle], object_options: ObjectOptions, probe_options: ProbeOptions, recover_probe: bool, algorithm_options: RpieOptions, errors: typing.Union[None, npt.NDArray] = None, ) -> typing.Tuple[npt.NDArray[cp.csingle], npt.NDArray[cp.csingle]]: if object_options: dpsi = psi_update_numerator deno = ( ( 1 - algorithm_options.alpha) * object_options.preconditioner + algorithm_options.alpha * object_options.preconditioner.max( axis=(-2, -1), keepdims=True, ) ) psi = psi + dpsi / deno if object_options.use_adaptive_moment: if errors: ( dpsi, object_options.v, object_options.m, ) = _momentum_checked( g=dpsi, v=object_options.v, m=object_options.m, mdecay=object_options.mdecay, errors=errors, memory_length=3, ) else: ( dpsi, object_options.v, object_options.m, ) = tike.opt.adam( g=dpsi, v=object_options.v, m=object_options.m, vdecay=object_options.vdecay, mdecay=object_options.mdecay, ) psi = psi + dpsi / deno if recover_probe: dprobe = probe_update_numerator[ 0, ... ] # deno = ( # ( 1 - algorithm_options.alpha) * probe_options.preconditioner[ 0, ... ] # + algorithm_options.alpha * probe_options.preconditioner[ 0, ... ].max( axis=(-2, -1), keepdims=True, ) # ) deno = algorithm_options.alpha * probe_options.preconditioner[ 0, ... ].max( axis=(-2, -1), keepdims=True, ) probe = probe + dprobe / deno if probe_options.use_adaptive_moment: # ptychoshelves only applies momentum to the main probe mode = 0 if errors: ( dprobe[0, 0, mode, :, :], probe_options.v, probe_options.m, ) = _momentum_checked( g=(dprobe)[0, 0, mode, :, :], v=probe_options.v, m=probe_options.m, mdecay=probe_options.mdecay, errors=errors, memory_length=3, ) else: ( dprobe[0, 0, mode, :, :], probe_options.v, probe_options.m, ) = tike.opt.adam( g=(dprobe)[0, 0, mode, :, :], v=probe_options.v, m=probe_options.m, vdecay=probe_options.vdecay, mdecay=probe_options.mdecay, ) probe = probe + dprobe / deno return psi, probe def _get_nearplane_gradients( data: npt.NDArray, scan: npt.NDArray, psi: npt.NDArray, probe: npt.NDArray, measured_pixels: npt.NDArray, psi_update_numerator: typing.Union[None, npt.NDArray], probe_update_numerator: typing.Union[None, npt.NDArray], position_update_numerator: typing.Union[None, npt.NDArray], position_update_denominator: typing.Union[None, npt.NDArray], eigen_probe: typing.Union[None, npt.NDArray], eigen_weights: typing.Union[None, npt.NDArray], batches: typing.List[npt.NDArray[np.intc]], streams: typing.List[cp.cuda.Stream], *, n: int, op: tike.operators.Ptycho, object_options: typing.Union[None, ObjectOptions] = None, probe_options: typing.Union[None, ProbeOptions] = None, recover_probe: bool, position_options: typing.Union[None, PositionOptions], exitwave_options: ExitWaveOptions, ) -> typing.Tuple[ float, npt.NDArray, npt.NDArray, npt.NDArray, npt.NDArray, typing.Union[npt.NDArray, None] ]: batch_start = batches[n][0] batch_size = len(batches[n]) bcosts = cp.empty(shape=batch_size, dtype=tike.precision.floating) psi_update_numerator = cp.zeros_like( psi ) if psi_update_numerator is None else psi_update_numerator #probe_update_numerator = cp.zeros_like( probe ) if probe_update_numerator is None else probe_update_numerator probe_update_numerator = cp.zeros( ( psi.shape[0], *probe.shape ), dtype = probe.dtype ) position_update_numerator = cp.empty_like( scan ) if position_update_numerator is None else position_update_numerator position_update_denominator = cp.empty_like( scan ) if position_update_denominator is None else position_update_denominator def keep_some_args_constant( ind_args, lo: int, hi: int, ): (data,) = ind_args nonlocal bcosts, psi_update_numerator, probe_update_numerator nonlocal position_update_numerator, position_update_denominator nonlocal eigen_weights, scan blo = lo - batch_start bhi = hi - batch_start unique_probe = tike.ptycho.probe.get_varying_probe( probe, eigen_probe, eigen_weights[lo:hi] if eigen_weights is not None else None, ) farplane, multislice_probes = op.fwd_return_intermediate_probes(probe=unique_probe, scan=scan[lo:hi], psi=psi) intensity = cp.sum( cp.square(cp.abs(farplane)), axis=list(range(1, farplane.ndim - 2)), ) bcosts[blo:bhi] = getattr( tike.operators, f'{exitwave_options.noise_model}_each_pattern', )( data[:, measured_pixels][:, None, :], intensity[:, measured_pixels][:, None, :], ) if exitwave_options.noise_model == 'poisson': xi = (1 - data / intensity)[:, None, None, :, :] grad_cost = farplane * xi step_length = cp.full( shape=(farplane.shape[0], 1, farplane.shape[2], 1, 1), fill_value=tike.precision.floating(exitwave_options.step_length_start), dtype=tike.precision.floating, ) if exitwave_options.step_length_usemodes == 'dominant_mode': step_length = tike.ptycho.exitwave.poisson_steplength_dominant_mode( xi, intensity, data, measured_pixels, step_length, exitwave_options.step_length_weight, ) else: step_length = tike.ptycho.exitwave.poisson_steplength_all_modes( xi, cp.square(cp.abs(farplane)), intensity, data, measured_pixels, step_length, exitwave_options.step_length_weight, ) farplane[..., measured_pixels] = (-step_length * grad_cost)[..., measured_pixels] else: # Gaussian noise model for exitwave updates, steplength = 1 # TODO: optimal step lengths using 2nd order taylor expansion farplane[..., measured_pixels] = -getattr( tike.operators, f'{exitwave_options.noise_model}_grad')( data, farplane, intensity, )[..., measured_pixels] unmeasured_pixels = cp.logical_not(measured_pixels) farplane[..., unmeasured_pixels] *= ( exitwave_options.unmeasured_pixels_scaling - 1.0) pad, end = op.diffraction.pad, op.diffraction.end diff = op.propagation.adj(farplane, overwrite=True)[..., pad:end, pad:end] if object_options: #for tt in cp.arange( psi.shape[0] - 1, -1, -1 ) : for tt in range(len(psi) - 1, -1, -1 ) : grad_psi = (cp.conj(multislice_probes[ tt, :, None, ... ]) * diff / probe.shape[-3]).reshape( scan[lo:hi].shape[0] * probe.shape[-3], *probe.shape[-2:] ) psi_update_numerator[ tt, ... ] = op.diffraction.patch.adj( patches=grad_psi, images=psi_update_numerator[ tt, ... ], positions=scan[lo:hi], nrepeat=probe.shape[-3], ) patches = op.diffraction.patch.fwd( patches=cp.zeros_like(diff[..., 0, 0, :, :]), images=psi[ tt, ... ], positions=scan[lo:hi], )[..., None, None, :, :] probe_update_numerator[ tt, ... ] += cp.sum( cp.conj(patches) * diff, axis=-5, keepdims=True, ) if tt == 0: break diff = op.diffraction.propagation.adj( diff ) if position_options or probe_options: patches = op.diffraction.patch.fwd( patches=cp.zeros_like(diff[..., 0, 0, :, :]), images=psi[0], positions=scan[lo:hi], )[..., None, None, :, :] if recover_probe: # probe_update_numerator += cp.sum( # cp.conj(patches) * diff, # axis=-5, # keepdims=True, # ) if eigen_weights is not None: m: int = 0 OP = patches * probe[..., m:m + 1, :, :] eigen_numerator = cp.sum( cp.real(cp.conj(OP) * diff[..., m:m + 1, :, :]), axis=(-1, -2), ) eigen_denominator = cp.sum( cp.abs(OP)**2, axis=(-1, -2), ) eigen_weights[lo:hi, ..., 0:1, m:m+1] += ( 0.1 * (eigen_numerator / eigen_denominator) ) # yapf: disable # if position_options: # grad_x, grad_y = tike.ptycho.position.gaussian_gradient(patches) # crop = probe.shape[-1] // 4 # position_update_numerator[lo:hi, ..., 0] = cp.sum( # cp.real( # cp.conj( # grad_x[..., crop:-crop, crop:-crop] # * unique_probe[..., crop:-crop, crop:-crop] # ) # * diff[..., crop:-crop, crop:-crop] # ), # axis=(-4, -3, -2, -1), # ) # position_update_denominator[lo:hi, ..., 0] = cp.sum( # cp.abs( # grad_x[..., crop:-crop, crop:-crop] # * unique_probe[..., crop:-crop, crop:-crop] # ) # ** 2, # axis=(-4, -3, -2, -1), # ) # position_update_numerator[lo:hi, ..., 1] = cp.sum( # cp.real( # cp.conj( # grad_y[..., crop:-crop, crop:-crop] # * unique_probe[..., crop:-crop, crop:-crop] # ) # * diff[..., crop:-crop, crop:-crop] # ), # axis=(-4, -3, -2, -1), # ) # position_update_denominator[lo:hi, ..., 1] = cp.sum( # cp.abs( # grad_y[..., crop:-crop, crop:-crop] # * unique_probe[..., crop:-crop, crop:-crop] # ) # ** 2, # axis=(-4, -3, -2, -1), # ) tike.communicators.stream.stream_and_modify2( f=keep_some_args_constant, ind_args=[ data, ], streams=streams, lo=batches[n][0], hi=batches[n][-1] + 1, ) return ( bcosts, psi_update_numerator, probe_update_numerator, position_update_numerator, position_update_denominator, eigen_weights, ) # def _update_position( # scan: npt.NDArray, # position_options: PositionOptions, # position_update_numerator: npt.NDArray, # position_update_denominator: npt.NDArray, # *, # alpha: float = 0.05, # max_shift: float = 1.0, # epoch: int = 0, # ) -> typing.Tuple[cp.ndarray, PositionOptions]: # if epoch < position_options.update_start: # return scan, position_options # step = (position_update_numerator) / ( # (1 - alpha) * position_update_denominator + # alpha * max(position_update_denominator.max(), 1e-6)) # if position_options.update_magnitude_limit > 0: # step = cp.clip( # step, # a_min=-position_options.update_magnitude_limit, # a_max=position_options.update_magnitude_limit, # ) # # Remove outliars and subtract the mean # step = step - cupyx.scipy.stats.trim_mean(step, 0.05) # if position_options.use_adaptive_moment: # ( # step, # position_options.v, # position_options.m, # ) = tike.opt.adam( # step, # position_options.v, # position_options.m, # vdecay=position_options.vdecay, # mdecay=position_options.mdecay, # ) # scan -= step # return scan, position_options