Denoiser#

class diffuse.denoisers.denoiser.Denoiser(integrator: Integrator, model: DiffusionModel, predictor: Predictor, x0_shape: Tuple[int, ...])[source]#

Bases: BaseDenoiser

Denoiser for generating samples using reverse diffusion.

Parameters:
  • integrator – The integrator to use for solving the reverse SDE

  • model – The diffusion model (SDE) defining the forward process

  • predictor – The predictor for computing the score/denoised estimate

  • x0_shape – Shape of the data samples (excluding batch dimension)

generate(rng_key: Key[Array, ''] | UInt32[Array, '2'], n_steps: int, n_particles: int, keep_history: bool = False, data_sharding: Any | None = None) Tuple[DenoiserState, Array | None][source]#

Generate denoised samples \(x_0\).

Parameters:
  • rng_key – Random key for initialization

  • n_steps – Number of denoising steps to perform

  • n_particles – Number of samples to generate (batch size)

  • keep_history – If True, return the full trajectory of samples

  • data_sharding – Optional JAX sharding specification for distributed computation

Returns:

Tuple of (final_state, history), where history is None if keep_history=False

init(position: Array, rng_key: Key[Array, ''] | UInt32[Array, '2']) DenoiserState[source]#

Initialize denoiser state

step(state: DenoiserState) DenoiserState[source]#

Perform one denoising step.

Sample \(x_{t-1} \sim p(x_{t-1} | x_t)\)

Parameters:

state – Current denoiser state

Returns:

Updated denoiser state at the previous time step