Source code for tike.align.solvers.cross_correlation

# Copyright (C) 2019, the scikit-image team
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
#    this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
# 3. Neither the name of skimage nor the names of its contributors may be used
#    to endorse or promote products derived from this software without specific
#    prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED
# WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
# EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Implements a cross_correlation 2D alignment algorithm based on
phase_cross_correlation from skimage.registration."""

import numpy as np


[docs]def cross_correlation( op, original, unaligned, upsample_factor=1, space="real", num_iter=None, reg_weight=1e-9, ): """Efficient subpixel image translation alignment by cross-correlation. This code gives the same precision as the FFT upsampled cross-correlation in a fraction of the computation time and with reduced memory requirements. It obtains an initial estimate of the cross-correlation peak by an FFT and then refines the shift estimation by upsampling the DFT only in a small neighborhood of that estimate by means of a matrix-multiply DFT. References ---------- Stéfan van der Walt, Johannes L. Schönberger, Juan Nunez-Iglesias, François Boulogne, Joshua D. Warner, Neil Yager, Emmanuelle Gouillart, Tony Yu and the scikit-image contributors. scikit-image: Image processing in Python. PeerJ 2:e453 (2014) :doi:`10.7717/peerj.453` Manuel Guizar-Sicairos, Samuel T. Thurman, and James R. Fienup, "Efficient subpixel image registration algorithms," Optics Letters 33, 156-158 (2008). :doi:`10.1364/OL.33.000156` James R. Fienup, "Invariant error metrics for image reconstruction" Optics Letters 36, 8352-8357 (1997). :doi:`10.1364/AO.36.008352` """ # assume complex data is already in Fourier space if space.lower() == 'fourier': src_freq = unaligned target_freq = original # real data needs to be fft'd. elif space.lower() == 'real': src_freq = op.xp.fft.fft2(unaligned) target_freq = op.xp.fft.fft2(original) else: raise ValueError(f"space must be 'fourier' or 'real' not '{space}'.") # Whole-pixel shift - Compute cross-correlation by an IFFT shape = src_freq.shape image_product = src_freq * target_freq.conj() cross_correlation = op.xp.fft.ifft2(image_product) # Add a small regularization term so that smaller shifts are preferred when # the cross_correlation is the same for multiple shifts. if reg_weight > 0: w = _area_overlap(op, cross_correlation) w = op.xp.fft.fftshift(w) * reg_weight else: w = 0 A = np.abs(cross_correlation) + w maxima = A.reshape(A.shape[0], -1).argmax(1) maxima = np.column_stack(np.unravel_index(maxima, A[0, :, :].shape)) shifts = op.xp.array(maxima, dtype='float32') midpoints = [x // 2 for x in shape[1:]] ids = np.where(shifts[:, 0] > midpoints[0]) shifts[ids[0], 0] -= shape[1] ids = np.where(shifts[:, 1] > midpoints[1]) shifts[ids[0], 1] -= shape[2] if upsample_factor > 1: # Initial shift estimate in upsampled grid shifts = np.round(shifts * upsample_factor) / upsample_factor upsampled_region_size = np.ceil(upsample_factor * 1.5) # Center of output array at dftshift + 1 dftshift = np.fix(upsampled_region_size / 2.0) normalization = (src_freq[0].size * upsample_factor**2) # Matrix multiply DFT around the current shift estimate sample_region_offset = dftshift - shifts * upsample_factor cross_correlation = _upsampled_dft( op, image_product.conj(), upsampled_region_size, upsample_factor, sample_region_offset, ).conj() cross_correlation /= normalization # Locate maximum and map back to original pixel grid A = np.abs(cross_correlation) maxima = A.reshape(A.shape[0], -1).argmax(1) maxima = np.column_stack(np.unravel_index(maxima, A[0, :, :].shape)) maxima = maxima - dftshift shifts = shifts + maxima / upsample_factor return {'shift': shifts.astype('float32'), 'cost': -1}
def _upsampled_dft(op, data, ups, upsample_factor, axis_offsets): im2pi = -2j * np.pi shape = data.shape kernel = ((op.xp.arange(ups) - axis_offsets[:, 1:2])[:, :, None] * op.xp.fft.fftfreq(shape[2], upsample_factor)) kernel = np.exp(im2pi * kernel) data = np.einsum('ijk,ipk->ijp', kernel, data) kernel = ((op.xp.arange(ups) - axis_offsets[:, 0:1])[:, :, None] * op.xp.fft.fftfreq(shape[1], upsample_factor)) kernel = np.exp(im2pi * kernel) return np.einsum('ijk,ipk->ijp', kernel, data) def _triangle(op, N): """Return N samples from the triangle function.""" x = op.xp.linspace(0, 1, N, endpoint=False) + 0.5 / N return 1 - abs(x - 0.5) def _area_overlap(op, A): """Return overlapping area of A with itself. Create overlap arrays for higher dimensions using matrix multiplication. >>> _area_overlap(np.empty(4)) array([0.625, 0.875, 0.875, 0.625]) >>> _area_overlap(np.empty((3, 5))) array([[0.4 , 0.53333333, 0.66666667, 0.53333333, 0.4 ], [0.6 , 0.8 , 1. , 0.8 , 0.6 ], [0.4 , 0.53333333, 0.66666667, 0.53333333, 0.4 ]]) """ for dim, shape in enumerate(A.shape[-2:]): if dim == 0: w = _triangle(op, shape) else: w = w[..., np.newaxis] @ _triangle(op, shape)[np.newaxis, ...] return w