Denoisers#

Unconditional Denoisers#

Base Denoiser#

class diffuse.denoisers.denoiser.Denoiser(integrator: Integrator, model: DiffusionModel, predictor: Predictor, x0_shape: Tuple[int, ...])[source]#

Bases: BaseDenoiser

Denoiser for generating samples using reverse diffusion.

Parameters:
  • integrator – The integrator to use for solving the reverse SDE

  • model – The diffusion model (SDE) defining the forward process

  • predictor – The predictor for computing the score/denoised estimate

  • x0_shape – Shape of the data samples (excluding batch dimension)

generate(rng_key: Key[Array, ''] | UInt32[Array, '2'], n_steps: int, n_particles: int, keep_history: bool = False, data_sharding: Any | None = None) Tuple[DenoiserState, Array | None][source]#

Generate denoised samples \(x_0\).

Parameters:
  • rng_key – Random key for initialization

  • n_steps – Number of denoising steps to perform

  • n_particles – Number of samples to generate (batch size)

  • keep_history – If True, return the full trajectory of samples

  • data_sharding – Optional JAX sharding specification for distributed computation

Returns:

Tuple of (final_state, history), where history is None if keep_history=False

init(position: Array, rng_key: Key[Array, ''] | UInt32[Array, '2']) DenoiserState[source]#

Initialize denoiser state

integrator: Integrator#
model: DiffusionModel#
predictor: Predictor#
step(state: DenoiserState) DenoiserState[source]#

Perform one denoising step.

Sample \(x_{t-1} \sim p(x_{t-1} | x_t)\)

Parameters:

state – Current denoiser state

Returns:

Updated denoiser state at the previous time step

x0_shape: Tuple[int, ...]#

Plug-and-Play Denoiser#

class diffuse.denoisers.pnp_denoiser.PnPDenoiser(integrator: Integrator, sde: SDE, score: Callable[[Array, float], Array], forward_model: ForwardModel, _resample: bool = False)[source]#

Bases: object

Conditional denoiser implementation

batch_step(rng_key: Key[Array, ''] | UInt32[Array, '2'], state: PnPDenoiserState, score: Callable[[Array, float], Array], measurement_state: MeasurementState) PnPDenoiserState[source]#

Batch update step

forward_model: ForwardModel#
init(position: Array, rng_key: Key[Array, ''] | UInt32[Array, '2'], dt: float) PnPDenoiserState[source]#

Initialize denoiser state

integrator: Integrator#
pooled_posterior_logpdf(rng_key: Key[Array, ''] | UInt32[Array, '2'], t: float, y_cntrst: Array, y_past: Array, design: Array, mask_history: Array)[source]#

Compute pooled posterior log probability density

posterior_logpdf(rng_key: Key[Array, ''] | UInt32[Array, '2'], t: float, y_meas: Array, design_mask: Array)[source]#

Compute posterior log probability density

score: Callable[[Array, float], Array]#
sde: SDE#
step(state: PnPDenoiserState, score: Callable[[Array, float], Array]) PnPDenoiserState[source]#

Single step update

y_noiser(mask: Array, key: Key[Array, ''] | UInt32[Array, '2'], state: SDEState, ts: float) SDEState[source]#

Add noise to measurements

class diffuse.denoisers.pnp_denoiser.PnPDenoiserState(position, auxiliary, log_weights)[source]#

Bases: NamedTuple

auxiliary: Array#

Alias for field number 1

log_weights: Array#

Alias for field number 2

position: Array#

Alias for field number 0

Base Classes#

class diffuse.denoisers.base.BaseDenoiser[source]#

Bases: ABC

abstractmethod generate(rng_key: Key[Array, ''] | UInt32[Array, '2'], measurement_state, n_steps: int, n_particles: int)[source]#

Generate samples

abstractmethod init(position: Array, rng_key: Key[Array, ''] | UInt32[Array, '2'], dt: float) DenoiserState[source]#

Initialize denoiser state

abstractmethod step(state: DenoiserState, score: Callable[[Array, float], Array]) DenoiserState[source]#

Perform single denoising step

class diffuse.denoisers.base.DenoiserState(integrator_state: IntegratorState)[source]#

Bases: NamedTuple

Base state for all denoisers

integrator_state: IntegratorState#

Alias for field number 0

Utilities#

diffuse.denoisers.utils.ess(log_weights: Array) float[source]#
diffuse.denoisers.utils.log_ess(log_weights: Array) float[source]#
diffuse.denoisers.utils.normalize_log_weights(log_weights: Array) Array[source]#
diffuse.denoisers.utils.resample_particles(position: Array, log_weights: Array, rng_key: Array, ess_low: float = 0.2, ess_high: float = 0.5) Tuple[Array, Array][source]#

Internal function to perform the actual resampling given the weights.

Parameters:
  • position – Current particle positions

  • log_weights – Log weights of the particles

  • rng_key – Random number generator key

  • ess_low – Lower threshold for ESS

  • ess_high – Upper threshold for ESS

Returns:

Tuple of (resampled_position, normalized_log_weights)

diffuse.denoisers.utils.stratified_resampling(key, w)[source]#
diffuse.denoisers.utils.weights_tweedie(state_next, measurement_state: MeasurementState, rng_key: Array, sde, score_fn, forward_model, ess_low: float = 0.2, ess_high: float = 0.5) Tuple[Array, Array][source]#

Compute weight with Tweedie’s formula and resample particles.

Parameters:
  • state_next – The current state of the particles

  • measurement_state – The measurement state containing observations

  • rng_key – Random number generator key

  • sde – The SDE object

  • score_fn – The score function

  • forward_model – The forward model

  • ess_low – Lower threshold for ESS (default: 0.2)

  • ess_high – Upper threshold for ESS (default: 0.5)

Returns:

Tuple of (position, log_weights)