Source code for tike.ptycho.ptycho

#!/usr/bin/env python
# -*- coding: utf-8 -*-

# #########################################################################
# Copyright (c) 2018, UChicago Argonne, LLC. All rights reserved.    #
#                                                                         #
# Copyright 2018. UChicago Argonne, LLC. This software was produced       #
# under U.S. Government contract DE-AC02-06CH11357 for Argonne National   #
# Laboratory (ANL), which is operated by UChicago Argonne, LLC for the    #
# U.S. Department of Energy. The U.S. Government has rights to use,       #
# reproduce, and distribute this software.  NEITHER THE GOVERNMENT NOR    #
# UChicago Argonne, LLC MAKES ANY WARRANTY, EXPRESS OR IMPLIED, OR        #
# ASSUMES ANY LIABILITY FOR THE USE OF THIS SOFTWARE.  If software is     #
# modified to produce derivative works, such modified software should     #
# be clearly marked, so as not to confuse it with the version available   #
# from ANL.                                                               #
#                                                                         #
# Additionally, redistribution and use in source and binary forms, with   #
# or without modification, are permitted provided that the following      #
# conditions are met:                                                     #
#                                                                         #
#     * Redistributions of source code must retain the above copyright    #
#       notice, this list of conditions and the following disclaimer.     #
#                                                                         #
#     * Redistributions in binary form must reproduce the above copyright #
#       notice, this list of conditions and the following disclaimer in   #
#       the documentation and/or other materials provided with the        #
#       distribution.                                                     #
#                                                                         #
#     * Neither the name of UChicago Argonne, LLC, Argonne National       #
#       Laboratory, ANL, the U.S. Government, nor the names of its        #
#       contributors may be used to endorse or promote products derived   #
#       from this software without specific prior written permission.     #
#                                                                         #
# THIS SOFTWARE IS PROVIDED BY UChicago Argonne, LLC AND CONTRIBUTORS     #
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT       #
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS       #
# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL UChicago     #
# Argonne, LLC OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,        #
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,    #
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;        #
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER        #
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT      #
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN       #
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE         #
# POSSIBILITY OF SUCH DAMAGE.                                             #
# #########################################################################

__author__ = "Doga Gursoy, Daniel Ching, Xiaodong Yu"
__copyright__ = "Copyright (c) 2018, UChicago Argonne, LLC."
__docformat__ = 'restructuredtext en'
__all__ = [
    "reconstruct",
    "simulate",
]

from itertools import product, chain
import logging
import time

import numpy as np
import cupy as cp

from tike.operators import Ptycho
from tike.communicators import Comm, MPIComm
from tike.opt import batch_indicies
from tike.ptycho import solvers
from .position import check_allowed_positions, get_padded_object
from .probe import get_varying_probe

logger = logging.getLogger(__name__)


def _compute_intensity(
    operator,
    psi,
    scan,
    probe,
    eigen_weights=None,
    eigen_probe=None,
    fly=1,
):
    leading = psi.shape[:-2]
    intensity = 0
    for m in range(probe.shape[-3]):
        farplane = operator.fwd(
            probe=get_varying_probe(probe, eigen_probe, eigen_weights, m=m),
            scan=scan,
            psi=psi,
        )
        intensity += np.sum(
            np.square(np.abs(farplane)).reshape(
                *leading,
                scan.shape[-2] // fly,
                fly,
                operator.detector_shape,
                operator.detector_shape,
            ),
            axis=-3,
            keepdims=False,
        )
    return intensity


[docs]def simulate( detector_shape, probe, scan, psi, fly=1, eigen_probe=None, eigen_weights=None, **kwargs ): # yapf: disable """Return real-valued detector counts of simulated ptychography data. Parameters ---------- detector_shape : int The pixel width of the detector. probe : (..., 1, 1, SHARED, WIDE, HIGH) complex64 The shared complex illumination function amongst all positions. scan : (..., POSI, 2) float32 Coordinates of the minimum corner of the probe grid for each measurement in the coordinate system of psi. psi : (..., WIDE, HIGH) complex64 The complex wavefront modulation of the object. fly : int The number of scan positions which combine for one detector frame. eigen_probe : (..., 1, EIGEN, SHARED, WIDE, HIGH) complex64 The eigen probes for all positions. eigen_weights : (..., POSI, EIGEN, SHARED) float32 The relative intensity of the eigen probes at each position. Returns ------- data : (..., FRAME, WIDE, HIGH) float32 The simulated intensity on the detector. """ check_allowed_positions(scan, psi, probe) with Ptycho( probe_shape=probe.shape[-1], detector_shape=int(detector_shape), nz=psi.shape[-2], n=psi.shape[-1], ntheta=scan.shape[0], **kwargs, ) as operator: scan = operator.asarray(scan, dtype='float32') psi = operator.asarray(psi, dtype='complex64') probe = operator.asarray(probe, dtype='complex64') if eigen_weights is not None: eigen_weights = operator.asarray(eigen_weights, dtype='float32') data = _compute_intensity(operator, psi, scan, probe, eigen_weights, eigen_probe, fly) return operator.asnumpy(data.real)
[docs]def reconstruct( data, probe, scan, algorithm, psi=None, num_gpu=1, num_iter=1, rtol=-1, model='gaussian', use_mpi=False, cost=None, times=None, eigen_probe=None, eigen_weights=None, batch_size=None, **kwargs ): # yapf: disable """Solve the ptychography problem using the given `algorithm`. Parameters ---------- data : (..., FRAME, WIDE, HIGH) float32 The intensity (square of the absolute value) of the propagated wavefront; i.e. what the detector records. eigen_probe : (..., 1, EIGEN, SHARED, WIDE, HIGH) complex64 The eigen probes for all positions. eigen_weights : (..., POSI, EIGEN, SHARED) float32 The relative intensity of the eigen probes at each position. psi : (..., WIDE, HIGH) complex64 The wavefront modulation coefficients of the object. probe : (..., 1, 1, SHARED, WIDE, HIGH) complex64 The shared complex illumination function amongst all positions. scan : (..., POSI, 2) float32 Coordinates of the minimum corner of the probe grid for each measurement in the coordinate system of psi. Coordinate order consistent with WIDE, HIGH order. algorithm : string The name of one algorithms from :py:mod:`.ptycho.solvers`. rtol : float Terminate early if the relative decrease of the cost function is less than this amount. batch_size : int The approximate number of scan positions processed by each GPU simultaneously per view. """ (psi, scan) = get_padded_object(scan, probe) if psi is None else (psi, scan) check_allowed_positions(scan, psi, probe) if use_mpi is True: mpi = MPIComm else: mpi = None if algorithm in solvers.__all__: # Initialize an operator. with Ptycho( probe_shape=probe.shape[-1], detector_shape=data.shape[-1], nz=psi.shape[-2], n=psi.shape[-1], ntheta=scan.shape[0], model=model, ) as operator, Comm(num_gpu, mpi) as comm: logger.info("{} for {:,d} - {:,d} by {:,d} frames for {:,d} " "iterations.".format(algorithm, *data.shape[-3:], num_iter)) num_batch = 1 if batch_size is None else max( 1, int(data.shape[-3] / batch_size / comm.pool.num_workers), ) # Divide the inputs into regions odd_pool = comm.pool.num_workers % 2 order, scan, data, eigen_weights = split_by_scan_grid( comm.pool, ( comm.pool.num_workers if odd_pool else comm.pool.num_workers // 2, 1 if odd_pool else 2, ), scan, data, eigen_weights, ) result = { 'psi': comm.pool.bcast(psi.astype('complex64')), 'probe': comm.pool.bcast(probe.astype('complex64')), 'eigen_probe': comm.pool.bcast(eigen_probe.astype('complex64')) if eigen_probe is not None else None, 'scan': scan, 'eigen_weights': eigen_weights, } for key, value in kwargs.items(): if np.ndim(value) > 0: kwargs[key] = comm.pool.bcast(value) result['probe'] = comm.pool.bcast( _rescale_obj_probe( operator, comm, data[0], result['psi'][0], scan[0], result['probe'][0], num_batch=num_batch, )) costs = [] times = [] start = time.perf_counter() for i in range(num_iter): logger.info(f"{algorithm} epoch {i:,d}") kwargs.update(result) result = getattr(solvers, algorithm)( operator, comm, data=data, num_batch=num_batch, **kwargs, ) if result['cost'] is not None: costs.append(result['cost']) times.append(time.perf_counter() - start) start = time.perf_counter() # Check for early termination if i > 0 and abs((costs[-1] - costs[-2]) / costs[-2]) < rtol: logger.info( "Cost function rtol < %g reached at %d " "iterations.", rtol, i) break reorder = np.argsort(np.concatenate(order)) result['scan'] = comm.pool.gather(scan, axis=1)[:, reorder] if 'eigen_weights' in result: result['eigen_weights'] = comm.pool.gather( eigen_weights, axis=1, )[:, reorder] result['eigen_probe'] = result['eigen_probe'][0] result['probe'] = result['probe'][0] result['cost'] = operator.asarray(costs) result['times'] = operator.asarray(times) for k, v in result.items(): if isinstance(v, list): result[k] = v[0] return {k: operator.asnumpy(v) for k, v in result.items()} else: raise ValueError(f"The '{algorithm}' algorithm is not an option.\n" f"\tAvailable algorithms are : {solvers.__all__}")
def _rescale_obj_probe(operator, comm, data, psi, scan, probe, num_batch): """Keep the object amplitude around 1 by scaling probe by a constant.""" i = batch_indicies(data.shape[-3], num_batch, use_random=True)[0] intensity, _ = operator._compute_intensity(data[..., i, :, :], psi, scan[..., i, :], probe) rescale = (np.linalg.norm(np.ravel(np.sqrt(data[..., i, :, :]))) / np.linalg.norm(np.ravel(np.sqrt(intensity)))) logger.info("object and probe rescaled by %f", rescale) probe *= rescale return probe def split_by_scan_grid(pool, shape, scan, *args, fly=1): """Split the field of view into a 2D grid. Mask divide the data into a 2D grid of spatially contiguous regions. Parameters ---------- shape : tuple of int The number of grid divisions along each dimension. scan : (ntheta, nscan, 2) float32 The 2D coordinates of the scan positions. args : (ntheta, nscan, ...) float32 The arrays to be split by scan position. fly : int The number of scan positions per frame. Returns ------- order : List[array[int]] The locations of the inputs in the original arrays. scan : List[array[float32]] The divided 2D coordinates of the scan positions. args : List[array[float32]] Each input divided into regions. """ if len(shape) != 2: raise ValueError('The grid shape must have two dimensions.') vstripes = split_by_scan_stripes(scan, shape[0], axis=0, fly=fly) hstripes = split_by_scan_stripes(scan, shape[1], axis=1, fly=fly) mask = [np.logical_and(*pair) for pair in product(vstripes, hstripes)] order = np.arange(scan.shape[1]) order = [order[m] for m in mask] def split(m, x): return None if x is None else cp.asarray(x[:, m], dtype='float32') split_args = [list(pool.map(split, mask, x=arg)) for arg in [scan, *args]] return (order, *split_args) def split_by_scan_stripes(scan, n, fly=1, axis=0): """Return `n` boolean masks that split the field of view into stripes. Mask divide the data into spatially contiguous regions along the position axis. Split scan into three stripes: >>> [scan[:, s] for s in split_by_scan_stripes(scan, 3)] FIXME: Only uses the first view to divide the positions. Assumes the positions on all angles are distributed similarly. Parameters ---------- scan : (ntheta, nscan, 2) float32 The 2D coordinates of the scan positions. n : int The number of stripes. fly : int The number of scan positions per frame. axis : int (0 or 1) Which spatial dimension to divide along. i.e. horizontal or vertical. Returns ------- mask : list of (nscan, ) boolean A list of boolean arrays which divide the scan positions into `n` stripes. """ if scan.ndim != 3: raise ValueError('scan must have three dimensions.') if n < 1: raise ValueError('The number of stripes must be > 0.') ntheta, nscan, _ = scan.shape if (nscan // fly) * fly != nscan: raise ValueError('The number of scan positions must be an ' 'integer multiple of the number of fly positions.') # Reshape scan so positions in the same fly scan are not separated scan = scan.reshape(ntheta, nscan // fly, fly, 2) # Determine the edges of the horizontal stripes edges = np.linspace( scan[..., axis].min(), scan[..., axis].max(), n + 1, endpoint=True, ) # Move the outer edges to include all points edges[0] -= 1 edges[-1] += 1 # Generate masks which put points into stripes return [ np.logical_and( edges[i] < scan[0, :, 0, axis], scan[0, :, 0, axis] <= edges[i + 1], ).repeat(fly) for i in range(n) ]