Source code for diffuse.denoisers.denoiser
# Copyright 2025 Jacopo Iollo <jacopo.iollo@inria.fr>, Geoffroy Oudoumanessah <geoffroy.oudoumanessah@inria.fr>
# Licensed under the Apache License, Version 2.0 (the "License");
# http://www.apache.org/licenses/LICENSE-2.0
from dataclasses import dataclass
from typing import Tuple, Union, Optional, Any
import jax
import jax.numpy as jnp
from jaxtyping import Array, PRNGKeyArray
from diffuse.integrator.base import Integrator
from diffuse.diffusion.sde import DiffusionModel
from diffuse.predictor import Predictor
from diffuse.denoisers.base import DenoiserState, BaseDenoiser
[docs]
@dataclass
class Denoiser(BaseDenoiser):
"""
Denoiser for generating samples using reverse diffusion.
Args:
integrator: The integrator to use for solving the reverse SDE
model: The diffusion model (SDE) defining the forward process
predictor: The predictor for computing the score/denoised estimate
x0_shape: Shape of the data samples (excluding batch dimension)
"""
integrator: Integrator
model: DiffusionModel
predictor: Predictor
x0_shape: Tuple[int, ...]
[docs]
def init(self, position: Array, rng_key: PRNGKeyArray) -> DenoiserState:
integrator_state = self.integrator.init(position, rng_key)
return DenoiserState(integrator_state)
[docs]
def step(
self,
state: DenoiserState,
) -> DenoiserState:
r"""
Perform one denoising step.
Sample :math:`x_{t-1} \sim p(x_{t-1} | x_t)`
Args:
state: Current denoiser state
Returns:
Updated denoiser state at the previous time step
"""
integrator_state = state.integrator_state
integrator_state_next = self.integrator(integrator_state, self.predictor)
return DenoiserState(integrator_state_next)
[docs]
def generate(
self,
rng_key: PRNGKeyArray,
n_steps: int,
n_particles: int,
keep_history: bool = False,
data_sharding: Optional[Any] = None,
) -> Tuple[DenoiserState, Union[Array, None]]:
r"""
Generate denoised samples :math:`x_0`.
Args:
rng_key: Random key for initialization
n_steps: Number of denoising steps to perform
n_particles: Number of samples to generate (batch size)
keep_history: If True, return the full trajectory of samples
data_sharding: Optional JAX sharding specification for distributed computation
Returns:
Tuple of (final_state, history), where history is None if keep_history=False
"""
rng_key, rng_key_start = jax.random.split(rng_key)
rndm_start = jax.random.normal(rng_key_start, shape=(n_particles, *self.x0_shape))
# Shard the initial noise if sharding is provided
if data_sharding is not None:
rndm_start = jax.device_put(rndm_start, data_sharding)
keys = jax.random.split(rng_key, n_particles)
# Also shard the keys
if data_sharding is not None:
keys = jax.device_put(keys, data_sharding)
state = jax.vmap(self.init, in_axes=(0, 0))(rndm_start, keys)
def body_fun(state, _):
state_next = jax.vmap(self.step)(state)
return state_next, state_next.integrator_state.position if keep_history else None
return jax.lax.scan(body_fun, state, jnp.arange(n_steps))