Source code for diffuse.denoisers.cond.tmp

# Copyright 2025 Jacopo Iollo <jacopo.iollo@inria.fr>, Geoffroy Oudoumanessah <geoffroy.oudoumanessah@inria.fr>
# Licensed under the Apache License, Version 2.0 (the "License");
# http://www.apache.org/licenses/LICENSE-2.0
from dataclasses import dataclass
import jax
from jaxtyping import Array, PRNGKeyArray

from diffuse.diffusion.sde import SDEState
from diffuse.denoisers.cond import CondDenoiser, CondDenoiserState
from diffuse.base_forward_model import MeasurementState
from diffuse.predictor import Predictor


[docs] @dataclass class TMPDenoiser(CondDenoiser): """Conditional denoiser using Tweedie's Moment Projection (TMP). Implements TMP which modifies the score function to incorporate measurement information through Tweedie's formula and moment matching. Args: integrator: Numerical integrator for solving the reverse SDE model: Diffusion model defining the forward process predictor: Predictor for computing score/noise/velocity forward_model: Forward measurement operator References: Boys, B., Girolami, M., Pidstrigach, J., Reich, S., Mosca, A., & Akyildiz, Ö. D. (2023). Tweedie moment projected diffusions for inverse problems. arXiv:2310.06721 """
[docs] def step( self, rng_key: PRNGKeyArray, state: CondDenoiserState, measurement_state: MeasurementState, ) -> CondDenoiserState: """Single step of TMP sampling. Modifies the score to include measurement term and uses integrator for the update. Args: rng_key: Random number generator key state: Current conditional denoiser state measurement_state: Measurement information Returns: Updated conditional denoiser state """ y_meas = measurement_state.y # Define modified score function that includes measurement term def modified_score(x: Array, t: Array) -> Array: sigma_t = self.model.noise_level(t) alpha_t = self.model.signal_level(t) scale = sigma_t / alpha_t def tweedie_fn(x_): return self.model.tweedie(SDEState(x_, t), self.predictor.score).position def efficient(v): restored_v = self.forward_model.adjoint(v, measurement_state) _, tangents = jax.jvp(tweedie_fn, (x,), (restored_v,)) measured_tangents = self.forward_model.apply(tangents, measurement_state) return scale * measured_tangents + self.forward_model.std**2 * v denoised = tweedie_fn(x) b = y_meas - self.forward_model.apply(denoised, measurement_state) res, _ = jax.scipy.sparse.linalg.cg(efficient, b, maxiter=3) restored_res = self.forward_model.adjoint(res, measurement_state) _, guidance = jax.jvp(tweedie_fn, (x,), (restored_res,)) score_val = self.predictor.score(x, t) return score_val + guidance # Create modified predictor for guidance modified_predictor = Predictor(self.model, modified_score, "score") # Use integrator to compute next state integrator_state_next = self.integrator(state.integrator_state, modified_predictor) state_next = CondDenoiserState(integrator_state_next, state.log_weights) return state_next