Quick Start Guide#
This guide introduces the core components of Diffuse and shows how to build diffusion pipelines.
Core Components#
Diffuse follows a modular design with four main components that can be mixed and matched:
SDE (Stochastic Differential Equation) - Defines the forward and reverse diffusion processes
Timer - Controls time scheduling during sampling
Integrator - Numerically solves the reverse SDE
Denoiser - Orchestrates generation and handles conditional sampling
SDE: Forward and Reverse Processes#
The SDE defines how noise is added during training and removed during sampling. Diffusion models are described by a stochastic differential equation:
This corresponds to slowly adding noise such that the noised signal can be written as:
where \(s(t)\) and \(\sigma(t)\) are given by:
import jax
import jax.numpy as jnp
from diffuse.diffusion.sde import LinearSchedule, SDE
# Create noise schedule
beta = LinearSchedule(b_min=0.02, b_max=7.0, t0=0.0, T=1.0)
sde = SDE(beta=beta)
# The SDE provides coefficients for the diffusion process
t = 0.5
coeffs = sde.coefficients(t)
print(f"At t={t}: drift={coeffs.drift:.3f}, diffusion={coeffs.diffusion:.3f}")
Different schedules are available:
from diffuse.diffusion.sde import CosineSchedule
# Alternative: cosine schedule (often better for images)
cosine_beta = CosineSchedule(b_min=0.02, b_max=7.0, t0=0.0, T=1.0)
Timer: Scheduling Integration Steps#
The timer maps discrete integration steps to continuous time \(t \in [0, T]\). It defines the time discretization used during the numerical integration of the reverse SDE:

Time discretization strategies
from diffuse.timer import VpTimer
# Create timer with 50 integration steps
timer = VpTimer(eps=1e-5, tf=1.0, n_steps=50)
# Timer maps step index to time
step = 25
time = timer(step)
print(f"Step {step} corresponds to time {time:.3f}")
Integrator: Numerical Solvers#
Integrators solve the reverse SDE numerically to perform denoising. The reverse SDE is given by:
Different integrators offer trade-offs between speed and quality:
from diffuse.integrator.deterministic import EulerIntegrator, DDIMIntegrator, DPMpp2sIntegrator
from diffuse.integrator.stochastic import EulerMaruyamaIntegrator
# Fast but lower quality
euler = EulerIntegrator(model=sde, timer=timer)
# Good balance of speed and quality
ddim = DDIMIntegrator(model=sde, timer=timer)
# High quality, slower
dpm = DPMpp2sIntegrator(model=sde, timer=timer)
# Stochastic (adds randomness)
euler_maruyama = EulerMaruyamaIntegrator(model=sde, timer=timer)
Predictor and Network#
The score function \(\nabla_x\log p_t(x)\) predicts the gradient of the log-density of the noisy data distribution at time \(t\). This is the key component that enables the reverse diffusion process. In practice, a neural network is trained to predict one of several equivalent quantities — score, noise \(\varepsilon\), velocity, or \(x_0\). The Predictor wraps the network with the chosen prediction_type and converts between targets internally:
from diffuse.predictor import Predictor
# Wrap a learned network. prediction_type is one of:
# "score", "noise", "velocity", "x0"
predictor = Predictor(model=sde, network=network_fn, prediction_type="score")
The network itself can be loaded from a Flax nnx module:
graphdef, state = nnx.split(model)
def network_fn(x, t):
model = nnx.merge(graphdef, state)
return model(x, t).output
Unconditional Generation#
To generate new samples \(x_0\) from pure noise \(x_T\), we integrate the reverse SDE from \(t=T\) to \(t=0\). Combine components to generate samples from pure noise:
from diffuse.denoisers.denoiser import Denoiser
# Create denoiser pipeline
denoiser = Denoiser(
integrator=ddim,
model=sde,
predictor=predictor,
x0_shape=(data_dim,), # Shape of data samples
)
# Generate samples
key = jax.random.PRNGKey(42)
n_particles = 100
n_steps = 50
final_state, history = denoiser.generate(
key, n_steps, n_particles, keep_history=True
)
samples = final_state.integrator_state.position
print(f"Generated {samples.shape[0]} samples of dimension {samples.shape[1]}")
Conditional Generation#
For conditional sampling \(x_0 \sim p(x_0|y)\) given measurements \(y\), use conditional denoisers that incorporate the measurement information during the reverse process:
from diffuse.denoisers.cond import (
DPSDenoiser, FPSDenoiser, TMPDenoiser,
DAPSDenoiser, PiGDMDenoiser, PnPDMDenoiser,
DPSGSGDenoiser, EnKGDenoiser, DiffPIRDenoiser,
)
from diffuse.base_forward_model import MeasurementState
from diffuse.examples.gaussian_mixtures.forward_models.matrix_product import MatrixProduct
# Create measurement
A = jnp.array([[1.0, 0.0]]) # Observe first coordinate
y_observed = jnp.array([1.5])
forward_model = MatrixProduct(A, std=0.1)
measurement_state = MeasurementState(y=y_observed, mask_history=A)
# Create conditional denoiser
fps_denoiser = FPSDenoiser(
integrator=ddim,
model=sde,
predictor=predictor,
forward_model=forward_model,
x0_shape=(data_dim,),
)
# Generate conditional samples
cond_state, cond_history = fps_denoiser.generate(
key, measurement_state, n_steps, n_particles, keep_history=True
)
conditional_samples = cond_state.integrator_state.position
Complete Pipeline Example#
Here’s a minimal working example:
import jax
import jax.numpy as jnp
from diffuse.diffusion.sde import LinearSchedule, SDE
from diffuse.timer import VpTimer
from diffuse.integrator.deterministic import DDIMIntegrator
from diffuse.predictor import Predictor
from diffuse.denoisers.denoiser import Denoiser
# 1. Define components
beta = LinearSchedule(b_min=0.02, b_max=7.0, t0=0.0, T=1.0)
sde = SDE(beta=beta)
timer = VpTimer(eps=1e-5, tf=1.0, n_steps=50)
integrator = DDIMIntegrator(model=sde, timer=timer)
predictor = Predictor(model=sde, network=network_fn, prediction_type="score")
# 2. Create pipeline
denoiser = Denoiser(
integrator=integrator,
model=sde,
predictor=predictor,
x0_shape=(data_dim,), # Shape of data samples
)
# 3. Generate samples
key = jax.random.PRNGKey(0)
final_state, _ = denoiser.generate(key, n_steps=50, n_particles=100)
samples = final_state.integrator_state.position
print(f"✓ Generated {samples.shape} samples")
Pytest#
This packages comes with an extensive test suite that can be run using pytest. To visualize the results, you can add –plot and use pytest -k to select desired Denoisers and Integrators combinations:
pytest --plot -k "DDIMIntegrator and DPSDenoiser"