Source code for tike.operators.cupy.patch

__author__ = "Daniel Ching"
__copyright__ = "Copyright (c) 2021, UChicago Argonne, LLC."

try:
    from importlib.resources import files
except ImportError:
    # Backport for python<3.9 available as importlib_resources package
    from importlib_resources import files

import cupy as cp
import numpy.typing as npt
import numpy as np

from .operator import Operator

kernels = [
    'fwd_patch<float2,float2,float>',
    'adj_patch<float2,float2,float>',
    'fwd_patch<double2,double2,float>',
    'adj_patch<double2,double2,float>',
    'fwd_patch<float2,double2,float>',
    'adj_patch<float2,double2,float>',
    'fwd_patch<double2,float2,float>',
    'adj_patch<double2,float2,float>',
    'fwd_patch<float2,float2,double>',
    'adj_patch<float2,float2,double>',
    'fwd_patch<double2,double2,double>',
    'adj_patch<double2,double2,double>',
    'fwd_patch<float2,double2,double>',
    'adj_patch<float2,double2,double>',
    'fwd_patch<double2,float2,double>',
    'adj_patch<double2,float2,double>',
]

_patch_module = cp.RawModule(
    code=files('tike.operators.cupy').joinpath('convolution.cu').read_text(),
    name_expressions=kernels,
    options=('--std=c++11',),
)

typename = {
    np.dtype('complex64'): 'float2',
    np.dtype('float32'): 'float',
    np.dtype('complex128'): 'double2',
    np.dtype('float64'): 'double',
}


def _next_power_two(v: int) -> int:
    """Return the next highest power of 2 of 32-bit v.

    https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2
    """
    v -= 1
    v |= v >> 1
    v |= v >> 2
    v |= v >> 4
    v |= v >> 8
    v |= v >> 16
    return v + 1


[docs]class Patch(Operator): """Extract (zero-padded) patches from images at provided positions. Parameters ---------- images : (..., H, W) complex64 The complex wavefront modulation of the object. positions : (..., N, 2) float32 Coordinates of the minimum corner of the patches in the image grid. patches : (..., N * nrepeat, width+, width+) complex64 OR (..., L, width+, width+) complex64 The extracted (zero-padded) patches. (N * nrepeat) = K * L, K >= nrepeat patch_width : int The width of the unpadded patches. """ def fwd( self, images: npt.NDArray[np.csingle], positions: npt.NDArray[np.single], patches: npt.NDArray[np.csingle] = None, patch_width: int = 0, height: int = 0, width: int = 0, nrepeat: int = 1, ): patch_width = patches.shape[-1] if patch_width == 0 else patch_width if patches is None: patches = cp.zeros_like( images, shape=(*positions.shape[:-2], positions.shape[-2] * nrepeat, patch_width, patch_width), ) assert patch_width <= patches.shape[-1] assert images.shape[:-2] == positions.shape[:-2] assert positions.shape[:-2] == patches.shape[:-3], (positions.shape, patches.shape) assert positions.shape[-2] * nrepeat == patches.shape[-3] assert positions.shape[-1] == 2, positions.shape nimage = int(np.prod(images.shape[:-2])) _fwd_patch = _patch_module.get_function( f'fwd_patch<{typename[patches.dtype]},{typename[images.dtype]},{typename[positions.dtype]}>' ) grids = ( positions.shape[-2], nimage, patch_width, ) blocks = (min(_next_power_two(patch_width), _fwd_patch.attributes['max_threads_per_block']),) _fwd_patch( grids, blocks, ( images, patches, positions, nimage, *images.shape[-2:], positions.shape[-2], nrepeat, patch_width, patches.shape[-1], ), ) return patches def adj( self, positions: npt.NDArray[np.single], patches: npt.NDArray[np.csingle], images: npt.NDArray[np.csingle] = None, patch_width: int = 0, height: int = 0, width: int = 0, nrepeat: int = 1, ): patch_width = patches.shape[-1] if patch_width == 0 else patch_width assert patch_width <= patches.shape[-1] if images is None: images = cp.zeros_like( patches, shape=(*positions.shape[:-2], height, width), ) leading = images.shape[:-2] height, width = images.shape[-2:] assert positions.shape[:-2] == leading N = positions.shape[-2] assert positions.shape[-1] == 2 assert patches.shape[:-3] == leading K = patches.shape[-3] assert (N * nrepeat) % K == 0 and K >= nrepeat assert patches.shape[-1] == patches.shape[-2] nimage = int(np.prod(images.shape[:-2])) _adj_patch = _patch_module.get_function( f'adj_patch<{typename[patches.dtype]},{typename[images.dtype]},{typename[positions.dtype]}>' ) grids = ( positions.shape[-2], nimage, patch_width, ) blocks = (min(_next_power_two(patch_width), _adj_patch.attributes['max_threads_per_block']),) _adj_patch( grids, blocks, ( images, patches, positions, nimage, height, width, N, nrepeat, patch_width, patches.shape[-1], K, ), ) return images