Source code for tike.operators.cupy.propagation

"""Defines a free-space propagation operator based on the CuPy FFT module."""

__author__ = "Daniel Ching, Viktor Nikitin"
__copyright__ = "Copyright (c) 2020, UChicago Argonne, LLC."

import numpy.typing as npt
import numpy as np

import tike.operators.cupy.objective as objective

from .cache import CachedFFT
from .operator import Operator


[docs]class Propagation(CachedFFT, Operator): """A Fourier-based free-space propagation using CuPy. Take an (..., N, N) array and apply the Fourier transform to the last two dimensions. Attributes ---------- detector_shape : int The pixel width and height of the nearplane and farplane waves. cost : (data-like, farplane-like) -> float The function to be minimized when solving a problem. grad : (data-like, farplane-like) -> farplane-like The gradient of cost. Parameters ---------- nearplane: (..., detector_shape, detector_shape) complex64 The wavefronts after exiting the object. farplane: (..., detector_shape, detector_shape) complex64 The wavefronts hitting the detector respectively. Shape for cost functions and gradients is (nscan, 1, 1, detector_shape, detector_shape). .. versionchanged:: 0.25.0 Removed the model parameter and the cost(), grad() functions. Use the cost and gradient functions directly instead. """
[docs] def __init__(self, detector_shape: int, norm: str = "ortho", **kwargs): self.detector_shape = detector_shape self.norm = norm
def fwd( self, nearplane: npt.NDArray[np.csingle], overwrite: bool = False, **kwargs, ) -> npt.NDArray[np.csingle]: """Forward Fourier-based free-space propagation operator.""" self._check_shape(nearplane) shape = nearplane.shape return self._fft2( nearplane.reshape(-1, self.detector_shape, self.detector_shape), norm=self.norm, axes=(-2, -1), overwrite_x=overwrite, ).reshape(shape) def adj( self, farplane: npt.NDArray[np.csingle], overwrite: bool = False, **kwargs, ) -> npt.NDArray[np.csingle]: """Adjoint Fourier-based free-space propagation operator.""" self._check_shape(farplane) shape = farplane.shape return self._ifft2( farplane.reshape(-1, self.detector_shape, self.detector_shape), norm=self.norm, axes=(-2, -1), overwrite_x=overwrite, ).reshape(shape) def _check_shape(self, x: npt.NDArray) -> None: assert type(x) is self.xp.ndarray, type(x) shape = (-1, self.detector_shape, self.detector_shape) if __debug__ and x.shape[-2:] != shape[-2:]: raise ValueError(f"waves must have shape {shape} not {x.shape}.")