Source code for tike.operators.cupy.patch

__author__ = "Daniel Ching"
__copyright__ = "Copyright (c) 2021, UChicago Argonne, LLC."

from importlib_resources import files

import cupy as cp
import numpy as np

from .operator import Operator

_cu_source = files('tike.operators.cupy').joinpath('convolution.cu').read_text()
_fwd_patch = cp.RawKernel(_cu_source, "fwd_patch")
_adj_patch = cp.RawKernel(_cu_source, "adj_patch")


def _next_power_two(v):
    """Return the next highest power of 2 of 32-bit v.

    https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2
    """
    v -= 1
    v |= v >> 1
    v |= v >> 2
    v |= v >> 4
    v |= v >> 8
    v |= v >> 16
    return v + 1


[docs]class Patch(Operator): """Extract (zero-padded) patches from images at provided positions. Parameters ---------- images : (..., H, W) complex64 The complex wavefront modulation of the object. positions : (..., N, 2) float32 Coordinates of the minimum corner of the patches in the image grid. patches : (..., N * nrepeat, width+, width+) complex64 The extracted (zero-padded) patches. patch_width : int The width of the unpadded patches. """
[docs] def fwd( self, images, positions, patches=None, patch_width=None, height=None, width=None, nrepeat=1, ): patch_width = patches.shape[-1] if patch_width is None else patch_width if patches is None: patches = cp.zeros( (*positions.shape[:-2], positions.shape[-2] * nrepeat, patch_width, patch_width), dtype='complex64', ) assert patch_width <= patches.shape[-1] assert images.shape[:-2] == positions.shape[:-2] assert positions.shape[:-2] == patches.shape[:-3], (positions.shape, patches.shape) assert positions.shape[-2] * nrepeat == patches.shape[-3] assert positions.shape[-1] == 2 assert images.dtype == 'complex64' assert patches.dtype == 'complex64' assert positions.dtype == 'float32' nimage = np.prod(images.shape[:-2]) grids = ( positions.shape[-2], nimage, patch_width, ) blocks = (min(_next_power_two(patch_width), _fwd_patch.attributes['max_threads_per_block']),) _fwd_patch( grids, blocks, ( images, patches, positions, nimage, *images.shape[-2:], positions.shape[-2], nrepeat, patch_width, patches.shape[-1], ), ) return patches
[docs] def adj( self, positions, patches, images=None, patch_width=None, height=None, width=None, nrepeat=1, ): patch_width = patches.shape[-1] if patch_width is None else patch_width assert patch_width <= patches.shape[-1] if images is None: images = cp.zeros( (*positions.shape[:-2], height, width), dtype='complex64', ) assert images.shape[:-2] == positions.shape[:-2] assert positions.shape[:-2] == patches.shape[:-3], (positions.shape, patches.shape) assert positions.shape[-2] * nrepeat == patches.shape[-3] assert positions.shape[-1] == 2 assert images.dtype == 'complex64' assert patches.dtype == 'complex64' assert positions.dtype == 'float32' nimage = np.prod(images.shape[:-2]) grids = ( positions.shape[-2], nimage, patch_width, ) blocks = (min(_next_power_two(patch_width), _adj_patch.attributes['max_threads_per_block']),) _adj_patch( grids, blocks, ( images, patches, positions, nimage, *images.shape[-2:], positions.shape[-2], nrepeat, patch_width, patches.shape[-1], ), ) return images