Source code for diffuse.denoisers.cond.pigdm
from dataclasses import dataclass
import jax
import jax.numpy as jnp
from jaxtyping import Array, PRNGKeyArray
from diffuse.base_forward_model import MeasurementState
from diffuse.denoisers.cond import CondDenoiser, CondDenoiserState
from diffuse.diffusion.sde import SDEState
[docs]
@dataclass
class PiGDMDenoiser(CondDenoiser):
"""Pseudoinverse-Guided Diffusion Models (ΠiGDM).
Guidance: grad = ∇_x [(A†r)ᵀ · x̂₀] where r = y - Ax̂₀
Args:
zeta: Step size scaling factor
epsilon: Numerical stability constant
cg_maxiter: CG iterations for pseudo-inverse (0 = use Aᵀ instead)
cg_reg: Tikhonov regularization for CG solve
Reference: Song et al., "Pseudoinverse-Guided Diffusion Models for Inverse Problems"
"""
zeta: float = 1e-2
epsilon: float = 1e-1
cg_maxiter: int = 0
cg_reg: float = 1e-2
def _pseudoinverse(self, residual: Array, measurement_state: MeasurementState) -> Array:
if self.cg_maxiter <= 0:
return self.forward_model.adjoint(residual, measurement_state)
def normal_op(v):
Av = self.forward_model.apply(v, measurement_state)
AtAv = self.forward_model.adjoint(Av, measurement_state)
return AtAv + self.cg_reg * v
rhs = self.forward_model.adjoint(residual, measurement_state)
direction, _ = jax.scipy.sparse.linalg.cg(normal_op, rhs, maxiter=self.cg_maxiter)
return direction
[docs]
def step(
self,
rng_key: PRNGKeyArray,
state: CondDenoiserState,
measurement_state: MeasurementState,
) -> CondDenoiserState:
x_t = state.integrator_state.position
t_current = self.integrator.timer(state.integrator_state.step)
def guidance_loss(x: Array) -> Array:
x0_hat = self.model.tweedie(SDEState(x, t_current), self.predictor.score).position
residual = measurement_state.y - self.forward_model.apply(x0_hat, measurement_state)
direction = jax.lax.stop_gradient(self._pseudoinverse(residual, measurement_state))
return jnp.sum(direction * x0_hat)
loss_val, grad = jax.value_and_grad(guidance_loss)(x_t)
step_size = self.zeta / (jnp.sqrt(jnp.abs(loss_val)) + self.epsilon)
integrator_state_uncond = self.integrator(state.integrator_state, self.predictor)
position_corrected = integrator_state_uncond.position + step_size * grad
integrator_state_next = integrator_state_uncond._replace(position=position_corrected)
return state._replace(integrator_state=integrator_state_next)