Source code for tike.ptycho.solvers.options

from __future__ import annotations
import abc
import dataclasses
import typing

import numpy as np
import numpy.typing as npt
import scipy.ndimage

from tike.ptycho.object import ObjectOptions
from tike.ptycho.position import PositionOptions, check_allowed_positions
from tike.ptycho.probe import ProbeOptions
from tike.ptycho.exitwave import ExitWaveOptions


@dataclasses.dataclass
class IterativeOptions(abc.ABC):
    """A base class providing options for iterative algorithms.

    .. versionadded:: 0.20.0
    """
    name: str = dataclasses.field(default='', init=False)
    """The name of the algorithm."""

    num_batch: int = 1
    """The dataset is divided into this number of groups where each group is
    processed sequentially."""

    batch_method: str = 'wobbly_center'
    """The name of the batch selection method. Choose from the cluster methods
    in the tike.cluster module."""

    costs: typing.List[typing.List[float]] = dataclasses.field(
        init=False,
        default_factory=list,
    )
    """The objective function value at previous iterations. One list is
    returned for each mini-batch."""

    num_iter: int = 1
    """The number of epochs to process before returning."""

    times: typing.List[float] = dataclasses.field(
        init=False,
        default_factory=list,
    )
    """The per-iteration wall-time for each previous iteration."""

    convergence_window: int = 0
    """The number of epochs to consider for convergence monitoring. Set to
    any value less than 2 to disable."""


[docs]@dataclasses.dataclass class DmOptions(IterativeOptions): name: str = dataclasses.field(default='dm', init=False) num_batch: int = 1 """The dataset is divided into this number of groups where each group is processed simultaneously."""
[docs]@dataclasses.dataclass class RpieOptions(IterativeOptions): name: str = dataclasses.field(default='rpie', init=False) num_batch: int = 5 alpha: float = 0.05 """A hyper-parameter which controls the step length. RPIE becomes EPIE when this parameter is 1."""
[docs]@dataclasses.dataclass class LstsqOptions(IterativeOptions): name: str = dataclasses.field(default='lstsq_grad', init=False)
[docs]@dataclasses.dataclass class PtychoParameters(): """A class for storing the ptychography forward model parameters. .. versionadded:: 0.22.0 """ probe: npt.NDArray[np.csingle] """(1, 1, SHARED, WIDE, HIGH) complex64 The shared illumination function amongst all positions.""" psi: npt.NDArray[np.csingle] """(WIDE, HIGH) complex64 The wavefront modulation coefficients of the object.""" scan: npt.NDArray[np.single] """(POSI, 2) float32 Coordinates of the minimum corner of the probe grid for each measurement in the coordinate system of psi. Coordinate order consistent with WIDE, HIGH order.""" eigen_probe: typing.Union[npt.NDArray[np.csingle], None] = None """(EIGEN, SHARED, WIDE, HIGH) complex64 The eigen probes for all positions.""" eigen_weights: typing.Union[npt.NDArray[np.single], None] = None """(POSI, EIGEN, SHARED) float32 The relative intensity of the eigen probes at each position.""" algorithm_options: IterativeOptions = dataclasses.field( default_factory=RpieOptions,) """A class containing algorithm specific parameters""" exitwave_options: typing.Union[ExitWaveOptions, None] = None """A class containing settings related to exitwave updates.""" probe_options: typing.Union[ProbeOptions, None] = None """A class containing settings related to probe updates.""" object_options: typing.Union[ObjectOptions, None] = None """A class containing settings related to object updates.""" position_options: typing.Union[PositionOptions, None] = None """A class containing settings related to position correction.""" def __post_init__(self): if (self.scan.ndim != 2 or self.scan.shape[1] != 2 or np.any(np.asarray(self.scan.shape) < 1)): raise ValueError(f"scan shape {self.scan.shape} is incorrect. " "It should be (N, 2) " "where N >= 1 is the number of scan positions.") if (self.probe.ndim != 5 or self.probe.shape[:2] != (1, 1) or np.any(np.asarray(self.probe.shape) < 1) or self.probe.shape[-2] != self.probe.shape[-1]): raise ValueError(f"probe shape {self.probe.shape} is incorrect. " "It should be (1, 1, S, W, H) " "where S >=1 is the number of probes, and " "W, H >= 1 are the square probe grid dimensions.") if (self.psi.ndim != 2 or np.any( np.asarray(self.psi.shape) <= np.asarray(self.probe.shape[-2:])) ): raise ValueError( f"psi shape {self.psi.shape} is incorrect. " "It should be (W, H) where W, H > probe.shape[-2:].") check_allowed_positions( self.scan, self.psi, self.probe.shape, ) if self.exitwave_options is None: self.exitwave_options = ExitWaveOptions( measured_pixels=np.ones(self.probe.shape[-2:], dtype=np.bool_))
[docs] def resample( self, factor: float, interp: None | typing.Callable[[np.ndarray, float], np.array], ) -> PtychoParameters: """Return a new `PtychoParameter` with the parameters rescaled.""" interp = _resize_fft if interp is None else interp return PtychoParameters( probe=interp(self.probe, factor), psi=_resize_spline(self.psi, factor), scan=self.scan * factor, eigen_probe=interp(self.eigen_probe, factor) if self.eigen_probe is not None else None, eigen_weights=self.eigen_weights, algorithm_options=self.algorithm_options, probe_options=self.probe_options.resample(factor, interp) if self.probe_options is not None else None, object_options=self.object_options.resample(factor, interp) if self.object_options is not None else None, position_options=self.position_options.resample(factor) if self.position_options is not None else None, exitwave_options=self.exitwave_options.resample(factor) if self.exitwave_options is not None else None, )
def _resize_spline(x: np.ndarray, f: float) -> np.ndarray: return scipy.ndimage.zoom( x, zoom=[1] * (x.ndim - 2) + [f, f], grid_mode=True, prefilter=False, ) def _resize_cv(x: np.ndarray, f: float, interpolation: int) -> np.ndarray: import tike.view x_shape = x.shape x = x.reshape(-1, *x_shape[-2:]) x1 = [ tike.view.resize_complex_image( i, scale_factor=(f, f), interpolation=interpolation, ) for i in x ] return np.asarray(x1).reshape(*x_shape[:-2], *x1[0].shape[-2:]) def _resize_linear(x: np.ndarray, f: float) -> np.ndarray: return _resize_cv(x, f, 1) def _resize_cubic(x: np.ndarray, f: float) -> np.ndarray: return _resize_cv(x, f, 2) def _resize_lanczos(x: np.ndarray, f: float) -> np.ndarray: return _resize_cv(x, f, 4)
[docs]def crop_fourier_space(x: np.ndarray, w: int) -> np.ndarray: """Crop x assuming a 2D frequency space image with zero frequency in corner.""" assert x.shape[-2] == x.shape[-1], "Only works on square arrays right now." half1 = w // 2 half0 = w - half1 # yapf: disable return x[ ..., np.r_[0:half0, (x.shape[-1] - half1):x.shape[-1]], ][ ..., np.r_[0:half0, (x.shape[-2] - half1):x.shape[-2]], :, ]
# yapf: enable def pad_fourier_space(x: np.ndarray, w: int) -> np.ndarray: """Pad x assuming a 2D frequency space image with zero frequency in corner.""" assert x.shape[-2] == x.shape[-1], "Only works on square arrays right now." half1 = x.shape[-1] // 2 half0 = x.shape[-1] - half1 new_x = np.zeros_like(x, shape=(*x.shape[:-2], w, w)) new_x[..., 0:half0, np.r_[0:half0, (w - half1):w]] = x[..., 0:half0, :] new_x[..., -half1:w, np.r_[0:half0, (w - half1):w]] = x[..., -half1:, :] return new_x def _resize_fft(x: np.ndarray, f: float) -> np.ndarray: """Use Fourier interpolation to resize/resample the last 2 dimensions of x""" if f == 1: return x crop_or_pad = crop_fourier_space if f < 1 else pad_fourier_space return np.fft.ifft2( crop_or_pad( np.fft.fft2( x, norm='ortho', axes=(-2, -1), ), w=int(x.shape[-1] * f), ), norm='ortho', axes=(-2, -1), )