import logging
import typing
import cupy as cp
import numpy.typing as npt
import tike.linalg
import tike.opt
import tike.ptycho.position
import tike.ptycho.probe
import tike.random
from .options import *
logger = logging.getLogger(__name__)
[docs]def dm(
op: tike.operators.Ptycho,
comm: tike.communicators.Comm,
data: typing.List[npt.NDArray],
batches: typing.List[typing.List[npt.NDArray[cp.intc]]],
*,
parameters: PtychoParameters,
) -> PtychoParameters:
"""Solve the ptychography problem using the difference map approach.
Parameters
----------
op : :py:class:`tike.operators.Ptycho`
A ptychography operator.
comm : :py:class:`tike.communicators.Comm`
An object which manages communications between GPUs and nodes.
data : list((FRAME, WIDE, HIGH) float32, ...)
A list of unique CuPy arrays for each device containing
the intensity (square of the absolute value) of the propagated
wavefront; i.e. what the detector records. FFT-shifted so the
diffraction peak is at the corners.
batches : list(list((BATCH_SIZE, ) int, ...), ...)
A list of list of indices along the FRAME axis of `data` for
each device which define the batches of `data` to process
simultaneously.
parameters : :py:class:`tike.ptycho.solvers.PtychoParameters`
An object which contains reconstruction parameters.
Returns
-------
result : :py:class:`tike.ptycho.solvers.PtychoParameters`
An object which contains reconstruction parameters.
References
----------
Thibault, Pierre, Martin Dierolf, Oliver Bunk, Andreas Menzel, and Franz
Pfeiffer. "Probe retrieval in ptychographic coherent diffractive imaging."
Ultramicroscopy 109, no. 4 (2009): 338-343.
.. seealso:: :py:mod:`tike.ptycho`
"""
psi_update_numerator = [None] * comm.pool.num_workers
probe_update_numerator = [None] * comm.pool.num_workers
# The objective function value for each batch
batch_cost: typing.List[float] = []
for n in tike.random.randomizer_np.permutation(len(batches[0])):
(
cost,
psi_update_numerator,
probe_update_numerator,
) = (list(a) for a in zip(*comm.pool.map(
_get_nearplane_gradients,
data,
parameters.scan,
parameters.psi,
parameters.probe,
parameters.exitwave_options.measured_pixels,
psi_update_numerator,
probe_update_numerator,
batches,
comm.streams,
n=n,
op=op,
object_options=parameters.object_options,
probe_options=parameters.probe_options,
exitwave_options=parameters.exitwave_options,
)))
cost = comm.Allreduce_mean(cost, axis=None).get()
batch_cost.append(cost)
(
parameters.psi,
parameters.probe,
) = _apply_update(
comm,
psi_update_numerator,
probe_update_numerator,
parameters.psi,
parameters.probe,
parameters.object_options,
parameters.probe_options,
)
parameters.algorithm_options.costs.append(batch_cost)
return parameters
def _apply_update(
comm,
psi_update_numerator,
probe_update_numerator,
psi,
probe,
object_options,
probe_options,
):
if object_options:
psi_update_numerator = comm.Allreduce_reduce_gpu(
psi_update_numerator)[0]
new_psi = psi_update_numerator / (object_options.preconditioner[0] +
1e-9)
if object_options.use_adaptive_moment:
(
dpsi,
object_options.v,
object_options.m,
) = tike.opt.adam(
g=(new_psi - psi[0]),
v=object_options.v,
m=object_options.m,
vdecay=object_options.vdecay,
mdecay=object_options.mdecay,
)
new_psi = dpsi + psi[0]
psi = comm.pool.bcast([new_psi])
if probe_options:
probe_update_numerator = comm.Allreduce_reduce_gpu(
probe_update_numerator)[0]
new_probe = probe_update_numerator / (probe_options.preconditioner[0] +
1e-9)
if probe_options.use_adaptive_moment:
(
dprobe,
probe_options.v,
probe_options.m,
) = tike.opt.adam(
g=(new_probe - probe[0]),
v=probe_options.v,
m=probe_options.m,
vdecay=probe_options.vdecay,
mdecay=probe_options.mdecay,
)
new_probe = dprobe + probe[0]
probe = comm.pool.bcast([new_probe])
return psi, probe
def _get_nearplane_gradients(
data: npt.NDArray,
scan: npt.NDArray,
psi: npt.NDArray,
probe: npt.NDArray,
measured_pixels: npt.NDArray,
psi_update_numerator: typing.Union[None, npt.NDArray],
probe_update_numerator: typing.Union[None, npt.NDArray],
batches: typing.List[typing.List[int]],
streams: typing.List[cp.cuda.Stream],
*,
n: int,
op: tike.operators.Ptycho,
object_options: typing.Union[None, ObjectOptions] = None,
probe_options: typing.Union[None, ProbeOptions] = None,
exitwave_options: ExitWaveOptions,
) -> typing.List[npt.NDArray]:
def keep_some_args_constant(
ind_args,
mod_args,
_,
):
(data, scan) = ind_args
(cost, psi_update_numerator, probe_update_numerator) = mod_args
varying_probe = probe
farplane = op.fwd(probe=varying_probe, scan=scan, psi=psi)
intensity = cp.sum(
cp.square(cp.abs(farplane)),
axis=list(range(1, farplane.ndim - 2)),
)
each_cost = getattr(
tike.operators,
f'{exitwave_options.noise_model}_each_pattern',
)(
data[:, measured_pixels][:, None, :],
intensity[:, measured_pixels][:, None, :],
)
cost += cp.sum(each_cost)
farplane[..., measured_pixels] *= ((
cp.sqrt(data) / (cp.sqrt(intensity) + 1e-9))[..., None, None,
measured_pixels])
farplane[..., ~measured_pixels] = 0
pad, end = op.diffraction.pad, op.diffraction.end
nearplane = op.propagation.adj(farplane, overwrite=True)[..., pad:end,
pad:end]
patches = op.diffraction.patch.fwd(
patches=cp.zeros_like(nearplane[..., 0, 0, :, :]),
images=psi,
positions=scan,
)[..., None, None, :, :]
if object_options:
grad_psi = (cp.conj(varying_probe) * nearplane).reshape(
scan.shape[0] * probe.shape[-3], *probe.shape[-2:])
psi_update_numerator = op.diffraction.patch.adj(
patches=grad_psi,
images=psi_update_numerator,
positions=scan,
nrepeat=probe.shape[-3],
)
if probe_options:
probe_update_numerator += cp.sum(
cp.conj(patches) * nearplane,
axis=-5,
keepdims=True,
)
return [
cost,
psi_update_numerator,
probe_update_numerator,
]
psi_update_numerator = cp.zeros_like(
psi,) if psi_update_numerator is None else psi_update_numerator
probe_update_numerator = cp.zeros_like(
probe,) if probe_update_numerator is None else probe_update_numerator
(
cost,
psi_update_numerator,
probe_update_numerator,
) = tike.communicators.stream.stream_and_modify(
f=keep_some_args_constant,
ind_args=[
data,
scan,
],
mod_args=[
0.0,
psi_update_numerator,
probe_update_numerator,
],
streams=streams,
indices=batches[n],
)
return [
cost / len(batches[n]),
psi_update_numerator,
probe_update_numerator,
]