Source code for tike.operators.cupy.propagation

"""Defines a free-space propagation operator based on the CuPy FFT module."""

__author__ = "Daniel Ching, Viktor Nikitin"
__copyright__ = "Copyright (c) 2020, UChicago Argonne, LLC."

from .cache import CachedFFT
from .operator import Operator

import numpy as np


[docs]class Propagation(CachedFFT, Operator): """A Fourier-based free-space propagation using CuPy. Take an (..., N, N) array and apply the Fourier transform to the last two dimensions. Attributes ---------- detector_shape : int The pixel width and height of the nearplane and farplane waves. model : string The type of noise model to use for the cost functions. cost : (data-like, farplane-like) -> float The function to be minimized when solving a problem. grad : (data-like, farplane-like) -> farplane-like The gradient of cost. Parameters ---------- nearplane: (..., detector_shape, detector_shape) complex64 The wavefronts after exiting the object. farplane: (..., detector_shape, detector_shape) complex64 The wavefronts hitting the detector respectively. Shape for cost functions and gradients is (ntheta, nscan, 1, 1, detector_shape, detector_shape). data, intensity : (ntheta, nscan, detector_shape, detector_shape) complex64 data is the square of the absolute value of `farplane`. `data` is the intensity of the `farplane`. """ def __init__(self, detector_shape, model='gaussian', **kwargs): self.detector_shape = detector_shape self.cost = getattr(self, f'_{model}_cost') self.grad = getattr(self, f'_{model}_grad')
[docs] def fwd(self, nearplane, overwrite=False, **kwargs): """Forward Fourier-based free-space propagation operator.""" self._check_shape(nearplane) shape = nearplane.shape return self._fft2( nearplane.reshape(-1, self.detector_shape, self.detector_shape), norm='ortho', axes=(-2, -1), overwrite=overwrite, ).reshape(shape)
[docs] def adj(self, farplane, overwrite=False, **kwargs): """Adjoint Fourier-based free-space propagation operator.""" self._check_shape(farplane) shape = farplane.shape return self._ifft2( farplane.reshape(-1, self.detector_shape, self.detector_shape), norm='ortho', axes=(-2, -1), overwrite=overwrite, ).reshape(shape)
def _check_shape(self, x): assert type(x) is self.xp.ndarray, type(x) shape = (-1, self.detector_shape, self.detector_shape) if (__debug__ and x.shape[-2:] != shape[-2:]): raise ValueError(f'waves must have shape {shape} not {x.shape}.') # COST FUNCTIONS AND GRADIENTS -------------------------------------------- # NOTE: We use mean instead of sum so that cost functions may be compared # when mini-batches of different sizes are used. def _gaussian_cost(self, data, intensity): diff = np.sqrt(intensity) - np.sqrt(data) diff *= diff.conj() return np.mean(diff) def _gaussian_grad(self, data, farplane, intensity, overwrite=False): return farplane * ( 1 - np.sqrt(data) / (np.sqrt(intensity) + 1e-32) )[:, :, np.newaxis, np.newaxis] # yapf:disable def _poisson_cost(self, data, intensity): return np.mean(intensity - data * np.log(intensity + 1e-32)) def _poisson_grad(self, data, farplane, intensity, overwrite=False): return farplane * ( 1 - data / (intensity + 1e-32) )[:, :, np.newaxis, np.newaxis] # yapf: disable