"""Functions related to creating and manipulating ptychographic exitwave arrays.
Ptychographic exitwaves are stored as a single complex array which represent
the wavefield after any and all interaction with the sample and thus there's
just free space propagation to the detector.
"""
from __future__ import annotations
import copy
import dataclasses
import logging
import cupy as cp
import numpy as np
import numpy.typing as npt
logger = logging.getLogger(__name__)
[docs]@dataclasses.dataclass
class ExitWaveOptions:
"""Manage data and setting related to exitwave updates.
.. versionadded:: 0.25.0
"""
measured_pixels: npt.NDArray[np.bool_]
"""
A boolean array that defines spatial regions on the detector where we have
measured pixels. False for bad pixels and True for good pixels.
"""
noise_model: str = "gaussian"
"""`'gaussian'` OR `'poisson'` The noise model for the exitwave updates"""
step_length_weight: float = 0.5
"""
When computing steplength, we use a weighted average of the previous
calculated step and the current calculated step, options are 0.0 <=
step_length_weight <= 1.0, with being closer to 1.0 favoring the current
calculated step
"""
step_length_usemodes: str = "all_modes"
"""
When computing steplength for exitwave updates, there are two ways we do
this:
"dominant_mode" - use the dominant mode to compute the steplength and
use that steplength for the other less dominant modes
"all_modes" - compute the steplength for each mode independently
"""
step_length_start: float = 0.5
"""
We use an iterative/recursive method for finding the steplengths, and this
is what we use as initialization for that method.
"""
unmeasured_pixels_scaling: float = 1.00
"""
Depending on how we control scaling for the exitwaves, we might need to
scale up or down the number of photons in the unmeasured regions for the
exitwave updates in Fourier space. `1.0` for no scaling.
"""
propagation_normalization: str = 'ortho'
"""Choose the scaling of the FFT in the forward model:
"ortho" - the forward and adjoint operations are divided by sqrt(n)
"forward" - the forward operation is divided by n
"backward" - the backard operation is divided by n
"""
[docs] def copy_to_device(self, comm) -> ExitWaveOptions:
"""Copy to the current GPU memory."""
options = copy.copy(self)
if self.measured_pixels is not None:
options.measured_pixels = comm.pool.bcast([self.measured_pixels])
return options
[docs] def copy_to_host(self) -> ExitWaveOptions:
"""Copy to the host CPU memory."""
options = copy.copy(self)
if self.measured_pixels is not None:
options.measured_pixels = cp.asnumpy(self.measured_pixels[0])
return options
[docs] def resample(self, factor: float) -> ExitWaveOptions:
"""Return a new `ExitWaveOptions` with the parameters rescaled."""
return ExitWaveOptions(
noise_model=self.noise_model,
step_length_weight=self.step_length_weight,
step_length_start=self.step_length_start,
step_length_usemodes=self.step_length_usemodes,
measured_pixels=crop_fourier_space(
self.measured_pixels,
int(self.measured_pixels.shape[-1] * factor),
),
unmeasured_pixels_scaling=self.unmeasured_pixels_scaling,
propagation_normalization=self.propagation_normalization )
[docs]def poisson_steplength_all_modes(
xi,
abs2_Psi,
I_e,
I_m,
measured_pixels,
step_length,
weight_avg,
):
"""
Compute the optimal steplength for each exitwave mode independently.
Parameters
----------
xi : (FRAME, 1, 1, WIDE, HIGH) float32
xi = 1 - I_m / I_e
abs2_Psi : (FRAME, 1, SHARED, WIDE, HIGH ) float32
the squared absolute value of the calulated exitwaves
I_m : (FRAME, WIDE, HIGH) float32
measured diffraction intensity
I_e : (FRAME, WIDE, HIGH) float32
calculated diffraction intensity
measured_pixels : (WIDE, HIGH) float32
the regions on the detector where we have defined
measurements
step_length : (FRAME, 1, SHARED, 1, 1) float32
the steplength initializations
weight_avg : float
the weight we use when computing a weighted average
with ( 0.0 <= weight_avg <= 1.0 )
"""
I_e = I_e[:, None, None, ...]
I_m = I_m[:, None, None, ...]
xi_abs_Psi2 = xi * abs2_Psi
denom_final = cp.sum(
(xi * xi_abs_Psi2)[..., measured_pixels],
axis=-1,
)
for _ in range(0, 2):
xi_alpha_minus_one = (xi * step_length - 1)
denom = abs2_Psi * cp.square(xi_alpha_minus_one) + I_e - abs2_Psi
numer = cp.sum(
(xi_abs_Psi2 *
(1 + (I_m * xi_alpha_minus_one) / denom))[..., measured_pixels],
axis=-1,
)
step_length = (step_length * (1 - weight_avg) +
(numer / denom_final)[..., None, None] * weight_avg)
return step_length
[docs]def poisson_steplength_dominant_mode(
xi,
I_e,
I_m,
measured_pixels,
step_length,
weight_avg,
):
"""
Compute the optimal steplength for each exitwave mode using only the
dominant mode.
Parameters
----------
xi : (FRAME, 1, 1, WIDE, HIGH) float32
xi = 1 - I_m / I_e
I_m : (FRAME, WIDE, HIGH) float32
measured diffraction intensity
I_e : (FRAME, WIDE, HIGH) float32
calculated diffraction intensity
measured_pixels : (WIDE, HIGH) float32
the regions on the detector where we have defined
measurements
step_length : (FRAME, 1, SHARED, 1, 1) float32
the steplength initializations
weight_avg : float
the weight we use when computing a weighted average
with ( 0.0 <= weight_avg <= 1.0 )
"""
I_e = I_e[:, None, None, ...]
I_m = I_m[:, None, None, ...]
sum_denom = cp.sum(
(cp.square(xi) * I_e)[..., measured_pixels],
axis=-1,
)
for _ in range(0, 2):
numer = xi * (I_e - I_m / (1 - step_length * xi))
numer_over_denom = cp.sum(
numer[..., measured_pixels],
axis=-1,
) / sum_denom
step_length = ((1 - weight_avg) * step_length +
weight_avg * numer_over_denom[..., None, None])
# step_length = cp.abs(cp.fmax(cp.fmin(step_length, 1), 0))
return step_length
[docs]def crop_fourier_space(x: np.ndarray, w: int) -> np.ndarray:
"""Crop x assuming a 2D frequency space image with zero frequency in corner."""
assert x.shape[-2] == x.shape[-1], "Only works on square arrays right now."
half1 = w // 2
half0 = w - half1
# yapf: disable
return x[
..., np.r_[0:half0, (x.shape[-1] - half1):x.shape[-1]],
][
..., np.r_[0:half0, (x.shape[-2] - half1):x.shape[-2]], :,
]
# yapf: enable