Source code for tike.operators.cupy.ptycho

"""Defines a ptychography operator based on the CuPy FFT module."""

__author__ = "Daniel Ching, Viktor Nikitin"
__copyright__ = "Copyright (c) 2020, UChicago Argonne, LLC."

from .operator import Operator
from .propagation import Propagation
from .convolution import Convolution


[docs]class Ptycho(Operator): """A Ptychography operator. Compose a diffraction and propagation operator to simulate the interaction of an illumination wavefront with an object followed by the propagation of the wavefront to a detector plane. Parameters ---------- detector_shape : int The pixel width and height of the (square) detector grid. nz, n : int The pixel width and height of the reconstructed grid. probe_shape : int The pixel width and height of the (square) probe illumination. propagation : :py:class:`Operator` The wave propagation operator being used. diffraction : :py:class:`Operator` The object probe interaction operator being used. model : string The type of noise model to use for the cost functions. data : (..., FRAME, WIDE, HIGH) float32 The intensity (square of the absolute value) of the propagated wavefront; i.e. what the detector records. farplane: (..., POSI, 1, SHARED, detector_shape, detector_shape) complex64 The wavefronts hitting the detector respectively. probe : {(..., 1, 1, SHARED, WIDE, HIGH), (..., POSI, 1, SHARED, WIDE, HIGH)} complex64 The complex illumination function. psi : (..., WIDE, HIGH) complex64 The wavefront modulation coefficients of the object. scan : (..., POSI, 2) float32 Coordinates of the minimum corner of the probe grid for each measurement in the coordinate system of psi. Coordinate order consistent with WIDE, HIGH order. """ def __init__(self, detector_shape, probe_shape, nz, n, ntheta=1, model='gaussian', propagation=Propagation, diffraction=Convolution, **kwargs): # noqa: D102 yapf: disable """Please see help(Ptycho) for more info.""" self.propagation = propagation( detector_shape=detector_shape, model=model, **kwargs, ) self.diffraction = diffraction( probe_shape=probe_shape, detector_shape=detector_shape, nz=nz, n=n, model=model, **kwargs, ) # TODO: Replace these with @property functions self.probe_shape = probe_shape self.detector_shape = detector_shape self.nz = nz self.n = n def __enter__(self): self.propagation.__enter__() self.diffraction.__enter__() return self def __exit__(self, type, value, traceback): self.propagation.__exit__(type, value, traceback) self.diffraction.__exit__(type, value, traceback)
[docs] def fwd(self, probe, scan, psi, **kwargs): """Please see help(Ptycho) for more info.""" return self.propagation.fwd( self.diffraction.fwd( psi=psi, scan=scan, probe=probe[..., 0, :, :, :], ), overwrite=True, )[..., None, :, :, :]
[docs] def adj(self, farplane, probe, scan, psi=None, overwrite=False, **kwargs): """Please see help(Ptycho) for more info.""" return self.diffraction.adj( nearplane=self.propagation.adj( farplane, overwrite=overwrite, )[..., 0, :, :, :], probe=probe[..., 0, :, :, :], scan=scan, overwrite=True, psi=psi, )
[docs] def adj_probe(self, farplane, scan, psi, overwrite=False, **kwargs): """Please see help(Ptycho) for more info.""" return self.diffraction.adj_probe( psi=psi, scan=scan, nearplane=self.propagation.adj( farplane=farplane, overwrite=overwrite, )[..., 0, :, :, :], overwrite=True, )[..., None, :, :, :]
def _compute_intensity(self, data, psi, scan, probe): """Compute detector intensities replacing the nth probe mode""" farplane = self.fwd( psi=psi, scan=scan, probe=probe, ) return self.xp.sum( (farplane * farplane.conj()).real, axis=(2, 3), ), farplane
[docs] def cost(self, data, psi, scan, probe) -> float: """Please see help(Ptycho) for more info.""" intensity, _ = self._compute_intensity(data, psi, scan, probe) return self.propagation.cost(data, intensity)
[docs] def grad_psi(self, data, psi, scan, probe): """Please see help(Ptycho) for more info.""" intensity, farplane = self._compute_intensity(data, psi, scan, probe) grad_obj = self.xp.zeros_like(psi) grad_obj = self.adj( farplane=self.propagation.grad( data, farplane, intensity, ), probe=probe, scan=scan, psi=grad_obj, overwrite=True, ) return grad_obj
[docs] def grad_probe(self, data, psi, scan, probe, mode=None): """Compute the gradient with respect to the probe(s). Parameters ---------- mode : list(int) Only return the gradient with resepect to these probes. """ mode = list(range(probe.shape[-3])) if mode is None else mode intensity, farplane = self._compute_intensity(data, psi, scan, probe) # Use the average gradient for all probe positions return self.xp.mean( self.adj_probe( farplane=self.propagation.grad( data, farplane[..., mode, :, :], intensity, ), psi=psi, scan=scan, overwrite=True, ), axis=1, keepdims=True, )