Source code for tike.ptycho.solvers.lstsq

import logging
import typing

import cupy as cp
import cupyx.scipy.stats
import numpy as np
import numpy.typing as npt

import tike.communicators
import tike.linalg
import tike.operators
import tike.opt
import tike.random
import tike.ptycho.position
import tike.ptycho.probe
import tike.ptycho.object
import tike.ptycho.exitwave
import tike.precision

from .options import *

logger = logging.getLogger(__name__)


[docs]def lstsq_grad( op: tike.operators.Ptycho, comm: tike.communicators.Comm, data: typing.List[npt.NDArray], batches: typing.List[npt.NDArray[cp.intc]], *, parameters: PtychoParameters, epoch: int, ): """Solve the ptychography problem using Odstrcil et al's approach. Object and probe are updated simultaneously using optimal step sizes computed using a least squares approach. Parameters ---------- op : :py:class:`tike.operators.Ptycho` A ptychography operator. comm : :py:class:`tike.communicators.Comm` An object which manages communications between GPUs and nodes. data : list((FRAME, WIDE, HIGH) float32, ...) A list of unique CuPy arrays for each device containing the intensity (square of the absolute value) of the propagated wavefront; i.e. what the detector records. FFT-shifted so the diffraction peak is at the corners. batches : list(list((BATCH_SIZE, ) int, ...), ...) A list of list of indices along the FRAME axis of `data` for each device which define the batches of `data` to process simultaneously. parameters : :py:class:`tike.ptycho.solvers.PtychoParameters` An object which contains reconstruction parameters. Returns ------- result : dict A dictionary containing the updated keyword-only arguments passed to this function. References ---------- Michal Odstrcil, Andreas Menzel, and Manuel Guizar-Sicaros. Iterative least-squares solver for generalized maximum-likelihood ptychography. Optics Express. 2018. .. seealso:: :py:mod:`tike.ptycho` """ probe = parameters.probe scan = parameters.scan psi = parameters.psi algorithm_options = parameters.algorithm_options probe_options = parameters.probe_options if probe_options is None: recover_probe = False else: recover_probe = probe_options.recover_probe position_options = parameters.position_options object_options = parameters.object_options exitwave_options = parameters.exitwave_options eigen_probe = parameters.eigen_probe eigen_weights = parameters.eigen_weights position_update_numerator = [None] * comm.pool.num_workers position_update_denominator = [None] * comm.pool.num_workers if eigen_probe is None: beigen_probe = [None] * comm.pool.num_workers else: beigen_probe = eigen_probe if eigen_weights is None: beigen_weights = [None] * comm.pool.num_workers else: beigen_weights = eigen_weights if object_options is not None: if algorithm_options.batch_method == 'compact': object_options.combined_update = cp.zeros_like(psi[0]) if recover_probe: probe_options.probe_update_sum = cp.zeros_like(probe[0]) if parameters.algorithm_options.batch_method == 'compact': order = range else: order = tike.random.randomizer_np.permutation batch_cost = [] beta_object = [] beta_probe = [] for batch_index in order(algorithm_options.num_batch): ( diff, unique_probe, probe_update, object_upd_sum, m_probe_update, costs, patches, position_update_numerator, position_update_denominator, position_options, ) = (list(a) for a in zip(*comm.pool.map( _get_nearplane_gradients, data, psi, scan, probe, beigen_probe, beigen_weights, batches, position_update_numerator, position_update_denominator, [None] * comm.pool.num_workers if position_options is None else position_options, comm.streams, exitwave_options.measured_pixels, object_options.preconditioner, batch_index=batch_index, num_batch=algorithm_options.num_batch, exitwave_options=exitwave_options, op=op, recover_psi=object_options is not None, recover_probe=recover_probe, recover_positions=position_options is not None, ))) position_options = None if position_options[ 0] is None else position_options if object_options is not None: object_upd_sum = comm.Allreduce(object_upd_sum) if recover_probe: m_probe_update = comm.pool.bcast( [comm.Allreduce_mean( m_probe_update, axis=-5, )]) ( beigen_probe, beigen_weights, ) = _update_nearplane( comm, diff, probe_update, m_probe_update, probe, beigen_probe, beigen_weights, patches, batches, batch_index=batch_index, num_batch=algorithm_options.num_batch, ) ( object_update_precond, A1, A2, A4, b1, b2, ) = (list(a) for a in zip(*comm.pool.map( _precondition_nearplane_gradients, diff, scan, unique_probe, probe, object_upd_sum, m_probe_update, object_options.preconditioner, patches, batches, batch_index=batch_index, op=op, m=0, recover_psi=object_options is not None, recover_probe=recover_probe, probe_options=probe_options, ))) if object_options is not None: A1_delta = comm.pool.bcast([comm.Allreduce_mean(A1, axis=-3)]) else: A1_delta = [None] * comm.pool.num_workers if recover_probe: A4_delta = comm.pool.bcast([comm.Allreduce_mean(A4, axis=-3)]) else: A4_delta = [None] * comm.pool.num_workers ( weighted_step_psi, weighted_step_probe, ) = (list(a) for a in zip(*comm.pool.map( _get_nearplane_steps, A1, A2, A4, b1, b2, A1_delta, A4_delta, recover_psi=object_options is not None, recover_probe=recover_probe, m=0, ))) if object_options is not None: bbeta_object = comm.Allreduce_mean( weighted_step_psi, axis=-5, )[..., 0, 0, 0] if recover_probe: bbeta_probe = comm.Allreduce_mean( weighted_step_probe, axis=-5, ) # Update each direction if object_options is not None: if algorithm_options.batch_method != 'compact': # (27b) Object update dpsi = bbeta_object[0] * object_update_precond[0] if object_options.use_adaptive_moment: ( dpsi, object_options.v, object_options.m, ) = tike.opt.momentum( g=dpsi, v=object_options.v, m=object_options.m, vdecay=object_options.vdecay, mdecay=object_options.mdecay, ) psi[0] = psi[0] + dpsi psi = comm.pool.bcast([psi[0]]) else: object_options.combined_update += object_upd_sum[0] if recover_probe: dprobe = bbeta_probe[0] * m_probe_update[0] probe_options.probe_update_sum += dprobe / algorithm_options.num_batch # (27a) Probe update probe[0] += dprobe probe = comm.pool.bcast([probe[0]]) for c in costs: batch_cost = batch_cost + c.tolist() if object_options is not None: beta_object.append(bbeta_object) if recover_probe: beta_probe.append(bbeta_probe) if eigen_probe is not None: eigen_probe = beigen_probe if eigen_weights is not None: eigen_weights = beigen_weights if position_options: scan, position_options = zip(*comm.pool.map( _update_position, scan, position_options, position_update_numerator, position_update_denominator, epoch=epoch, )) algorithm_options.costs.append(batch_cost) if object_options and algorithm_options.batch_method == 'compact': object_update_precond = _precondition_object_update( object_options.combined_update, object_options.preconditioner[0], ) # (27b) Object update beta_object = cp.mean(cp.stack(beta_object)) dpsi = beta_object * object_update_precond psi[0] = psi[0] + dpsi if object_options.use_adaptive_moment: ( dpsi, object_options.v, object_options.m, ) = _momentum_checked( g=dpsi, v=object_options.v, m=object_options.m, mdecay=object_options.mdecay, errors=list(np.mean(x) for x in algorithm_options.costs[-3:]), beta=beta_object, memory_length=3, ) weight = object_options.preconditioner[0] weight = weight / (0.1 * weight.max() + weight) psi[0] = psi[0] + weight * dpsi psi = comm.pool.bcast([psi[0]]) if recover_probe: if probe_options.use_adaptive_moment: beta_probe = cp.mean(cp.stack(beta_probe)) dprobe = probe_options.probe_update_sum if probe_options.v is None: probe_options.v = np.zeros_like( dprobe, shape=(3, *dprobe.shape), ) if probe_options.m is None: probe_options.m = np.zeros_like(dprobe,) # ptychoshelves only applies momentum to the main probe mode = 0 ( d, probe_options.v[..., mode, :, :], probe_options.m[..., mode, :, :], ) = _momentum_checked( g=dprobe[..., mode, :, :], v=probe_options.v[..., mode, :, :], m=probe_options.m[..., mode, :, :], mdecay=probe_options.mdecay, errors=list(np.mean(x) for x in algorithm_options.costs[-3:]), beta=beta_probe, memory_length=3, ) probe[0][..., mode, :, :] = probe[0][..., mode, :, :] + d probe = comm.pool.bcast([probe[0]]) parameters.probe = probe parameters.psi = psi parameters.scan = scan parameters.algorithm_options = algorithm_options parameters.probe_options = probe_options parameters.object_options = object_options parameters.position_options = position_options parameters.eigen_weights = eigen_weights parameters.eigen_probe = eigen_probe return parameters
def _update_nearplane( comm: tike.communicators.Comm, diff, probe_update, m_probe_update, probe: typing.List[npt.NDArray[cp.csingle]], eigen_probe: typing.List[npt.NDArray[cp.csingle]], eigen_weights: typing.List[npt.NDArray[cp.single]], patches, batches, *, batch_index: int, num_batch: int, ): m = 0 if eigen_weights[0] is not None: eigen_weights = comm.pool.map( _get_coefs_intensity, eigen_weights, diff, probe, patches, batches, batch_index=batch_index, m=m, ) # (30) residual probe updates if eigen_weights[0].shape[-2] > 1: R = comm.pool.map( _get_residuals, probe_update, m_probe_update, m=m, ) if eigen_probe[0] is not None and m < eigen_probe[0].shape[-3]: assert eigen_weights[0].shape[-2] == eigen_probe[0].shape[-4] + 1 for eigen_index in range(1, eigen_probe[0].shape[-4] + 1): ( eigen_probe, eigen_weights, ) = tike.ptycho.probe.update_eigen_probe( comm, R, eigen_probe, eigen_weights, patches, diff, batches, batch_index=batch_index, β=min(0.1, 1.0 / num_batch), c=eigen_index, m=m, ) if eigen_index + 1 < eigen_weights[0].shape[-2]: # Subtract projection of R onto new probe from R R = comm.pool.map( _update_residuals, R, eigen_probe, batches, batch_index=batch_index, axis=(-2, -1), c=eigen_index - 1, m=m, ) return ( eigen_probe, eigen_weights, ) def _get_nearplane_gradients( data: npt.NDArray, psi: npt.NDArray[cp.csingle], scan: npt.NDArray[cp.single], probe: npt.NDArray[cp.csingle], eigen_probe, eigen_weights, batches, position_update_numerator, position_update_denominator, position_options: PositionOptions, streams: typing.List[cp.cuda.Stream], measured_pixels: npt.NDArray, object_preconditioner: npt.NDArray[cp.csingle], *, batch_index: int, num_batch: int, op: tike.operators.Ptycho, recover_psi: bool, recover_probe: bool, recover_positions: bool, exitwave_options: ExitWaveOptions, ): batch_start = batches[batch_index][0] batch_size = len(batches[batch_index]) # These variables are only as large as the batch bcosts = cp.empty(shape=batch_size, dtype=tike.precision.floating) bchi = cp.empty_like( probe, shape=(batch_size, 1, *probe.shape[-3:]), ) bpatches = cp.empty_like( probe, shape=(batch_size, 1, 1, *probe.shape[-2:]), ) bprobe_update = cp.empty_like( probe, shape=bchi.shape, ) bunique_probe = cp.empty_like( probe, shape=(batch_size, 1, *probe.shape[-3:]), ) # These variables are as large as the entire dataset m_probe_update = cp.zeros_like(probe) object_upd_sum = cp.zeros_like(psi) position_update_numerator = cp.empty_like( scan ) if position_update_numerator is None else position_update_numerator position_update_denominator = cp.empty_like( scan ) if position_update_denominator is None else position_update_denominator def keep_some_args_constant( ind_args, lo: int, hi: int, ) -> None: (data,) = ind_args nonlocal bchi, bunique_probe, bprobe_update, object_upd_sum nonlocal m_probe_update, bcosts, bpatches, position_update_numerator nonlocal position_update_denominator blo = lo - batch_start bhi = hi - batch_start bunique_probe[blo:bhi] = tike.ptycho.probe.get_varying_probe( probe, eigen_probe, eigen_weights[lo:hi] if eigen_weights is not None else None, ) farplane = op.fwd(probe=bunique_probe[blo:bhi], scan=scan[lo:hi], psi=psi) intensity = cp.sum( cp.square(cp.abs(farplane)), axis=list(range(1, farplane.ndim - 2)), ) bcosts[blo:bhi] = getattr( tike.operators, f'{exitwave_options.noise_model}_each_pattern')( data[:, measured_pixels][:, None, :], intensity[:, measured_pixels][:, None, :], ) if exitwave_options.noise_model == 'poisson': xi = (1 - data / (intensity + 1e-9))[:, None, None, ...] grad_cost = farplane * xi step_length = cp.full( shape=(farplane.shape[0], 1, farplane.shape[2], 1, 1), fill_value=exitwave_options.step_length_start, ) if exitwave_options.step_length_usemodes == 'dominant_mode': step_length = tike.ptycho.exitwave.poisson_steplength_dominant_mode( xi, intensity, data, measured_pixels, step_length, exitwave_options.step_length_weight, ) else: step_length = tike.ptycho.exitwave.poisson_steplength_all_modes( xi, cp.square(cp.abs(farplane)), intensity, data, measured_pixels, step_length, exitwave_options.step_length_weight, ) farplane[..., measured_pixels] = (-step_length * grad_cost)[..., measured_pixels] else: farplane[..., measured_pixels] = -getattr( tike.operators, f'{exitwave_options.noise_model}_grad')( data, farplane, intensity, )[..., measured_pixels] unmeasured_pixels = cp.logical_not(measured_pixels) farplane[..., unmeasured_pixels] *= ( exitwave_options.unmeasured_pixels_scaling - 1.0) farplane = op.propagation.adj(farplane, overwrite=True) pad, end = op.diffraction.pad, op.diffraction.end bchi[blo:bhi] = farplane[..., pad:end, pad:end] # Get update directions for each scan positions if recover_psi: # (24b) object_update_proj = cp.conj(bunique_probe[blo:bhi]) * bchi[blo:bhi] # (25b) Common object gradient. object_upd_sum = op.diffraction.patch.adj( patches=object_update_proj.reshape( len(scan[lo:hi]) * bchi.shape[-3], *bchi.shape[-2:]), images=object_upd_sum, positions=scan[lo:hi], nrepeat=bchi.shape[-3], ) else: object_upd_sum = None if recover_probe: bpatches[blo:bhi] = op.diffraction.patch.fwd( patches=cp.zeros_like(bchi[blo:bhi, ..., 0, 0, :, :]), images=psi, positions=scan[lo:hi], )[..., None, None, :, :] # (24a) bprobe_update[blo:bhi] = cp.conj(bpatches[blo:bhi]) * bchi[blo:bhi] # (25a) Common probe gradient. Use simple average instead of # division as described in publication because that's what # ptychoshelves does m_probe_update += cp.sum( bprobe_update[blo:bhi], axis=-5, keepdims=True, ) else: bprobe_update = None m_probe_update = None bpatches = None if position_options: m = 0 # TODO: Try adjusting gradient sigma property grad_x, grad_y = tike.ptycho.position.gaussian_gradient( bpatches[blo:bhi]) # start section to compute position certainty metric crop = probe.shape[-1] // 4 total_illumination = op.diffraction.patch.fwd( images=object_preconditioner, positions=scan[lo:hi], patch_width=probe.shape[-1], )[:, crop:-crop, crop:-crop].real power = cp.abs(probe[0, 0, 0, crop:-crop, crop:-crop])**2 dX = cp.mean( cp.abs(grad_x[:, 0, 0, crop:-crop, crop:-crop]).real * total_illumination * power, axis=(-2, -1), keepdims=False, ) dY = cp.mean( cp.abs(grad_y[:, 0, 0, crop:-crop, crop:-crop]).real * total_illumination * power, axis=(-2, -1), keepdims=False, ) total_variation = cp.sqrt(cp.stack( [dX, dY], axis=1, )) mean_variation = (cp.mean( total_variation**4, axis=0, ) + 1e-6) position_options.confidence[ lo:hi] = total_variation**4 / mean_variation # end section to compute position certainty metric position_update_numerator[lo:hi, ..., 0] = cp.sum( cp.real( cp.conj(grad_x[..., crop:-crop, crop:-crop] * bunique_probe[blo:bhi, ..., m:m + 1, crop:-crop, crop:-crop]) * bchi[blo:bhi, ..., m:m + 1, crop:-crop, crop:-crop]), axis=(-4, -3, -2, -1), ) position_update_denominator[lo:hi, ..., 0] = cp.sum( cp.abs(grad_x[..., crop:-crop, crop:-crop] * bunique_probe[blo:bhi, ..., m:m + 1, crop:-crop, crop:-crop])**2, axis=(-4, -3, -2, -1), ) position_update_numerator[lo:hi, ..., 1] = cp.sum( cp.real( cp.conj(grad_y[..., crop:-crop, crop:-crop] * bunique_probe[blo:bhi, ..., m:m + 1, crop:-crop, crop:-crop]) * bchi[blo:bhi, ..., m:m + 1, crop:-crop, crop:-crop]), axis=(-4, -3, -2, -1), ) position_update_denominator[lo:hi, ..., 1] = cp.sum( cp.abs(grad_y[..., crop:-crop, crop:-crop] * bunique_probe[blo:bhi, ..., m:m + 1, crop:-crop, crop:-crop])**2, axis=(-4, -3, -2, -1), ) tike.communicators.stream.stream_and_modify2( f=keep_some_args_constant, ind_args=[ data, ], streams=streams, lo=batch_start, hi=batch_start + batch_size, ) return ( bchi, bunique_probe, bprobe_update, object_upd_sum, m_probe_update / num_batch if m_probe_update is not None else None, bcosts, bpatches, position_update_numerator, position_update_denominator, position_options, ) def _precondition_object_update( object_upd_sum: npt.NDArray[cp.csingle], psi_update_denominator: npt.NDArray[cp.csingle], alpha: float = 0.05, ) -> npt.NDArray[cp.csingle]: return object_upd_sum / cp.sqrt( cp.square((1 - alpha) * psi_update_denominator) + cp.square(alpha * cp.amax( psi_update_denominator, axis=(-2, -1), keepdims=True, ))) def _precondition_nearplane_gradients( nearplane, scan, unique_probe, probe, object_upd_sum, m_probe_update, psi_update_denominator, patches, batches, *, batch_index: int, op, m, recover_psi, recover_probe, alpha=0.05, probe_options, ): lo = batches[batch_index][0] hi = lo + len(batches[batch_index]) eps = op.xp.float32(1e-9) / (nearplane.shape[-2] * nearplane.shape[-1]) A1 = None A2 = None A4 = None b1 = None b2 = None dOP = None dPO = None object_update_proj = None if recover_psi: object_update_precond = _precondition_object_update( object_upd_sum, psi_update_denominator, ) object_update_proj = op.diffraction.patch.fwd( patches=cp.zeros_like(nearplane[..., 0, 0, :, :]), images=object_update_precond, positions=scan[lo:hi], ) dOP = object_update_proj[..., None, None, :, :] * unique_probe[..., m:m + 1, :, :] A1 = cp.sum((dOP * dOP.conj()).real + eps, axis=(-2, -1)) if recover_probe: # b0 = tike.ptycho.probe.finite_probe_support( # unique_probe[..., m:m+1, :, :], # p=probe_options.probe_support, # radius=probe_options.probe_support_radius, # degree=probe_options.probe_support_degree, # ) # b1 = probe_options.additional_probe_penalty * cp.linspace( # 0, # 1, # probe[0].shape[-3], # dtype=tike.precision.floating, # )[..., m:m+1, None, None] # m_probe_update = (m_probe_update - # (b0 + b1) * probe[..., m:m+1, :, :]) / ( # (1 - alpha) * probe_update_denominator + # alpha * probe_update_denominator.max( # axis=(-2, -1), # keepdims=True, # ) + b0 + b1) dPO = m_probe_update[..., m:m + 1, :, :] * patches A4 = cp.sum((dPO * dPO.conj()).real + eps, axis=(-2, -1)) if recover_psi and recover_probe: b1 = cp.sum((dOP.conj() * nearplane[..., m:m + 1, :, :]).real, axis=(-2, -1)) b2 = cp.sum((dPO.conj() * nearplane[..., m:m + 1, :, :]).real, axis=(-2, -1)) A2 = cp.sum((dOP * dPO.conj()), axis=(-2, -1)) elif recover_psi: b1 = cp.sum((dOP.conj() * nearplane[..., m:m + 1, :, :]).real, axis=(-2, -1)) elif recover_probe: b2 = cp.sum((dPO.conj() * nearplane[..., m:m + 1, :, :]).real, axis=(-2, -1)) return ( object_update_precond, A1, A2, A4, b1, b2, ) def _get_nearplane_steps(A1, A2, A4, b1, b2, A1_delta, A4_delta, recover_psi, recover_probe, m): if recover_psi: A1 += 0.5 * A1_delta if recover_probe: A4 += 0.5 * A4_delta # (22) Use least-squares to find the optimal step sizes simultaneously if recover_psi and recover_probe: A3 = A2.conj() determinant = A1 * A4 - A2 * A3 x1 = -cp.conj(A2 * b2 - A4 * b1) / determinant x2 = cp.conj(A1 * b2 - A3 * b1) / determinant elif recover_psi: x1 = b1 / A1 elif recover_probe: x2 = b2 / A4 else: x1 = None x2 = None if recover_psi: step = 0.9 * cp.maximum(0, x1[..., None, None].real) # (27b) Object update beta_object = cp.mean(step, keepdims=True, axis=-5) else: beta_object = None if recover_probe: step = 0.9 * cp.maximum(0, x2[..., None, None].real) beta_probe = cp.mean(step, axis=-5, keepdims=True) else: beta_probe = None return beta_object, beta_probe def _get_coefs_intensity(weights, xi, P, O, batches, *, batch_index, m): """ Parameters ---------- weights : (B, C, M) xi : (B, 1, M, H, W) P : (B, 1, M, H, W) O : (B, 1, 1, H, W) """ lo = batches[batch_index][0] hi = lo + len(batches[batch_index]) OP = O * P[:, :, m:m + 1, :, :] num = cp.sum(cp.real(cp.conj(OP) * xi[:, :, m:m + 1, :, :]), axis=(-1, -2)) den = cp.sum(cp.abs(OP)**2, axis=(-1, -2)) weights[lo:hi, 0:1, m:m + 1] += 0.1 * num / den return weights def _get_residuals(grad_probe, grad_probe_mean, m): """ Parameters ---------- grad_probe : (B, 1, M, H, W) grad_probe_mean : (1, 1, M, H, W) """ return grad_probe[..., m:m + 1, :, :] - grad_probe_mean[..., m:m + 1, :, :] def _update_residuals(R, eigen_probe, batches, *, batch_index, axis, c, m): """ Parameters ---------- R : (B, 1, 1, H, W) eigen_probe : (1, C, M, H, W) """ R -= tike.linalg.projection( R, eigen_probe[:, c:c + 1, m:m + 1, :, :], axis=axis, ) return R def _update_position( scan: npt.NDArray, position_options: PositionOptions, position_update_numerator: npt.NDArray, position_update_denominator: npt.NDArray, *, alpha=0.05, max_shift=1, epoch=0, ): if epoch < position_options.update_start: return scan, position_options step = (position_update_numerator) / ( (1 - alpha) * position_update_denominator + alpha * max(position_update_denominator.max(), 1e-6)) if position_options.update_magnitude_limit > 0: step = cp.clip( step, a_min=-position_options.update_magnitude_limit, a_max=position_options.update_magnitude_limit, ) # Remove outliars and subtract the mean step = step - cupyx.scipy.stats.trim_mean(step, 0.05) if position_options.use_adaptive_moment: ( step, position_options.v, position_options.m, ) = tike.opt.adam( step, position_options.v, position_options.m, vdecay=position_options.vdecay, mdecay=position_options.mdecay, ) scan -= step return scan, position_options def _momentum_checked( g: npt.NDArray, v: typing.Union[None, npt.NDArray], m: typing.Union[None, npt.NDArray], mdecay: float, errors: typing.List[float], beta: float = 1.0, memory_length: int = 3, vdecay=None, ) -> typing.Tuple[npt.NDArray, npt.NDArray, npt.NDArray]: """Momentum updates, but only if the cost function is trending downward. Parameters ---------- previous_g (EPOCH, WIDTH, HEIGHT) The previous psi updates g (WIDTH, HEIGHT) The current psi update """ m = np.zeros_like(g,) if m is None else m previous_g = np.zeros_like( g, shape=(memory_length, *g.shape), ) if v is None else v # Keep a running list of the update directions previous_g = np.roll(previous_g, shift=-1, axis=0) previous_g[-1] = g / tike.linalg.norm(g) * beta # Only apply momentum updates if the objective function is decreasing if (len(errors) > 2 and max(errors[-3], errors[-2]) > min(errors[-2], errors[-1])): # Check that previous updates are moving in a similar direction previous_update_correlation = tike.linalg.inner( previous_g[:-1], previous_g[-1], axis=(-2, -1), ).real.flatten() if np.all(previous_update_correlation > 0): friction, _ = tike.opt.fit_line_least_squares( x=np.arange(len(previous_update_correlation) + 1), y=[ 0, ] + np.log(previous_update_correlation).tolist(), ) friction = 0.5 * max(-friction, 0) m = (1 - friction) * m + g return mdecay * m, previous_g, m return np.zeros_like(g), previous_g, m / 2