Source code for tike.ptycho.solvers.dm

import logging
import typing

import cupy as cp
import numpy.typing as npt

import tike.linalg
import tike.opt
import tike.ptycho.position
import tike.ptycho.probe
import tike.random

from .options import *

logger = logging.getLogger(__name__)


[docs]def dm( op: tike.operators.Ptycho, comm: tike.communicators.Comm, data: typing.List[npt.NDArray], batches: typing.List[typing.List[npt.NDArray[cp.intc]]], *, parameters: PtychoParameters, ) -> PtychoParameters: """Solve the ptychography problem using the difference map approach. Parameters ---------- op : :py:class:`tike.operators.Ptycho` A ptychography operator. comm : :py:class:`tike.communicators.Comm` An object which manages communications between GPUs and nodes. data : list((FRAME, WIDE, HIGH) float32, ...) A list of unique CuPy arrays for each device containing the intensity (square of the absolute value) of the propagated wavefront; i.e. what the detector records. FFT-shifted so the diffraction peak is at the corners. batches : list(list((BATCH_SIZE, ) int, ...), ...) A list of list of indices along the FRAME axis of `data` for each device which define the batches of `data` to process simultaneously. parameters : :py:class:`tike.ptycho.solvers.PtychoParameters` An object which contains reconstruction parameters. Returns ------- result : :py:class:`tike.ptycho.solvers.PtychoParameters` An object which contains reconstruction parameters. References ---------- Thibault, Pierre, Martin Dierolf, Oliver Bunk, Andreas Menzel, and Franz Pfeiffer. "Probe retrieval in ptychographic coherent diffractive imaging." Ultramicroscopy 109, no. 4 (2009): 338-343. .. seealso:: :py:mod:`tike.ptycho` """ psi_update_numerator = [None] * comm.pool.num_workers probe_update_numerator = [None] * comm.pool.num_workers # The objective function value for each batch batch_cost: typing.List[float] = [] for n in tike.random.randomizer_np.permutation(len(batches[0])): ( cost, psi_update_numerator, probe_update_numerator, ) = (list(a) for a in zip(*comm.pool.map( _get_nearplane_gradients, data, parameters.scan, parameters.psi, parameters.probe, parameters.exitwave_options.measured_pixels, psi_update_numerator, probe_update_numerator, batches, comm.streams, n=n, op=op, object_options=parameters.object_options, probe_options=parameters.probe_options, exitwave_options=parameters.exitwave_options, ))) cost = comm.Allreduce_mean(cost, axis=None).get() batch_cost.append(cost) ( parameters.psi, parameters.probe, ) = _apply_update( comm, psi_update_numerator, probe_update_numerator, parameters.psi, parameters.probe, parameters.object_options, parameters.probe_options, ) parameters.algorithm_options.costs.append(batch_cost) return parameters
def _apply_update( comm, psi_update_numerator, probe_update_numerator, psi, probe, object_options, probe_options, ): if object_options: psi_update_numerator = comm.Allreduce_reduce_gpu( psi_update_numerator)[0] new_psi = psi_update_numerator / (object_options.preconditioner[0] + 1e-9) if object_options.use_adaptive_moment: ( dpsi, object_options.v, object_options.m, ) = tike.opt.adam( g=(new_psi - psi[0]), v=object_options.v, m=object_options.m, vdecay=object_options.vdecay, mdecay=object_options.mdecay, ) new_psi = dpsi + psi[0] psi = comm.pool.bcast([new_psi]) if probe_options: probe_update_numerator = comm.Allreduce_reduce_gpu( probe_update_numerator)[0] new_probe = probe_update_numerator / (probe_options.preconditioner[0] + 1e-9) if probe_options.use_adaptive_moment: ( dprobe, probe_options.v, probe_options.m, ) = tike.opt.adam( g=(new_probe - probe[0]), v=probe_options.v, m=probe_options.m, vdecay=probe_options.vdecay, mdecay=probe_options.mdecay, ) new_probe = dprobe + probe[0] probe = comm.pool.bcast([new_probe]) return psi, probe def _get_nearplane_gradients( data: npt.NDArray, scan: npt.NDArray, psi: npt.NDArray, probe: npt.NDArray, measured_pixels: npt.NDArray, psi_update_numerator: typing.Union[None, npt.NDArray], probe_update_numerator: typing.Union[None, npt.NDArray], batches: typing.List[typing.List[int]], streams: typing.List[cp.cuda.Stream], *, n: int, op: tike.operators.Ptycho, object_options: typing.Union[None, ObjectOptions] = None, probe_options: typing.Union[None, ProbeOptions] = None, exitwave_options: ExitWaveOptions, ) -> typing.List[npt.NDArray]: def keep_some_args_constant( ind_args, mod_args, _, ): (data, scan) = ind_args (cost, psi_update_numerator, probe_update_numerator) = mod_args varying_probe = probe farplane = op.fwd(probe=varying_probe, scan=scan, psi=psi) intensity = cp.sum( cp.square(cp.abs(farplane)), axis=list(range(1, farplane.ndim - 2)), ) each_cost = getattr( tike.operators, f'{exitwave_options.noise_model}_each_pattern', )( data[:, measured_pixels][:, None, :], intensity[:, measured_pixels][:, None, :], ) cost += cp.sum(each_cost) farplane[..., measured_pixels] *= (( cp.sqrt(data) / (cp.sqrt(intensity) + 1e-9))[..., None, None, measured_pixels]) farplane[..., ~measured_pixels] = 0 pad, end = op.diffraction.pad, op.diffraction.end nearplane = op.propagation.adj(farplane, overwrite=True)[..., pad:end, pad:end] patches = op.diffraction.patch.fwd( patches=cp.zeros_like(nearplane[..., 0, 0, :, :]), images=psi, positions=scan, )[..., None, None, :, :] if object_options: grad_psi = (cp.conj(varying_probe) * nearplane).reshape( scan.shape[0] * probe.shape[-3], *probe.shape[-2:]) psi_update_numerator = op.diffraction.patch.adj( patches=grad_psi, images=psi_update_numerator, positions=scan, nrepeat=probe.shape[-3], ) if probe_options: probe_update_numerator += cp.sum( cp.conj(patches) * nearplane, axis=-5, keepdims=True, ) return [ cost, psi_update_numerator, probe_update_numerator, ] psi_update_numerator = cp.zeros_like( psi,) if psi_update_numerator is None else psi_update_numerator probe_update_numerator = cp.zeros_like( probe,) if probe_update_numerator is None else probe_update_numerator ( cost, psi_update_numerator, probe_update_numerator, ) = tike.communicators.stream.stream_and_modify( f=keep_some_args_constant, ind_args=[ data, scan, ], mod_args=[ 0.0, psi_update_numerator, probe_update_numerator, ], streams=streams, indices=batches[n], ) return [ cost / len(batches[n]), psi_update_numerator, probe_update_numerator, ]