Source code for diffuse.denoisers.cond.dps_gsg
# Copyright 2025 Jacopo Iollo <jacopo.iollo@inria.fr>, Geoffroy Oudoumanessah <geoffroy.oudoumanessah@inria.fr>
# Licensed under the Apache License, Version 2.0 (the "License");
# http://www.apache.org/licenses/LICENSE-2.0
"""Zero-order DPS using Gaussian Smoothed Gradient estimation."""
from dataclasses import dataclass
import jax
import jax.numpy as jnp
from jaxtyping import Array, PRNGKeyArray
from diffuse.diffusion.sde import SDEState
from diffuse.denoisers.cond import CondDenoiser, CondDenoiserState
from diffuse.base_forward_model import MeasurementState
[docs]
@dataclass
class DPSGSGDenoiser(CondDenoiser):
"""Zero-order DPS using Gaussian Smoothed Gradient (GSG) estimation.
Derivative-free DPS for black-box or non-differentiable forward models.
Estimates gradients via finite differences with Gaussian perturbations.
Args:
num_queries: Number of perturbation samples for gradient estimation
mu: Perturbation scale (smoothing factor)
zeta: Gradient step size
epsilon: Numerical stability constant
central_diff: Use central differences (True) or forward differences (False)
"""
num_queries: int = 64
mu: float = 0.01
zeta: float = 1e-2
epsilon: float = 1e-3
central_diff: bool = True
def __post_init__(self):
if self.num_queries <= 0:
raise ValueError("num_queries must be strictly positive.")
if self.mu <= 0:
raise ValueError("mu must be strictly positive.")
if self.zeta <= 0:
raise ValueError("zeta must be strictly positive.")
def _estimate_gradient_central(
self,
rng_key: PRNGKeyArray,
x0_hat: Array,
measurement_state: MeasurementState,
) -> tuple[Array, Array]:
y_meas = measurement_state.y
perturbations = jax.random.normal(rng_key, (self.num_queries, *x0_hat.shape))
def compute_loss(x: Array) -> Array:
residual = y_meas - self.forward_model.apply(x, measurement_state)
return jnp.sum(residual**2)
base_loss = compute_loss(x0_hat)
def single_query_gradient(u: Array) -> Array:
loss_plus = compute_loss(x0_hat + self.mu * u)
loss_minus = compute_loss(x0_hat - self.mu * u)
return u * (loss_plus - loss_minus) / (2.0 * self.mu)
gradients = jax.vmap(single_query_gradient)(perturbations)
return jnp.mean(gradients, axis=0), base_loss
def _estimate_gradient_forward(
self,
rng_key: PRNGKeyArray,
x0_hat: Array,
measurement_state: MeasurementState,
) -> tuple[Array, Array]:
y_meas = measurement_state.y
perturbations = jax.random.normal(rng_key, (self.num_queries, *x0_hat.shape))
def compute_loss(x: Array) -> Array:
residual = y_meas - self.forward_model.apply(x, measurement_state)
return jnp.sum(residual**2)
base_loss = compute_loss(x0_hat)
def single_query_gradient(u: Array) -> Array:
loss_perturbed = compute_loss(x0_hat + self.mu * u)
return u * (loss_perturbed - base_loss) / self.mu
gradients = jax.vmap(single_query_gradient)(perturbations)
return jnp.mean(gradients, axis=0), base_loss
[docs]
def step(
self,
rng_key: PRNGKeyArray,
state: CondDenoiserState,
measurement_state: MeasurementState,
) -> CondDenoiserState:
key_gsg, _ = jax.random.split(rng_key)
position_current = state.integrator_state.position
t_current = self.integrator.timer(state.integrator_state.step)
# estimate gradient via GSG
x0_hat = self.model.tweedie(SDEState(position_current, t_current), self.predictor.score).position
if self.central_diff:
gradient, loss_val = self._estimate_gradient_central(key_gsg, x0_hat, measurement_state)
else:
gradient, loss_val = self._estimate_gradient_forward(key_gsg, x0_hat, measurement_state)
# scale gradient from x0 to xt space
alpha_t = self.model.signal_level(t_current)
alpha_t_clipped = jnp.maximum(alpha_t, 0.1)
gradient_xt = gradient / alpha_t_clipped
zeta = self.zeta / (jnp.sqrt(loss_val) + self.epsilon)
# apply correction
integrator_state_uncond = self.integrator(state.integrator_state, self.predictor)
position_corrected = integrator_state_uncond.position - zeta * gradient_xt
integrator_state_next = integrator_state_uncond._replace(position=position_corrected)
return state._replace(integrator_state=integrator_state_next)
__all__ = ["DPSGSGDenoiser"]