Source code for tike.operators.cupy.bucket

__author__ = "Daniel Ching, Xiaodong Yu"
__copyright__ = "Copyright (c) 2020, UChicago Argonne, LLC."

try:
    from importlib.resources import files
except ImportError:
    # Backport for python<3.9 available as importlib_resources package
    from importlib_resources import files
from itertools import product
import logging

import cupy as cp
import numpy as np
import tike.precision

from .lamino import Lamino

kernels = [
    'coordinates_and_weights<float,float3>',
    'coordinates_and_weights<double,double3>',
    'fwd<float2>',
    'adj<float2>',
    'fwd<double2>',
    'adj<double2>',
]

_bucket_module = cp.RawModule(
    code=files('tike.operators.cupy').joinpath('bucket.cu').read_text(),
    name_expressions=kernels,
    options=('--std=c++11',),
)

typename = {
    np.dtype('complex64'): 'float2',
    np.dtype('float32'): 'float',
    np.dtype('complex128'): 'double2',
    np.dtype('float64'): 'double',
}

logger = logging.getLogger(__name__)


[docs]class Bucket(Lamino): """A Laminography operator. Laminography operators to simulate propagation of the beam through the object for a defined tilt angle. An object rotates around its own vertical axis, nz, and the beam illuminates the object some tilt angle off this axis. Attributes ---------- n : int The pixel width of the cubic reconstructed grid. tilt : float32 The tilt angle; the angle between the rotation axis of the object and the light source. π / 2 for conventional tomography. 0 for a beam path along the rotation axis. Parameters ---------- u : (nz, n, n) complex64 The complex refractive index of the object. nz is the axis corresponding to the rotation axis. data : (ntheta, n, n) complex64 The complex projection data of the object. theta : array-like float32 The projection angles; rotation around the vertical axis of the object. """
[docs] def __init__(self, n, tilt, eps=1, **kwargs): """Please see help(Lamino) for more info.""" self.n = n self.tilt = np.single(tilt) # Increase precision until weights are less than eps precision = 1 while (1 / precision**3) > eps: precision += 1 logger.info("Bucket operator using %d precision to reach %f eps.", precision, eps) self.precision = np.int16(precision) self.weight = np.double(1.0 / self.precision**3)
def __enter__(self): return self def __exit__(self, type, value, traceback): pass def fwd(self, u: cp.array, theta: cp.array, grid: cp.array, **kwargs): """Perform forward laminography operation. Parameters ---------- u : (nz, n, n) complex64 The complex refractive index of the object. nz is the axis corresponding to the rotation axis. theta : array-like float32 The projection angles; rotation around the vertical axis of the object. Return ------ data : (ntheta, n, n) complex64 The complex projection data of the object. """ data = cp.zeros_like(u, shape=(len(theta), self.n, self.n)) plane_coords = cp.zeros((len(grid), self.precision**3, 2), dtype='int16') _coords_weights_kernel = _bucket_module.get_function( f'coordinates_and_weights<{typename[theta.dtype]},{typename[theta.dtype]}3>' ) _bucket_fwd = _bucket_module.get_function(f'fwd<{typename[u.dtype]}>') for t in range(len(theta)): assert grid.dtype == 'int16' assert self.tilt.dtype == np.single assert self.precision.dtype == 'int16' assert plane_coords.dtype == 'int16' _coords_weights_kernel( (grid.shape[0],), (self.precision, self.precision, self.precision), ( grid, grid.shape[0], self.tilt, theta, t, self.precision, plane_coords, ), ) # Shift zero-centered coordinates to array indices; wrap negative # indices around plane_index = (plane_coords + self.n // 2) % self.n gmax, gmin = grid[:, :1].max(), grid[:, :1].min() grid_index = cp.concatenate( [(grid[:, :1] + cp.abs(gmin)) % (gmax - gmin), (grid[:, 1:] + self.n // 2) % self.n], axis=-1, ) assert data.dtype == u.dtype assert self.weight.dtype == np.double assert grid_index.dtype == 'int16' assert plane_index.dtype == 'int16' assert self.precision.dtype == 'int16' _bucket_fwd( (grid.shape[0],), (self.precision**3,), ( data, t, data.shape[1], data.shape[2], self.weight, u, u.shape[0], u.shape[1], u.shape[2], plane_index, grid_index, grid_index.shape[0], self.precision, ), ) return data def adj(self, data: cp.array, theta: cp.array, grid: cp.array, **kwargs): """Perform adjoint laminography operation. Parameters ---------- data : (ntheta, n, n) complex64 The complex projection data of the object. theta : array-like float32 The projection angles; rotation around the vertical axis of the object. Return ------ u : (nz, n, n) complex64 The complex refractive index of the object. nz is the axis corresponding to the rotation axis. """ u = cp.zeros_like( data, shape=(len(grid) // (self.n**2), self.n, self.n), ) plane_coords = cp.zeros((len(grid), self.precision**3, 2), dtype='int16') _coords_weights_kernel = _bucket_module.get_function( f'coordinates_and_weights<{typename[theta.dtype]},{typename[theta.dtype]}3>' ) _bucket_adj = _bucket_module.get_function( f'adj<{typename[data.dtype]}>') for t in range(len(theta)): _coords_weights_kernel( (grid.shape[0],), (self.precision, self.precision, self.precision), ( grid, grid.shape[0], self.tilt, theta, t, self.precision, plane_coords, ), ) # Shift zero-centered coordinates to array indices; wrap negative # indices around plane_index = (plane_coords + self.n // 2) % self.n gmax, gmin = grid[:, :1].max(), grid[:, :1].min() grid_index = cp.concatenate( [(grid[:, :1] + cp.abs(gmin)) % (gmax - gmin), (grid[:, 1:] + self.n // 2) % self.n], axis=-1, ) assert data.dtype == u.dtype assert self.weight.dtype == np.double assert grid_index.dtype == 'int16' assert plane_index.dtype == 'int16' assert self.precision.dtype == 'int16' _bucket_adj( (grid.shape[0],), (self.precision**3,), ( data, t, data.shape[1], data.shape[2], self.weight, u, u.shape[0], u.shape[1], u.shape[2], plane_index, grid_index, grid_index.shape[0], self.precision, ), ) return u def cost(self, data, fwd_data): """Cost function for the least-squres laminography problem""" return self.xp.linalg.norm((fwd_data - data).ravel())**2 def grad(self, data, theta, fwd_data, grid): """Gradient for the least-squares laminography problem""" out = self.adj( data=(fwd_data - data), theta=theta, grid=grid, ) # BUG? Cannot joint line below and above otherwise types are promoted? out /= (data.shape[-3] * self.n**3) return out def _make_grid(self, size=1, rank=0): """Return integer coordinates in the grid; origin centered.""" lo, hi = -self.n // 2, self.n // 2 grid = np.stack( np.mgrid[lo:hi, lo:hi, lo:hi], axis=-1, ) return np.array_split(grid, size)[rank]
def _get_coordinates_and_weights( grid, tilt, theta, precision=1, plane_coords=None, ): """Return coordinates and weights of grid points projected onto plane. Assumes a grid with cells of unit size onto a plane with cells on unit size. For an arbitrary point, the coordinate of the cell that contains the point is comptued by floor(point). i.e. the coordinates of a grid width 4 would be [-2, -1, 0, 1]. i.e. the zeroth grid cell spans from [0, 1). Transformations defining the plane are rotations about the origin. Parameters ---------- grid (N, 3) array int Coordinates of voxels on the grid. Coordinates are origin center. theta : (T,) float32 The projection angles; rotation around the vertical axis of the object. tilt : float32 The tilt angle; the angle between the rotation axis of the object and the light source. π / 2 for conventional tomography. 0 for a beam path along the rotation axis. precision : int The desired precision of the operator. Scales the number of grid samples by precision^3. Returns ------- plane_coords(N, precision**3, 2): int The coordinates on the planes into which each grid point is projected. """ N = len(grid) if plane_coords is None: plane_coords = cp.zeros((N, precision**3, 2), dtype='int') transformation = compute_transformation(tilt, theta) normal = transformation @ cp.array((1, 0, 0), dtype=tike.precision.floating) # print(f'python normal is {normal}') for cell in range(N): for i, j, k in product(range(precision), repeat=3): point = grid[cell] + cp.array([ (i + 0.5) / precision, (j + 0.5) / precision, (k + 0.5) / precision, ]) # print(f"python {point}") chunk = k + precision * (j + precision * i) p = project_point_to_plane( point, normal, transformation, ) # print(f"python {p}") plane_coords[cell, chunk] = p return plane_coords def _compute_transformation(tilt, theta): """Return a transformation which aligns [1, 0, 0] with the plane normal.""" transformation = cp.zeros((3, 3), dtype=tike.precision.floating) transformation[0, 0] = cp.cos(tilt) transformation[0, 1] = cp.sin(tilt) transformation[1, 0] = -cp.cos(theta) * cp.sin(tilt) transformation[1, 1] = cp.cos(theta) * cp.cos(tilt) transformation[1, 2] = -cp.sin(theta) transformation[2, 0] = -cp.sin(theta) * cp.sin(tilt) transformation[2, 1] = cp.sin(theta) * cp.cos(tilt) transformation[2, 2] = cp.cos(theta) return transformation def _project_point_to_plane(point, normal, transformation): """Return the integer coordinate of the point projected to the plane.""" distance = cp.sum(point * normal) projected = point - distance * normal return cp.floor(transformation.T @ projected)[1:]