import logging
import typing
import cupy as cp
import numpy.typing as npt
import tike.communicators
import tike.linalg
import tike.operators
import tike.opt
import tike.precision
import tike.ptycho.position
import tike.ptycho.probe
import tike.random
from .options import *
logger = logging.getLogger(__name__)
[docs]def dm(
op: tike.operators.Ptycho,
comm: tike.communicators.Comm,
data: typing.List[npt.NDArray],
batches: typing.List[typing.List[npt.NDArray[cp.intc]]],
*,
parameters: PtychoParameters,
epoch: int,
) -> PtychoParameters:
"""Solve the ptychography problem using the difference map approach.
Parameters
----------
op : :py:class:`tike.operators.Ptycho`
A ptychography operator.
comm : :py:class:`tike.communicators.Comm`
An object which manages communications between GPUs and nodes.
data : list((FRAME, WIDE, HIGH) float32, ...)
A list of unique CuPy arrays for each device containing
the intensity (square of the absolute value) of the propagated
wavefront; i.e. what the detector records. FFT-shifted so the
diffraction peak is at the corners.
batches : list(list((BATCH_SIZE, ) int, ...), ...)
A list of list of indices along the FRAME axis of `data` for
each device which define the batches of `data` to process
simultaneously.
parameters : :py:class:`tike.ptycho.solvers.PtychoParameters`
An object which contains reconstruction parameters.
Returns
-------
result : :py:class:`tike.ptycho.solvers.PtychoParameters`
An object which contains reconstruction parameters.
References
----------
Thibault, Pierre, Martin Dierolf, Oliver Bunk, Andreas Menzel, and Franz
Pfeiffer. "Probe retrieval in ptychographic coherent diffractive imaging."
Ultramicroscopy 109, no. 4 (2009): 338-343.
.. seealso:: :py:mod:`tike.ptycho`
"""
psi_update_numerator = [None] * comm.pool.num_workers
probe_update_numerator = [None] * comm.pool.num_workers
(
cost,
psi_update_numerator,
probe_update_numerator,
) = (list(a) for a in zip(*comm.pool.map(
_get_nearplane_gradients,
data,
parameters.scan,
parameters.psi,
parameters.probe,
parameters.exitwave_options.measured_pixels,
psi_update_numerator,
probe_update_numerator,
comm.streams,
op=op,
object_options=parameters.object_options,
probe_options=parameters.probe_options,
exitwave_options=parameters.exitwave_options,
)))
cost = comm.Allreduce_mean(cost).get()
(
parameters.psi,
parameters.probe,
) = _apply_update(
comm,
psi_update_numerator,
probe_update_numerator,
parameters.psi,
parameters.probe,
parameters.object_options,
parameters.probe_options,
)
parameters.algorithm_options.costs.append(cost)
return parameters
def _apply_update(
comm,
psi_update_numerator,
probe_update_numerator,
psi,
probe,
object_options,
probe_options,
):
if object_options:
psi_update_numerator = comm.Allreduce_reduce_gpu(
psi_update_numerator)[0]
new_psi = psi_update_numerator / (object_options.preconditioner[0] +
1e-9)
if object_options.use_adaptive_moment:
(
dpsi,
object_options.v,
object_options.m,
) = tike.opt.adam(
g=(new_psi - psi[0]),
v=object_options.v,
m=object_options.m,
vdecay=object_options.vdecay,
mdecay=object_options.mdecay,
)
new_psi = dpsi + psi[0]
psi = comm.pool.bcast([new_psi])
if probe_options:
probe_update_numerator = comm.Allreduce_reduce_gpu(
probe_update_numerator)[0]
new_probe = probe_update_numerator / (probe_options.preconditioner[0] +
1e-9)
if probe_options.use_adaptive_moment:
(
dprobe,
probe_options.v,
probe_options.m,
) = tike.opt.adam(
g=(new_probe - probe[0]),
v=probe_options.v,
m=probe_options.m,
vdecay=probe_options.vdecay,
mdecay=probe_options.mdecay,
)
new_probe = dprobe + probe[0]
probe = comm.pool.bcast([new_probe])
return psi, probe
def _get_nearplane_gradients(
data: npt.NDArray,
scan: npt.NDArray,
psi: npt.NDArray,
probe: npt.NDArray,
measured_pixels: npt.NDArray,
psi_update_numerator: typing.Union[None, npt.NDArray],
probe_update_numerator: typing.Union[None, npt.NDArray],
streams: typing.List[cp.cuda.Stream],
*,
op: tike.operators.Ptycho,
object_options: typing.Union[None, ObjectOptions] = None,
probe_options: typing.Union[None, ProbeOptions] = None,
exitwave_options: ExitWaveOptions,
) -> typing.List[npt.NDArray]:
cost = cp.zeros(1, dtype=tike.precision.floating)
count = cp.ones(1, dtype=tike.precision.floating) / len(data)
probe_update_numerator = cp.zeros_like(
probe) if probe_update_numerator is None else probe_update_numerator
psi_update_numerator = cp.zeros_like(
psi) if psi_update_numerator is None else psi_update_numerator
def keep_some_args_constant(
ind_args,
lo,
hi,
):
(data,) = ind_args
nonlocal cost, psi_update_numerator, probe_update_numerator
varying_probe = probe
farplane = op.fwd(probe=varying_probe, scan=scan[lo:hi], psi=psi)
intensity = cp.sum(
cp.square(cp.abs(farplane)),
axis=list(range(1, farplane.ndim - 2)),
)
each_cost = getattr(
tike.operators,
f'{exitwave_options.noise_model}_each_pattern',
)(
data[:, measured_pixels][:, None, :],
intensity[:, measured_pixels][:, None, :],
)
cost += cp.sum(each_cost) * count
farplane[..., measured_pixels] *= ((
cp.sqrt(data) / (cp.sqrt(intensity) + 1e-9))[..., None, None,
measured_pixels])
farplane[..., ~measured_pixels] = 0
pad, end = op.diffraction.pad, op.diffraction.end
nearplane = op.propagation.adj(farplane, overwrite=True)[..., pad:end,
pad:end]
patches = op.diffraction.patch.fwd(
patches=cp.zeros_like(nearplane[..., 0, 0, :, :]),
images=psi,
positions=scan[lo:hi],
)[..., None, None, :, :]
if object_options:
grad_psi = (cp.conj(varying_probe) * nearplane).reshape(
(hi - lo) * probe.shape[-3], *probe.shape[-2:])
psi_update_numerator = op.diffraction.patch.adj(
patches=grad_psi,
images=psi_update_numerator,
positions=scan[lo:hi],
nrepeat=probe.shape[-3],
)
if probe_options:
probe_update_numerator += cp.sum(
cp.conj(patches) * nearplane,
axis=-5,
keepdims=True,
)
tike.communicators.stream.stream_and_modify2(
f=keep_some_args_constant,
ind_args=[
data,
],
streams=streams,
lo=0,
hi=len(data),
)
return [
cost,
psi_update_numerator,
probe_update_numerator,
]