Source code for diffuse.denoisers.cond.dps
# Copyright 2025 Jacopo Iollo <jacopo.iollo@inria.fr>, Geoffroy Oudoumanessah <geoffroy.oudoumanessah@inria.fr>
# Licensed under the Apache License, Version 2.0 (the "License");
# http://www.apache.org/licenses/LICENSE-2.0
from dataclasses import dataclass
import jax
import jax.numpy as jnp
from jaxtyping import Array, PRNGKeyArray
from diffuse.diffusion.sde import SDEState
from diffuse.denoisers.cond import CondDenoiser, CondDenoiserState
from diffuse.base_forward_model import MeasurementState
[docs]
@dataclass
class DPSDenoiser(CondDenoiser):
"""Conditional denoiser using Diffusion Posterior Sampling (DPS).
Implements DPS which uses Tweedie's formula for denoising and applies
measurement-consistency gradient corrections at each sampling step.
Args:
integrator: Numerical integrator for solving the reverse SDE
model: Diffusion model defining the forward process
predictor: Predictor for computing score/noise/velocity
forward_model: Forward measurement operator
epsilon: Numerical stability parameter (default: 1e-3)
zeta: Gradient step size parameter (default: 1e-2)
References:
Chung, H., Kim, J., Mccann, M. T., Klasky, M. L., & Ye, J. C. (2022).
Diffusion posterior sampling for general noisy inverse problems. arXiv:2209.14687
"""
epsilon: float = 1e-3
zeta: float = 1e-2
[docs]
def step(
self,
rng_key: PRNGKeyArray,
state: CondDenoiserState,
measurement_state: MeasurementState,
) -> CondDenoiserState:
"""Single step of DPS sampling.
Implements the DPS algorithm:
1. Compute Tweedie estimate at current position
2. Take unconditional diffusion step with integrator
3. Apply measurement-consistency gradient correction
This approach works correctly with second-order integrators (Heun, DPM++, etc.)
Args:
rng_key: Random number generator key
state: Current conditional denoiser state
measurement_state: Measurement information
Returns:
Updated conditional denoiser state
"""
y_meas = measurement_state.y
position_current = state.integrator_state.position
t_current = self.integrator.timer(state.integrator_state.step)
def measurement_loss(x: Array) -> Array:
denoised = self.model.tweedie(SDEState(x, t_current), self.predictor.score).position
# Measurement consistency loss: ||y - A(x̂_0)||²
residual = y_meas - self.forward_model.apply(denoised, measurement_state)
return jnp.sum(residual**2)
loss_val, gradient = jax.value_and_grad(measurement_loss)(position_current)
zeta = self.zeta / (jnp.sqrt(loss_val) + self.epsilon)
integrator_state_uncond = self.integrator(state.integrator_state, self.predictor)
position_corrected = integrator_state_uncond.position - zeta * gradient
integrator_state_next = integrator_state_uncond._replace(position=position_corrected)
state_next = state._replace(integrator_state=integrator_state_next)
return state_next