Denoiser#
- class diffuse.denoisers.denoiser.Denoiser(integrator: Integrator, model: DiffusionModel, predictor: Predictor, x0_shape: Tuple[int, ...])[source]#
Bases:
BaseDenoiserDenoiser for generating samples using reverse diffusion.
- Parameters:
integrator – The integrator to use for solving the reverse SDE
model – The diffusion model (SDE) defining the forward process
predictor – The predictor for computing the score/denoised estimate
x0_shape – Shape of the data samples (excluding batch dimension)
- generate(rng_key: Key[Array, ''] | UInt32[Array, '2'], n_steps: int, n_particles: int, keep_history: bool = False, data_sharding: Any | None = None) Tuple[DenoiserState, Array | None][source]#
Generate denoised samples \(x_0\).
- Parameters:
rng_key – Random key for initialization
n_steps – Number of denoising steps to perform
n_particles – Number of samples to generate (batch size)
keep_history – If True, return the full trajectory of samples
data_sharding – Optional JAX sharding specification for distributed computation
- Returns:
Tuple of (final_state, history), where history is None if keep_history=False
- init(position: Array, rng_key: Key[Array, ''] | UInt32[Array, '2']) DenoiserState[source]#
Initialize denoiser state
- step(state: DenoiserState) DenoiserState[source]#
Perform one denoising step.
Sample \(x_{t-1} \sim p(x_{t-1} | x_t)\)
- Parameters:
state – Current denoiser state
- Returns:
Updated denoiser state at the previous time step