Source code for tike.operators.cupy.shift
__author__ = "Daniel Ching, Viktor Nikitin"
__copyright__ = "Copyright (c) 2020, UChicago Argonne, LLC."
from .cache import CachedFFT
from .operator import Operator
[docs]class Shift(CachedFFT, Operator):
"""Shift last two dimensions of an array using Fourier method."""
def fwd(self, a, shift, overwrite=False, cval=None):
"""Apply shifts along last two dimensions of a.
Parameters
----------
array (..., H, W) float32
The array to be shifted.
shift (..., 2) float32
The the shifts to be applied along the last two axes.
"""
if shift is None:
return a
shape = a.shape
padded = a.reshape(*shape)
padded = self._fft2(
padded,
axes=(-2, -1),
overwrite_x=overwrite,
)
x, y = self.xp.meshgrid(
self.xp.fft.fftfreq(padded.shape[-1]).astype(shift.dtype),
self.xp.fft.fftfreq(padded.shape[-2]).astype(shift.dtype),
)
padded *= self.xp.exp(
-2j
* self.xp.pi
* (x * shift[..., 1, None, None] + y * shift[..., 0, None, None])
)
padded = self._ifft2(padded, axes=(-2, -1), overwrite_x=True)
return padded.reshape(*shape)
def adj(self, a, shift, overwrite=False, cval=None):
if shift is None:
return a
return self.fwd(a, -shift, overwrite=overwrite, cval=cval)
inv = adj