__author__ = "Daniel Ching, Viktor Nikitin"
__copyright__ = "Copyright (c) 2020, UChicago Argonne, LLC."
import cupy as cp
from .operator import Operator
from .patch import Patch
from .shift import Shift
[docs]class Convolution(Operator):
"""A 2D Convolution operator with linear interpolation.
Compute the product two arrays at specific relative positions.
Attributes
----------
nscan : int
The number of scan positions at each angular view.
probe_shape : int
The pixel width and height of the (square) probe illumination.
nz, n : int
The pixel width and height of the reconstructed grid.
ntheta : int
The number of angular partitions of the data.
Parameters
----------
psi : (..., nz, n) complex64
The complex wavefront modulation of the object.
probe : complex64
The (..., nscan, nprobe, probe_shape, probe_shape) or
(..., 1, nprobe, probe_shape, probe_shape) complex illumination
function.
nearplane: complex64
The (...., nscan, nprobe, probe_shape, probe_shape)
wavefronts after exiting the object.
scan : (..., nscan, 2) float32
Coordinates of the minimum corner of the probe grid for each
measurement in the coordinate system of psi. Vertical coordinates
first, horizontal coordinates second.
"""
[docs] def __init__(self, probe_shape, nz, n, ntheta=None,
detector_shape=None, **kwargs): # yapf: disable
self.probe_shape = probe_shape
self.nz = nz
self.n = n
if detector_shape is None:
self.detector_shape = probe_shape
else:
self.detector_shape = detector_shape
self.pad = (self.detector_shape - self.probe_shape) // 2
self.end = self.probe_shape + self.pad
self.patch = Patch()
def fwd(self, psi, scan, probe):
"""Extract probe shaped patches from the psi at each scan position.
The patches within the bounds of psi are linearly interpolated, and
indices outside the bounds of psi are not allowed.
"""
assert psi.shape[:-2] == scan.shape[:-2], (psi.shape, scan.shape)
assert probe.shape[:-4] == scan.shape[:-2], (probe.shape, scan.shape)
assert probe.shape[-4] == 1 or probe.shape[-4] == scan.shape[-2]
if self.detector_shape == self.probe_shape:
patches = self.xp.empty_like(
psi,
shape=(
*scan.shape[:-2],
scan.shape[-2] * probe.shape[-3],
self.detector_shape,
self.detector_shape,
),
)
else:
patches = self.xp.zeros_like(
psi,
shape=(
*scan.shape[:-2],
scan.shape[-2] * probe.shape[-3],
self.detector_shape,
self.detector_shape,
),
)
patches = self.patch.fwd(
patches=patches,
images=psi,
positions=scan,
patch_width=self.probe_shape,
nrepeat=probe.shape[-3],
)
patches = patches.reshape((
*scan.shape[:-1],
probe.shape[-3],
self.detector_shape,
self.detector_shape,
))
patches[..., self.pad:self.end, self.pad:self.end] *= probe
return patches
def adj(self, nearplane, scan, probe, psi=None, overwrite=False):
"""Combine probe shaped patches into a psi shaped grid by addition."""
assert probe.shape[:-4] == scan.shape[:-2]
assert probe.shape[-4] == 1 or probe.shape[-4] == scan.shape[-2]
assert nearplane.shape[:-3] == scan.shape[:-1]
if not overwrite:
nearplane = nearplane.copy()
nearplane[..., self.pad:self.end, self.pad:self.end] *= probe.conj()
if psi is None:
psi = self.xp.zeros_like(
nearplane,
shape=(*scan.shape[:-2], self.nz, self.n),
)
assert psi.shape[:-2] == scan.shape[:-2]
return self.patch.adj(
patches=nearplane.reshape((
*scan.shape[:-2],
scan.shape[-2] * nearplane.shape[-3],
*nearplane.shape[-2:],
)),
images=psi,
positions=scan,
patch_width=self.probe_shape,
nrepeat=nearplane.shape[-3],
)
def adj_probe(self, nearplane, scan, psi, overwrite=False):
"""Combine probe shaped patches into a probe."""
assert nearplane.shape[:-3] == scan.shape[:-1], (nearplane.shape,
scan.shape)
assert psi.shape[:-2] == scan.shape[:-2], (psi.shape, scan.shape)
patches = self.xp.zeros_like(
psi,
shape=(
*scan.shape[:-2],
scan.shape[-2] * nearplane.shape[-3],
self.probe_shape,
self.probe_shape,
),
)
patches = self.patch.fwd(
patches=patches,
images=psi,
positions=scan,
patch_width=self.probe_shape,
nrepeat=nearplane.shape[-3],
)
patches = patches.reshape((*scan.shape[:-1], nearplane.shape[-3],
self.probe_shape, self.probe_shape))
patches = patches.conj()
patches *= nearplane[..., self.pad:self.end, self.pad:self.end]
return patches
def adj_all(self, nearplane, scan, probe, psi, overwrite=False, rpie=False):
"""Peform adj and adj_probe at the same time."""
assert probe.shape[:-4] == scan.shape[:-2]
assert psi.shape[:-2] == scan.shape[:-2], (psi.shape, scan.shape)
assert probe.shape[-4] == 1 or probe.shape[-4] == scan.shape[-2]
assert nearplane.shape[:-3] == scan.shape[:-1], (nearplane.shape,
scan.shape)
patches = self.patch.fwd(
# Could be xp.empty if scan positions are all in bounds
patches=self.xp.zeros_like(
psi,
shape=(
*scan.shape[:-2],
scan.shape[-2] * nearplane.shape[-3],
self.probe_shape,
self.probe_shape,
),
),
images=psi,
positions=scan,
patch_width=self.probe_shape,
nrepeat=nearplane.shape[-3],
)
patches = patches.reshape((*scan.shape[:-1], nearplane.shape[-3],
self.probe_shape, self.probe_shape))
if rpie:
patches_amp = self.xp.sum(
patches * patches.conj(),
axis=-4,
keepdims=True,
)
patches = patches.conj()
patches *= nearplane[..., self.pad:self.end, self.pad:self.end]
if not overwrite:
nearplane = nearplane.copy()
nearplane[..., self.pad:self.end, self.pad:self.end] *= probe.conj()
if rpie:
probe_amp = probe * probe.conj()
probe_amp = probe_amp.reshape(
(*scan.shape[:-2], -1, *nearplane.shape[-2:])
# (..., nscan * nprobe, probe_shape, probe_shape)
# (..., nprobe, probe_shape, probe_shape)
)
probe_amp = self.patch.adj(
patches=probe_amp,
images=self.xp.zeros_like(
psi,
shape=(*scan.shape[:-2], self.nz, self.n),
),
positions=scan,
patch_width=self.probe_shape,
nrepeat=nearplane.shape[-3],
)
apsi = self.patch.adj(
patches=nearplane.reshape((
*scan.shape[:-2],
scan.shape[-2] * nearplane.shape[-3],
*nearplane.shape[-2:],
)),
images=self.xp.zeros_like(
psi,
shape=(*scan.shape[:-2], self.nz, self.n),
),
positions=scan,
patch_width=self.probe_shape,
nrepeat=nearplane.shape[-3],
)
if rpie:
return apsi, patches, patches_amp, probe_amp
else:
return apsi, patches
class ConvolutionFFT(Operator):
"""A 2D Convolution operator with linear interpolation.
Compute the product two arrays at specific relative positions.
Attributes
----------
nscan : int
The number of scan positions at each angular view.
probe_shape : int
The pixel width and height of the (square) probe illumination.
nz, n : int
The pixel width and height of the reconstructed grid.
ntheta : int
The number of angular partitions of the data.
Parameters
----------
psi : (..., nz, n) complex64
The complex wavefront modulation of the object.
probe : complex64
The (..., nscan, nprobe, probe_shape, probe_shape) or
(..., 1, nprobe, probe_shape, probe_shape) complex illumination
function.
nearplane: complex64
The (...., nscan, nprobe, probe_shape, probe_shape)
wavefronts after exiting the object.
scan : (..., nscan, 2) float32
Coordinates of the minimum corner of the probe grid for each
measurement in the coordinate system of psi. Vertical coordinates
first, horizontal coordinates second.
"""
def __init__(self, probe_shape, nz, n, ntheta=None,
detector_shape=None, **kwargs): # yapf: disable
self.probe_shape = probe_shape
self.nz = nz
self.n = n
if detector_shape is None:
self.detector_shape = probe_shape
else:
self.detector_shape = detector_shape
self.pad = (self.detector_shape - self.probe_shape) // 2
self.end = self.probe_shape + self.pad
self.patch = Patch()
self.shift = Shift()
def __enter__(self):
self.shift.__enter__()
return self
def __exit__(self, type, value, traceback):
self.shift.__exit__(type, value, traceback)
def fwd(self, psi, scan, probe):
"""Extract probe shaped patches from the psi at each scan position.
The patches within the bounds of psi are linearly interpolated, and
indices outside the bounds of psi are not allowed.
"""
assert psi.shape[:-2] == scan.shape[:-2], (psi.shape, scan.shape)
assert probe.shape[:-4] == scan.shape[:-2], (probe.shape, scan.shape)
assert probe.shape[-4] == 1 or probe.shape[-4] == scan.shape[-2]
if self.detector_shape == self.probe_shape:
patches = self.xp.empty_like(
psi,
shape=(
*scan.shape[:-2],
scan.shape[-2] * probe.shape[-3],
self.detector_shape,
self.detector_shape,
),
)
else:
patches = self.xp.zeros_like(
psi,
shape=(
*scan.shape[:-2],
scan.shape[-2] * probe.shape[-3],
self.detector_shape,
self.detector_shape,
),
)
index, shift = self.xp.divmod(scan, 1.0)
shift = shift.reshape((*scan.shape[:-1], 1, 2))
patches = self.patch.fwd(
patches=patches,
images=psi,
positions=index,
patch_width=self.probe_shape,
nrepeat=probe.shape[-3],
)
patches = patches.reshape((
*scan.shape[:-1],
probe.shape[-3],
self.detector_shape,
self.detector_shape,
))
patches = self.shift.adj(patches, shift, overwrite=False)
patches[..., self.pad:self.end, self.pad:self.end] *= probe
return patches
def adj(self, nearplane, scan, probe, psi=None, overwrite=False):
"""Combine probe shaped patches into a psi shaped grid by addition."""
assert probe.shape[:-4] == scan.shape[:-2]
assert probe.shape[-4] == 1 or probe.shape[-4] == scan.shape[-2]
assert nearplane.shape[:-3] == scan.shape[:-1]
if not overwrite:
nearplane = nearplane.copy()
nearplane[..., self.pad:self.end, self.pad:self.end] *= probe.conj()
index, shift = self.xp.divmod(scan, 1.0)
shift = shift.reshape((*scan.shape[:-1], 1, 2))
nearplane = self.shift.fwd(nearplane, shift, overwrite=True)
if psi is None:
psi = self.xp.zeros_like(
nearplane,
shape=(*scan.shape[:-2], self.nz, self.n),
)
assert psi.shape[:-2] == scan.shape[:-2]
return self.patch.adj(
patches=nearplane.reshape((
*scan.shape[:-2],
scan.shape[-2] * nearplane.shape[-3],
*nearplane.shape[-2:],
)),
images=psi,
positions=index,
patch_width=self.probe_shape,
nrepeat=nearplane.shape[-3],
)