Source code for tike.operators.cupy.ptycho

"""Defines a ptychography operator based on the CuPy FFT module."""

__author__ = "Daniel Ching, Viktor Nikitin"
__copyright__ = "Copyright (c) 2020, UChicago Argonne, LLC."

import typing

import numpy.typing as npt
import numpy as np
import cupy as cp

from .operator import Operator
from .propagation import Propagation
from .convolution import Convolution
from . import objective


@cp.fuse()
def _intensity_from_farplane(farplane):
    return cp.sum(
        cp.real(farplane * cp.conj(farplane)),
        axis=tuple(range(1, farplane.ndim - 2)),
    )


[docs]class Ptycho(Operator): """A Ptychography operator. Compose a diffraction and propagation operator to simulate the interaction of an illumination wavefront with an object followed by the propagation of the wavefront to a detector plane. Parameters ---------- detector_shape : int The pixel width and height of the (square) detector grid. nz, n : int The pixel width and height of the reconstructed grid. probe_shape : int The pixel width and height of the (square) probe illumination. propagation : :py:class:`Operator` The wave propagation operator being used. diffraction : :py:class:`Operator` The object probe interaction operator being used. data : (..., FRAME, WIDE, HIGH) float32 The intensity (square of the absolute value) of the propagated wavefront; i.e. what the detector records. farplane: (..., POSI, 1, SHARED, detector_shape, detector_shape) complex64 The wavefronts hitting the detector respectively. probe : {(..., 1, 1, SHARED, WIDE, HIGH), (..., POSI, 1, SHARED, WIDE, HIGH)} complex64 The complex illumination function. psi : (..., WIDE, HIGH) complex64 The wavefront modulation coefficients of the object. scan : (..., POSI, 2) float32 Coordinates of the minimum corner of the probe grid for each measurement in the coordinate system of psi. Coordinate order consistent with WIDE, HIGH order. .. versionchanged:: 0.25.0 Removed the model and ntheta parameters. """
[docs] def __init__( self, detector_shape: int, probe_shape: int, nz: int, n: int, propagation: typing.Type[Propagation] = Propagation, diffraction: typing.Type[Convolution] = Convolution, norm: str = 'ortho', **kwargs, ): """Please see help(Ptycho) for more info.""" self.propagation = propagation( detector_shape=detector_shape, norm=norm, **kwargs, ) self.diffraction = diffraction( probe_shape=probe_shape, detector_shape=detector_shape, nz=nz, n=n, **kwargs, ) # TODO: Replace these with @property functions self.probe_shape = probe_shape self.detector_shape = detector_shape self.nz = nz self.n = n
def __enter__(self): self.propagation.__enter__() self.diffraction.__enter__() return self def __exit__(self, type, value, traceback): self.propagation.__exit__(type, value, traceback) self.diffraction.__exit__(type, value, traceback) def fwd( self, probe: npt.NDArray[np.csingle], scan: npt.NDArray[np.single], psi: npt.NDArray[np.csingle], **kwargs, ) -> npt.NDArray[np.csingle]: """Please see help(Ptycho) for more info.""" return self.propagation.fwd( self.diffraction.fwd( psi=psi, scan=scan, probe=probe[..., 0, :, :, :], ), overwrite=True, )[..., None, :, :, :] def adj( self, farplane: npt.NDArray[np.csingle], probe: npt.NDArray[np.csingle], scan: npt.NDArray[np.single], psi: npt.NDArray[np.csingle] = None, overwrite: bool = False, **kwargs, ) -> npt.NDArray[np.csingle]: """Please see help(Ptycho) for more info.""" return self.diffraction.adj( nearplane=self.propagation.adj( farplane, overwrite=overwrite, )[..., 0, :, :, :], probe=probe[..., 0, :, :, :], scan=scan, overwrite=True, psi=psi, ) def adj_probe( self, farplane: npt.NDArray[np.csingle], scan: npt.NDArray[np.single], psi: npt.NDArray[np.csingle], overwrite: bool = False, **kwargs, ) -> npt.NDArray[np.csingle]: """Please see help(Ptycho) for more info.""" return self.diffraction.adj_probe( psi=psi, scan=scan, nearplane=self.propagation.adj( farplane=farplane, overwrite=overwrite, )[..., 0, :, :, :], overwrite=True, )[..., None, :, :, :] def _compute_intensity( self, data: npt.NDArray, psi: npt.NDArray[np.csingle], scan: npt.NDArray[np.single], probe: npt.NDArray[np.csingle], ) -> npt.NDArray[np.single]: """Compute detector intensities replacing the nth probe mode""" farplane = self.fwd( psi=psi, scan=scan, probe=probe, ) return _intensity_from_farplane(farplane), farplane def cost( self, data: npt.NDArray, psi: npt.NDArray[np.csingle], scan: npt.NDArray[np.single], probe: npt.NDArray[np.csingle], *, model: str, ) -> float: """Please see help(Ptycho) for more info.""" intensity, _ = self._compute_intensity(data, psi, scan, probe) return getattr(objective, model)(data, intensity) def grad_psi( self, data: npt.NDArray, psi: npt.NDArray[np.csingle], scan: npt.NDArray[np.single], probe: npt.NDArray[np.csingle], *, model: str, ) -> npt.NDArray[np.csingle]: """Please see help(Ptycho) for more info.""" intensity, farplane = self._compute_intensity(data, psi, scan, probe) grad_obj = self.xp.zeros_like(psi) grad_obj = self.adj( farplane=getattr(objective, f'{model}_grad')( data, farplane, intensity, ), probe=probe, scan=scan, psi=grad_obj, overwrite=True, ) return grad_obj def grad_probe( self, data: npt.NDArray, psi: npt.NDArray[np.csingle], scan: npt.NDArray[np.single], probe: npt.NDArray[np.csingle], mode: typing.List[int] = None, *, model: str, ) -> npt.NDArray[np.csingle]: """Compute the gradient with respect to the probe(s). Parameters ---------- mode : list(int) Only return the gradient with resepect to these probes. """ mode = list(range(probe.shape[-3])) if mode is None else mode intensity, farplane = self._compute_intensity(data, psi, scan, probe) # Use the average gradient for all probe positions return self.xp.mean( self.adj_probe( farplane=getattr(objective, f'{model}_grad')( data, farplane[..., mode, :, :], intensity, ), psi=psi, scan=scan, overwrite=True, ), axis=0, keepdims=True, ) def adj_all( self, farplane: npt.NDArray[np.csingle], probe: npt.NDArray[np.csingle], scan: npt.NDArray[np.single], psi: npt.NDArray[np.csingle], overwrite: bool = False, rpie: bool = False, ) -> typing.Tuple[npt.NDArray, ...]: """Please see help(Ptycho) for more info.""" result = self.diffraction.adj_all( nearplane=self.propagation.adj( farplane, overwrite=overwrite, )[..., 0, :, :, :], probe=probe[..., 0, :, :, :], scan=scan, overwrite=True, psi=psi, rpie=rpie, ) return (result[0], result[1][..., None, :, :, :], *result[2:])