import typing
import cupy as cp
import numpy.typing as npt
import tike.communicators
import tike.operators
import tike.precision
from .options import ObjectOptions, ProbeOptions, PtychoParameters
def _rolling_average_object(parameters: PtychoParameters, new):
# if parameters.object_options.preconditioner is None:
# parameters.object_options.preconditioner = new
# else:
# parameters.object_options.preconditioner = 0.5 * (
# new + parameters.object_options.preconditioner
# )
parameters.object_options.preconditioner = new
return parameters
def _rolling_average_probe(parameters: PtychoParameters, new):
# if parameters.probe_options.preconditioner is None:
# parameters.probe_options.preconditioner = new
# else:
# parameters.probe_options.preconditioner = 0.5 * (
# new + parameters.probe_options.preconditioner
# )
parameters.probe_options.preconditioner = new
return parameters
@cp.fuse()
def _probe_amp_sum(probe):
return cp.sum(
probe * cp.conj(probe),
axis=-3,
)
def _psi_preconditioner(
parameters: PtychoParameters,
streams: typing.List[cp.cuda.Stream],
*,
operator: tike.operators.Ptycho,
) -> npt.NDArray:
psi_update_denominator = cp.zeros(
shape=parameters.psi.shape,
dtype=parameters.psi.dtype,
)
def make_certain_args_constant(
ind_args,
lo: int,
hi: int,
) -> None:
nonlocal psi_update_denominator
probe_amp = _probe_amp_sum(parameters.probe)[:, 0]
psi_update_denominator[0] = operator.diffraction.patch.adj(
patches=probe_amp,
images=psi_update_denominator[0],
positions=parameters.scan[lo:hi],
)
probe1 = parameters.probe[:, 0]
for i in range(1, len(parameters.psi)):
probe1 = operator.diffraction.diffraction.fwd(
probe=probe1,
scan=parameters.scan[lo:hi],
psi=parameters.psi[i-1],
)
probe1 = operator.diffraction.propagation.fwd(probe1)
probe_amp = _probe_amp_sum(probe1)
psi_update_denominator[i] = operator.diffraction.patch.adj(
patches=probe_amp,
images=psi_update_denominator[i],
positions=parameters.scan[lo:hi],
)
tike.communicators.stream.stream_and_modify2(
f=make_certain_args_constant,
ind_args=[],
streams=streams,
lo=0,
hi=len(parameters.scan),
)
return psi_update_denominator
@cp.fuse()
def _patch_amp_sum(patches):
return cp.sum(
patches * cp.conj(patches),
axis=0,
keepdims=False,
)
def _probe_preconditioner(
parameters: PtychoParameters,
streams: typing.List[cp.cuda.Stream],
*,
operator: tike.operators.Ptycho,
) -> npt.NDArray:
probe_update_denominator = cp.zeros(
shape=( parameters.psi.shape[0], *parameters.probe.shape[-2:] ),
dtype=parameters.probe.dtype,
)
def make_certain_args_constant(
ind_args,
lo: int,
hi: int,
) -> None:
nonlocal probe_update_denominator
for i in range(0, len(parameters.psi)):
patches = operator.diffraction.patch.fwd(
images=parameters.psi[ i, ... ],
positions=parameters.scan[lo:hi],
patch_width=parameters.probe.shape[-1],
)
probe_update_denominator[ i, ...] += _patch_amp_sum(patches)
assert probe_update_denominator.ndim == 3
# patches = operator.diffraction.patch.fwd(
# images=parameters.psi[0],
# positions=parameters.scan[lo:hi],
# patch_width=parameters.probe.shape[-1],
# )
# probe_update_denominator[...] += _patch_amp_sum(patches)
# assert probe_update_denominator.ndim == 2
tike.communicators.stream.stream_and_modify2(
f=make_certain_args_constant,
ind_args=[],
streams=streams,
lo=0,
hi=len(parameters.scan),
)
return probe_update_denominator
[docs]def update_preconditioners(
comm: tike.communicators.Comm,
parameters: typing.List[PtychoParameters],
operator: tike.operators.Ptycho,
) -> typing.List[PtychoParameters]:
"""Update the probe and object preconditioners."""
if parameters[0].object_options:
preconditioner = comm.pool.map(
_psi_preconditioner,
parameters,
comm.pool.streams,
operator=operator,
)
# preconditioner = comm.Allreduce(preconditioner)
parameters = comm.pool.map(
_rolling_average_object,
parameters,
preconditioner,
)
if parameters[0].probe_options:
preconditioner = comm.pool.map(
_probe_preconditioner,
parameters,
comm.pool.streams,
operator=operator,
)
# preconditioner = comm.Allreduce(preconditioner)
parameters = comm.pool.map(
_rolling_average_probe,
parameters,
preconditioner,
)
return parameters