Conditional Denoisers#
DPS (Diffusion Posterior Sampling)#
- class diffuse.denoisers.cond.dps.DPSDenoiser(integrator: Integrator, model: DiffusionModel, predictor: Predictor, forward_model: ForwardModel, x0_shape: Tuple[int, ...], resample: bool | None = False, ess_low: float | None = 0.2, ess_high: float | None = 0.5, epsilon: float = 0.001, zeta: float = 0.01)[source]#
Bases:
CondDenoiserConditional denoiser using Diffusion Posterior Sampling (DPS).
Implements DPS which uses Tweedie’s formula for denoising and applies measurement-consistency gradient corrections at each sampling step.
- Parameters:
integrator – Numerical integrator for solving the reverse SDE
model – Diffusion model defining the forward process
predictor – Predictor for computing score/noise/velocity
forward_model – Forward measurement operator
epsilon – Numerical stability parameter (default: 1e-3)
zeta – Gradient step size parameter (default: 1e-2)
References
Chung, H., Kim, J., Mccann, M. T., Klasky, M. L., & Ye, J. C. (2022). Diffusion posterior sampling for general noisy inverse problems. arXiv:2209.14687
- step(rng_key: Key[Array, ''] | UInt32[Array, '2'], state: CondDenoiserState, measurement_state: MeasurementState) CondDenoiserState[source]#
Single step of DPS sampling.
Implements the DPS algorithm: 1. Compute Tweedie estimate at current position 2. Take unconditional diffusion step with integrator 3. Apply measurement-consistency gradient correction
This approach works correctly with second-order integrators (Heun, DPM++, etc.)
- Parameters:
rng_key – Random number generator key
state – Current conditional denoiser state
measurement_state – Measurement information
- Returns:
Updated conditional denoiser state
DPS-GSG (Gradient Surrogate)#
- class diffuse.denoisers.cond.dps_gsg.DPSGSGDenoiser(integrator: Integrator, model: DiffusionModel, predictor: Predictor, forward_model: ForwardModel, x0_shape: Tuple[int, ...], resample: bool | None = False, ess_low: float | None = 0.2, ess_high: float | None = 0.5, num_queries: int = 64, mu: float = 0.01, zeta: float = 0.01, epsilon: float = 0.001, central_diff: bool = True)[source]#
Bases:
CondDenoiserZero-order DPS using Gaussian Smoothed Gradient (GSG) estimation.
Derivative-free DPS for black-box or non-differentiable forward models. Estimates gradients via finite differences with Gaussian perturbations.
- Parameters:
num_queries – Number of perturbation samples for gradient estimation
mu – Perturbation scale (smoothing factor)
zeta – Gradient step size
epsilon – Numerical stability constant
central_diff – Use central differences (True) or forward differences (False)
- step(rng_key: Key[Array, ''] | UInt32[Array, '2'], state: CondDenoiserState, measurement_state: MeasurementState) CondDenoiserState[source]#
Abstract method to perform a single step of conditional denoising.
This method should be implemented by subclasses to define how to update the state based on the current measurement and random key.
- Parameters:
rng_key – Random number generator key for stochastic operations
state – Current state of the denoiser containing position and weights
measurement_state – Current measurement state containing observations
- Returns:
Updated state after performing the denoising step
- Return type:
CondDenoiserState
FPS (Filtered Posterior Sampling)#
- class diffuse.denoisers.cond.fps.FPSDenoiser(integrator: Integrator, model: DiffusionModel, predictor: Predictor, forward_model: ForwardModel, x0_shape: Tuple[int, ...], resample: bool | None = False, ess_low: float | None = 0.2, ess_high: float | None = 0.5)[source]#
Bases:
CondDenoiserFiltering Posterior Sampling (FPS) Denoiser.
Implements continuous-time SDE version of FPS for conditional generation with particle filtering and resampling.
- Parameters:
integrator – Numerical integrator for solving the reverse SDE
model – Diffusion model defining the forward process
predictor – Predictor for computing score/noise/velocity
forward_model – Forward measurement operator
References
Dou, Z., & Song, Y. (2024). Diffusion Posterior Sampling for Linear Inverse Problem Solving: A Filtering Perspective. arXiv:2407.03981
- resampler(state_next: CondDenoiserState, measurement_state: MeasurementState, rng_key: Key[Array, ''] | UInt32[Array, '2']) CondDenoiserState[source]#
Resample particles based on the current state and measurement.
This method resamples particles if the Effective Sample Size (ESS) falls below the specified thresholds, ensuring the quality of the particle set.
- Parameters:
state_next – Next state of the denoiser. Shape: (n_particles, …)
measurement_state – Current measurement state.
rng_key – Random number generator key.
- Returns:
Updated state after resampling.
- Return type:
CondDenoiserState
- step(rng_key: Key[Array, ''] | UInt32[Array, '2'], state: CondDenoiserState, measurement_state: MeasurementState) CondDenoiserState[source]#
Single step of continuous-time FPS sampling.
Implements FPS by separating: 1. Computation of guidance term at current position 2. Unconditional diffusion step with integrator 3. Guidance correction applied to result
This approach works correctly with second-order integrators (Heun, DPM++, etc.) because the integrator sees the true unconditional score/velocity.
- Parameters:
rng_key – Random number generator key
state – Current conditional denoiser state
measurement_state – Measurement information
- Returns:
Updated conditional denoiser state
- y_noiser(key: Key[Array, ''] | UInt32[Array, '2'], t: float, measurement_state: MeasurementState) SDEState[source]#
Generate noisy measurement at time t.
Computes \(y^{(t)} = \sqrt{\bar{\alpha}_t} y + \sqrt{1-\bar{\alpha}_t} A_\xi \epsilon\)
- Parameters:
key – Random number generator key
t – Current time
measurement_state – Measurement information
- Returns:
SDEState containing the noised measurement
TMP (Tweedie Moment Projection)#
- class diffuse.denoisers.cond.tmp.TMPDenoiser(integrator: Integrator, model: DiffusionModel, predictor: Predictor, forward_model: ForwardModel, x0_shape: Tuple[int, ...], resample: bool | None = False, ess_low: float | None = 0.2, ess_high: float | None = 0.5)[source]#
Bases:
CondDenoiserConditional denoiser using Tweedie’s Moment Projection (TMP).
Implements TMP which modifies the score function to incorporate measurement information through Tweedie’s formula and moment matching.
- Parameters:
integrator – Numerical integrator for solving the reverse SDE
model – Diffusion model defining the forward process
predictor – Predictor for computing score/noise/velocity
forward_model – Forward measurement operator
References
Boys, B., Girolami, M., Pidstrigach, J., Reich, S., Mosca, A., & Akyildiz, Ö. D. (2023). Tweedie moment projected diffusions for inverse problems. arXiv:2310.06721
- step(rng_key: Key[Array, ''] | UInt32[Array, '2'], state: CondDenoiserState, measurement_state: MeasurementState) CondDenoiserState[source]#
Single step of TMP sampling.
Modifies the score to include measurement term and uses integrator for the update.
- Parameters:
rng_key – Random number generator key
state – Current conditional denoiser state
measurement_state – Measurement information
- Returns:
Updated conditional denoiser state
PiGDM (Pseudoinverse-Guided Diffusion)#
- class diffuse.denoisers.cond.pigdm.PiGDMDenoiser(integrator: Integrator, model: DiffusionModel, predictor: Predictor, forward_model: ForwardModel, x0_shape: Tuple[int, ...], resample: bool | None = False, ess_low: float | None = 0.2, ess_high: float | None = 0.5, zeta: float = 0.01, epsilon: float = 0.1, cg_maxiter: int = 0, cg_reg: float = 0.01)[source]#
Bases:
CondDenoiserPseudoinverse-Guided Diffusion Models (ΠiGDM).
Guidance: grad = ∇_x [(A†r)ᵀ · x̂₀] where r = y - Ax̂₀
- Parameters:
zeta – Step size scaling factor
epsilon – Numerical stability constant
cg_maxiter – CG iterations for pseudo-inverse (0 = use Aᵀ instead)
cg_reg – Tikhonov regularization for CG solve
Reference: Song et al., “Pseudoinverse-Guided Diffusion Models for Inverse Problems”
- step(rng_key: Key[Array, ''] | UInt32[Array, '2'], state: CondDenoiserState, measurement_state: MeasurementState) CondDenoiserState[source]#
Abstract method to perform a single step of conditional denoising.
This method should be implemented by subclasses to define how to update the state based on the current measurement and random key.
- Parameters:
rng_key – Random number generator key for stochastic operations
state – Current state of the denoiser containing position and weights
measurement_state – Current measurement state containing observations
- Returns:
Updated state after performing the denoising step
- Return type:
CondDenoiserState
DAPS (Decoupled Annealed Posterior Sampling)#
- class diffuse.denoisers.cond.daps.DAPSDenoiser(integrator: Integrator, model: DiffusionModel, predictor: Predictor, forward_model: ForwardModel, x0_shape: Tuple[int, ...], resample: bool | None = False, ess_low: float | None = 0.2, ess_high: float | None = 0.5, langevin_steps: int = 100, langevin_lr: float = 0.0001, langevin_lr_min_ratio: float = 0.01, tau: float = 0.01, diffusion_steps: int = 50)[source]#
Bases:
CondDenoiserDecoupled Annealing Posterior Sampling (DAPS).
At each annealing step: 1. Reverse diffusion: ODE solve from σ_k to ~0 to get x̂₀ 2. Langevin dynamics: sample from p(x₀|y) ∝ exp(-||Ax-y||²/2τ² - ||x-x̂₀||²/2σ²) 3. Forward diffusion: add noise to reach σ_{k+1}
- Parameters:
langevin_steps – Number of Langevin MCMC steps per annealing level
langevin_lr – Langevin step size
langevin_lr_min_ratio – Final/initial step size ratio for annealing
tau – Likelihood temperature (measurement noise scale)
diffusion_steps – ODE steps for reverse diffusion
Reference: Zhang et al., “Improving Diffusion Inverse Problem Solving with Decoupled Noise Annealing”
- step(rng_key: Key[Array, ''] | UInt32[Array, '2'], state: CondDenoiserState, measurement_state: MeasurementState) CondDenoiserState[source]#
Abstract method to perform a single step of conditional denoising.
This method should be implemented by subclasses to define how to update the state based on the current measurement and random key.
- Parameters:
rng_key – Random number generator key for stochastic operations
state – Current state of the denoiser containing position and weights
measurement_state – Current measurement state containing observations
- Returns:
Updated state after performing the denoising step
- Return type:
CondDenoiserState
PnPDM (Plug-and-Play Diffusion Model)#
- class diffuse.denoisers.cond.pnpdm.PnPDMDenoiser(integrator: Integrator, model: DiffusionModel, predictor: Predictor, forward_model: ForwardModel, x0_shape: Tuple[int, ...], resample: bool | None = False, ess_low: float | None = 0.2, ess_high: float | None = 0.5, sigma_max: float = 1.0, sigma_min: float = 0.01, rho: float = 0.95, langevin_steps: int = 10, langevin_lr: float = 0.01, tau: float = 0.5, diffusion_steps: int = 50, annealing_steps: int = 50)[source]#
Bases:
CondDenoiserPlug-and-Play Diffusion Models (PnPDM).
Alternates between Langevin dynamics for data fitting and full reverse diffusion for denoising. Uses exponential decay sigma annealing.
At each annealing step: 1. Run Langevin dynamics: minimize ||Ax-y||²/τ² + ||x-x₀||²/σ² 2. Run full reverse diffusion from σ to eps
Reference: “Principled Probabilistic Imaging using Diffusion Models as Plug-and-Play Priors” (Wu et al.)
- langevin_sampling(rng_key: Key[Array, ''] | UInt32[Array, '2'], x0_hat: Array, sigma: Array, measurement_state: MeasurementState) Array[source]#
Run Langevin dynamics optimizing ||Ax - y||²/τ² + ||x - x0_hat||²/σ².
- reverse_diffuse(rng_key: Key[Array, ''] | UInt32[Array, '2'], z: Array, sigma: Array) Array[source]#
Run full reverse diffusion from sigma level to eps.
- step(rng_key: Key[Array, ''] | UInt32[Array, '2'], state: CondDenoiserState, measurement_state: MeasurementState) CondDenoiserState[source]#
Single annealing step: Langevin + noise + reverse diffusion.
DiffPIR (Diffusion-based Plug-and-Play Image Restoration)#
- class diffuse.denoisers.cond.diffpir.DiffPIRDenoiser(integrator: Integrator, model: DiffusionModel, predictor: Predictor, forward_model: ForwardModel, x0_shape: Tuple[int, ...], resample: bool | None = False, ess_low: float | None = 0.2, ess_high: float | None = 0.5, sigma_n: float = 0.05, lamb: float = 1.0, xi: float = 0.5, linear: bool = False, cg_tol: float = 1e-05, cg_maxiter: int = 32)[source]#
Bases:
CondDenoiserPlug-and-Play Diffusion for Image Restoration (DiffPIR).
Alternates: denoise → data-fidelity → add noise. Regularization: ρ = 2λσₙ² / σₜ²
- Parameters:
sigma_n – Measurement noise std
lamb – Regularization strength (prior vs likelihood balance)
xi – Noise injection ratio (0=deterministic, 1=stochastic)
linear – Use proximal step (True) or gradient step (False)
cg_tol – CG solver tolerance (linear mode)
cg_maxiter – CG max iterations (linear mode)
- step(rng_key: Key[Array, ''] | UInt32[Array, '2'], state: CondDenoiserState, measurement_state: MeasurementState) CondDenoiserState[source]#
Abstract method to perform a single step of conditional denoising.
This method should be implemented by subclasses to define how to update the state based on the current measurement and random key.
- Parameters:
rng_key – Random number generator key for stochastic operations
state – Current state of the denoiser containing position and weights
measurement_state – Current measurement state containing observations
- Returns:
Updated state after performing the denoising step
- Return type:
CondDenoiserState
EnKG (Ensemble Kalman Guidance)#
- class diffuse.denoisers.cond.enkg.EnKGDenoiser(integrator: Integrator, model: DiffusionModel, predictor: Predictor, forward_model: ForwardModel, x0_shape: Tuple[int, ...], resample: bool | None = False, ess_low: float | None = 0.2, ess_high: float | None = 0.5, guidance_scale: float = 0.5, lr_min_ratio: float = 0.0, denoising_steps: int = 15)[source]#
Bases:
CondDenoiserEnsemble Kalman Guidance (EnKG) conditional denoiser.
Derivative-free method using ensemble Kalman updates. Does NOT require gradients of the forward model.
- Parameters:
guidance_scale – Base learning rate (γ) for Kalman updates
lr_min_ratio – Min LR ratio (r) for decay: γ(1-r)(N-i)/N + r
denoising_steps – Euler ODE steps for x̂₀ estimation (1 = Tweedie)
Reference: devzhk/enkg-pytorch
- generate(rng_key: Key[Array, ''] | UInt32[Array, '2'], measurement_state: MeasurementState, n_steps: int, n_particles: int, keep_history: bool = False)[source]#
Generate samples
- step(rng_key: Key[Array, ''] | UInt32[Array, '2'], state: CondDenoiserState, measurement_state: MeasurementState) CondDenoiserState[source]#
Abstract method to perform a single step of conditional denoising.
This method should be implemented by subclasses to define how to update the state based on the current measurement and random key.
- Parameters:
rng_key – Random number generator key for stochastic operations
state – Current state of the denoiser containing position and weights
measurement_state – Current measurement state containing observations
- Returns:
Updated state after performing the denoising step
- Return type:
CondDenoiserState