Conditional Denoisers#

DPS (Diffusion Posterior Sampling)#

class diffuse.denoisers.cond.dps.DPSDenoiser(integrator: Integrator, model: DiffusionModel, predictor: Predictor, forward_model: ForwardModel, x0_shape: Tuple[int, ...], resample: bool | None = False, ess_low: float | None = 0.2, ess_high: float | None = 0.5, epsilon: float = 0.001, zeta: float = 0.01)[source]#

Bases: CondDenoiser

Conditional denoiser using Diffusion Posterior Sampling (DPS).

Implements DPS which uses Tweedie’s formula for denoising and applies measurement-consistency gradient corrections at each sampling step.

Parameters:
  • integrator – Numerical integrator for solving the reverse SDE

  • model – Diffusion model defining the forward process

  • predictor – Predictor for computing score/noise/velocity

  • forward_model – Forward measurement operator

  • epsilon – Numerical stability parameter (default: 1e-3)

  • zeta – Gradient step size parameter (default: 1e-2)

References

Chung, H., Kim, J., Mccann, M. T., Klasky, M. L., & Ye, J. C. (2022). Diffusion posterior sampling for general noisy inverse problems. arXiv:2209.14687

step(rng_key: Key[Array, ''] | UInt32[Array, '2'], state: CondDenoiserState, measurement_state: MeasurementState) CondDenoiserState[source]#

Single step of DPS sampling.

Implements the DPS algorithm: 1. Compute Tweedie estimate at current position 2. Take unconditional diffusion step with integrator 3. Apply measurement-consistency gradient correction

This approach works correctly with second-order integrators (Heun, DPM++, etc.)

Parameters:
  • rng_key – Random number generator key

  • state – Current conditional denoiser state

  • measurement_state – Measurement information

Returns:

Updated conditional denoiser state

DPS-GSG (Gradient Surrogate)#

class diffuse.denoisers.cond.dps_gsg.DPSGSGDenoiser(integrator: Integrator, model: DiffusionModel, predictor: Predictor, forward_model: ForwardModel, x0_shape: Tuple[int, ...], resample: bool | None = False, ess_low: float | None = 0.2, ess_high: float | None = 0.5, num_queries: int = 64, mu: float = 0.01, zeta: float = 0.01, epsilon: float = 0.001, central_diff: bool = True)[source]#

Bases: CondDenoiser

Zero-order DPS using Gaussian Smoothed Gradient (GSG) estimation.

Derivative-free DPS for black-box or non-differentiable forward models. Estimates gradients via finite differences with Gaussian perturbations.

Parameters:
  • num_queries – Number of perturbation samples for gradient estimation

  • mu – Perturbation scale (smoothing factor)

  • zeta – Gradient step size

  • epsilon – Numerical stability constant

  • central_diff – Use central differences (True) or forward differences (False)

step(rng_key: Key[Array, ''] | UInt32[Array, '2'], state: CondDenoiserState, measurement_state: MeasurementState) CondDenoiserState[source]#

Abstract method to perform a single step of conditional denoising.

This method should be implemented by subclasses to define how to update the state based on the current measurement and random key.

Parameters:
  • rng_key – Random number generator key for stochastic operations

  • state – Current state of the denoiser containing position and weights

  • measurement_state – Current measurement state containing observations

Returns:

Updated state after performing the denoising step

Return type:

CondDenoiserState

FPS (Filtered Posterior Sampling)#

class diffuse.denoisers.cond.fps.FPSDenoiser(integrator: Integrator, model: DiffusionModel, predictor: Predictor, forward_model: ForwardModel, x0_shape: Tuple[int, ...], resample: bool | None = False, ess_low: float | None = 0.2, ess_high: float | None = 0.5)[source]#

Bases: CondDenoiser

Filtering Posterior Sampling (FPS) Denoiser.

Implements continuous-time SDE version of FPS for conditional generation with particle filtering and resampling.

Parameters:
  • integrator – Numerical integrator for solving the reverse SDE

  • model – Diffusion model defining the forward process

  • predictor – Predictor for computing score/noise/velocity

  • forward_model – Forward measurement operator

resample#

Whether to use particle resampling (set in __post_init__)

Type:

bool | None

ess_low#

Low threshold for effective sample size (0.2)

Type:

float | None

ess_high#

High threshold for effective sample size (0.6)

Type:

float | None

References

Dou, Z., & Song, Y. (2024). Diffusion Posterior Sampling for Linear Inverse Problem Solving: A Filtering Perspective. arXiv:2407.03981

resampler(state_next: CondDenoiserState, measurement_state: MeasurementState, rng_key: Key[Array, ''] | UInt32[Array, '2']) CondDenoiserState[source]#

Resample particles based on the current state and measurement.

This method resamples particles if the Effective Sample Size (ESS) falls below the specified thresholds, ensuring the quality of the particle set.

Parameters:
  • state_next – Next state of the denoiser. Shape: (n_particles, …)

  • measurement_state – Current measurement state.

  • rng_key – Random number generator key.

Returns:

Updated state after resampling.

Return type:

CondDenoiserState

step(rng_key: Key[Array, ''] | UInt32[Array, '2'], state: CondDenoiserState, measurement_state: MeasurementState) CondDenoiserState[source]#

Single step of continuous-time FPS sampling.

Implements FPS by separating: 1. Computation of guidance term at current position 2. Unconditional diffusion step with integrator 3. Guidance correction applied to result

This approach works correctly with second-order integrators (Heun, DPM++, etc.) because the integrator sees the true unconditional score/velocity.

Parameters:
  • rng_key – Random number generator key

  • state – Current conditional denoiser state

  • measurement_state – Measurement information

Returns:

Updated conditional denoiser state

y_noiser(key: Key[Array, ''] | UInt32[Array, '2'], t: float, measurement_state: MeasurementState) SDEState[source]#

Generate noisy measurement at time t.

Computes \(y^{(t)} = \sqrt{\bar{\alpha}_t} y + \sqrt{1-\bar{\alpha}_t} A_\xi \epsilon\)

Parameters:
  • key – Random number generator key

  • t – Current time

  • measurement_state – Measurement information

Returns:

SDEState containing the noised measurement

TMP (Tweedie Moment Projection)#

class diffuse.denoisers.cond.tmp.TMPDenoiser(integrator: Integrator, model: DiffusionModel, predictor: Predictor, forward_model: ForwardModel, x0_shape: Tuple[int, ...], resample: bool | None = False, ess_low: float | None = 0.2, ess_high: float | None = 0.5)[source]#

Bases: CondDenoiser

Conditional denoiser using Tweedie’s Moment Projection (TMP).

Implements TMP which modifies the score function to incorporate measurement information through Tweedie’s formula and moment matching.

Parameters:
  • integrator – Numerical integrator for solving the reverse SDE

  • model – Diffusion model defining the forward process

  • predictor – Predictor for computing score/noise/velocity

  • forward_model – Forward measurement operator

References

Boys, B., Girolami, M., Pidstrigach, J., Reich, S., Mosca, A., & Akyildiz, Ö. D. (2023). Tweedie moment projected diffusions for inverse problems. arXiv:2310.06721

step(rng_key: Key[Array, ''] | UInt32[Array, '2'], state: CondDenoiserState, measurement_state: MeasurementState) CondDenoiserState[source]#

Single step of TMP sampling.

Modifies the score to include measurement term and uses integrator for the update.

Parameters:
  • rng_key – Random number generator key

  • state – Current conditional denoiser state

  • measurement_state – Measurement information

Returns:

Updated conditional denoiser state

PiGDM (Pseudoinverse-Guided Diffusion)#

class diffuse.denoisers.cond.pigdm.PiGDMDenoiser(integrator: Integrator, model: DiffusionModel, predictor: Predictor, forward_model: ForwardModel, x0_shape: Tuple[int, ...], resample: bool | None = False, ess_low: float | None = 0.2, ess_high: float | None = 0.5, zeta: float = 0.01, epsilon: float = 0.1, cg_maxiter: int = 0, cg_reg: float = 0.01)[source]#

Bases: CondDenoiser

Pseudoinverse-Guided Diffusion Models (ΠiGDM).

Guidance: grad = ∇_x [(A†r)ᵀ · x̂₀] where r = y - Ax̂₀

Parameters:
  • zeta – Step size scaling factor

  • epsilon – Numerical stability constant

  • cg_maxiter – CG iterations for pseudo-inverse (0 = use Aᵀ instead)

  • cg_reg – Tikhonov regularization for CG solve

Reference: Song et al., “Pseudoinverse-Guided Diffusion Models for Inverse Problems”

step(rng_key: Key[Array, ''] | UInt32[Array, '2'], state: CondDenoiserState, measurement_state: MeasurementState) CondDenoiserState[source]#

Abstract method to perform a single step of conditional denoising.

This method should be implemented by subclasses to define how to update the state based on the current measurement and random key.

Parameters:
  • rng_key – Random number generator key for stochastic operations

  • state – Current state of the denoiser containing position and weights

  • measurement_state – Current measurement state containing observations

Returns:

Updated state after performing the denoising step

Return type:

CondDenoiserState

DAPS (Decoupled Annealed Posterior Sampling)#

class diffuse.denoisers.cond.daps.DAPSDenoiser(integrator: Integrator, model: DiffusionModel, predictor: Predictor, forward_model: ForwardModel, x0_shape: Tuple[int, ...], resample: bool | None = False, ess_low: float | None = 0.2, ess_high: float | None = 0.5, langevin_steps: int = 100, langevin_lr: float = 0.0001, langevin_lr_min_ratio: float = 0.01, tau: float = 0.01, diffusion_steps: int = 50)[source]#

Bases: CondDenoiser

Decoupled Annealing Posterior Sampling (DAPS).

At each annealing step: 1. Reverse diffusion: ODE solve from σ_k to ~0 to get x̂₀ 2. Langevin dynamics: sample from p(x₀|y) ∝ exp(-||Ax-y||²/2τ² - ||x-x̂₀||²/2σ²) 3. Forward diffusion: add noise to reach σ_{k+1}

Parameters:
  • langevin_steps – Number of Langevin MCMC steps per annealing level

  • langevin_lr – Langevin step size

  • langevin_lr_min_ratio – Final/initial step size ratio for annealing

  • tau – Likelihood temperature (measurement noise scale)

  • diffusion_steps – ODE steps for reverse diffusion

Reference: Zhang et al., “Improving Diffusion Inverse Problem Solving with Decoupled Noise Annealing”

step(rng_key: Key[Array, ''] | UInt32[Array, '2'], state: CondDenoiserState, measurement_state: MeasurementState) CondDenoiserState[source]#

Abstract method to perform a single step of conditional denoising.

This method should be implemented by subclasses to define how to update the state based on the current measurement and random key.

Parameters:
  • rng_key – Random number generator key for stochastic operations

  • state – Current state of the denoiser containing position and weights

  • measurement_state – Current measurement state containing observations

Returns:

Updated state after performing the denoising step

Return type:

CondDenoiserState

PnPDM (Plug-and-Play Diffusion Model)#

class diffuse.denoisers.cond.pnpdm.PnPDMDenoiser(integrator: Integrator, model: DiffusionModel, predictor: Predictor, forward_model: ForwardModel, x0_shape: Tuple[int, ...], resample: bool | None = False, ess_low: float | None = 0.2, ess_high: float | None = 0.5, sigma_max: float = 1.0, sigma_min: float = 0.01, rho: float = 0.95, langevin_steps: int = 10, langevin_lr: float = 0.01, tau: float = 0.5, diffusion_steps: int = 50, annealing_steps: int = 50)[source]#

Bases: CondDenoiser

Plug-and-Play Diffusion Models (PnPDM).

Alternates between Langevin dynamics for data fitting and full reverse diffusion for denoising. Uses exponential decay sigma annealing.

At each annealing step: 1. Run Langevin dynamics: minimize ||Ax-y||²/τ² + ||x-x₀||²/σ² 2. Run full reverse diffusion from σ to eps

Reference: “Principled Probabilistic Imaging using Diffusion Models as Plug-and-Play Priors” (Wu et al.)

langevin_sampling(rng_key: Key[Array, ''] | UInt32[Array, '2'], x0_hat: Array, sigma: Array, measurement_state: MeasurementState) Array[source]#

Run Langevin dynamics optimizing ||Ax - y||²/τ² + ||x - x0_hat||²/σ².

reverse_diffuse(rng_key: Key[Array, ''] | UInt32[Array, '2'], z: Array, sigma: Array) Array[source]#

Run full reverse diffusion from sigma level to eps.

step(rng_key: Key[Array, ''] | UInt32[Array, '2'], state: CondDenoiserState, measurement_state: MeasurementState) CondDenoiserState[source]#

Single annealing step: Langevin + noise + reverse diffusion.

DiffPIR (Diffusion-based Plug-and-Play Image Restoration)#

class diffuse.denoisers.cond.diffpir.DiffPIRDenoiser(integrator: Integrator, model: DiffusionModel, predictor: Predictor, forward_model: ForwardModel, x0_shape: Tuple[int, ...], resample: bool | None = False, ess_low: float | None = 0.2, ess_high: float | None = 0.5, sigma_n: float = 0.05, lamb: float = 1.0, xi: float = 0.5, linear: bool = False, cg_tol: float = 1e-05, cg_maxiter: int = 32)[source]#

Bases: CondDenoiser

Plug-and-Play Diffusion for Image Restoration (DiffPIR).

Alternates: denoise → data-fidelity → add noise. Regularization: ρ = 2λσₙ² / σₜ²

Parameters:
  • sigma_n – Measurement noise std

  • lamb – Regularization strength (prior vs likelihood balance)

  • xi – Noise injection ratio (0=deterministic, 1=stochastic)

  • linear – Use proximal step (True) or gradient step (False)

  • cg_tol – CG solver tolerance (linear mode)

  • cg_maxiter – CG max iterations (linear mode)

step(rng_key: Key[Array, ''] | UInt32[Array, '2'], state: CondDenoiserState, measurement_state: MeasurementState) CondDenoiserState[source]#

Abstract method to perform a single step of conditional denoising.

This method should be implemented by subclasses to define how to update the state based on the current measurement and random key.

Parameters:
  • rng_key – Random number generator key for stochastic operations

  • state – Current state of the denoiser containing position and weights

  • measurement_state – Current measurement state containing observations

Returns:

Updated state after performing the denoising step

Return type:

CondDenoiserState

EnKG (Ensemble Kalman Guidance)#

class diffuse.denoisers.cond.enkg.EnKGDenoiser(integrator: Integrator, model: DiffusionModel, predictor: Predictor, forward_model: ForwardModel, x0_shape: Tuple[int, ...], resample: bool | None = False, ess_low: float | None = 0.2, ess_high: float | None = 0.5, guidance_scale: float = 0.5, lr_min_ratio: float = 0.0, denoising_steps: int = 15)[source]#

Bases: CondDenoiser

Ensemble Kalman Guidance (EnKG) conditional denoiser.

Derivative-free method using ensemble Kalman updates. Does NOT require gradients of the forward model.

Parameters:
  • guidance_scale – Base learning rate (γ) for Kalman updates

  • lr_min_ratio – Min LR ratio (r) for decay: γ(1-r)(N-i)/N + r

  • denoising_steps – Euler ODE steps for x̂₀ estimation (1 = Tweedie)

Reference: devzhk/enkg-pytorch

generate(rng_key: Key[Array, ''] | UInt32[Array, '2'], measurement_state: MeasurementState, n_steps: int, n_particles: int, keep_history: bool = False)[source]#

Generate samples

step(rng_key: Key[Array, ''] | UInt32[Array, '2'], state: CondDenoiserState, measurement_state: MeasurementState) CondDenoiserState[source]#

Abstract method to perform a single step of conditional denoising.

This method should be implemented by subclasses to define how to update the state based on the current measurement and random key.

Parameters:
  • rng_key – Random number generator key for stochastic operations

  • state – Current state of the denoiser containing position and weights

  • measurement_state – Current measurement state containing observations

Returns:

Updated state after performing the denoising step

Return type:

CondDenoiserState