Source code for tike.operators.cupy.cache

__author__ = "Daniel Ching"
__copyright__ = "Copyright (c) 2020, UChicago Argonne, LLC."

import typing

from cupyx.scipy.fft import fftn, ifftn, get_fft_plan
import cupy.cuda.cufft
import cupy.cuda.runtime
import numpy.typing as npt
import numpy as np


[docs]class CachedFFT(): """Provides a multi-plan per-device cache for CuPy FFT. A class which inherits from this class gains the _fft2, _fftn, and _ifft2 methods which provide automatic plan caching for the CuPy FFTs. This plan cache differs from the cache included in CuPy>=8 because it is NOT per-thread. This allows us to use threadpool.map() and allows us to destroy the cache manually. """ def __enter__(self): self.plan_cache = {} return self def __exit__(self, type, value, traceback): self.plan_cache.clear() del self.plan_cache def _get_fft_plan( self, a: npt.NDArray, axes: typing.Tuple[int, ...] = (), **kwargs, ) -> typing.Union[cupy.cuda.cufft.Plan1d, cupy.cuda.cufft.PlanNd]: """Cache multiple FFT plans at the same time.""" axes = tuple(range(a.ndim)) if axes == () else axes key = (*a.shape, *axes, a.dtype, cupy.cuda.runtime.getDevice()) if key in self.plan_cache: plan = self.plan_cache[key] else: plan = get_fft_plan(a, axes=axes) self.plan_cache[key] = plan return plan def _fft2( self, a: npt.NDArray, *args, axes: typing.Tuple[int, int] = (-2, -1), **kwargs, ) -> npt.NDArray[np.csingle]: return self._fftn(a, *args, axes=axes, **kwargs) def _ifft2( self, a: npt.NDArray, *args, axes: typing.Tuple[int, int] = (-2, -1), **kwargs, ) -> npt.NDArray[np.csingle]: return self._ifftn(a, *args, axes=axes, **kwargs) def _ifftn( self, a: npt.NDArray, *args, **kwargs, ) -> npt.NDArray[np.csingle]: with self._get_fft_plan(a, **kwargs): return ifftn(a, *args, **kwargs) def _fftn( self, a: npt.NDArray, *args, **kwargs, ) -> npt.NDArray[np.csingle]: with self._get_fft_plan(a, **kwargs): return fftn(a, *args, **kwargs)