Quick Start Guide#

This guide introduces the core components of Diffuse and shows how to build diffusion pipelines.

Core Components#

Diffuse follows a modular design with four main components that can be mixed and matched:

  1. SDE (Stochastic Differential Equation) - Defines the forward and reverse diffusion processes

  2. Timer - Controls time scheduling during sampling

  3. Integrator - Numerically solves the reverse SDE

  4. Denoiser - Orchestrates generation and handles conditional sampling

SDE: Forward and Reverse Processes#

The SDE defines how noise is added during training and removed during sampling. Diffusion models are described by a stochastic differential equation:

\[dx(t) = f(x, t)dt + g(t)dW(t)\]

This corresponds to slowly adding noise such that the noised signal can be written as:

\[x(t) = s(t)x(0) + \sigma(t)\varepsilon, \quad \varepsilon\sim\mathcal{N}(0,I)\]

where \(s(t)\) and \(\sigma(t)\) are given by:

\[s(t) = \exp\left(\int_0^t f(\xi) d\xi\right), \quad \sigma(t) = s(t)\left(\int_0^t \frac{g(\xi)^2}{s(\xi)^2} d\xi \right)^{1/2}\]
import jax
import jax.numpy as jnp
from diffuse.diffusion.sde import LinearSchedule, SDE

# Create noise schedule
beta = LinearSchedule(b_min=0.02, b_max=7.0, t0=0.0, T=1.0)
sde = SDE(beta=beta)

# The SDE provides coefficients for the diffusion process
t = 0.5
coeffs = sde.coefficients(t)
print(f"At t={t}: drift={coeffs.drift:.3f}, diffusion={coeffs.diffusion:.3f}")

Different schedules are available:

from diffuse.diffusion.sde import CosineSchedule

# Alternative: cosine schedule (often better for images)
cosine_beta = CosineSchedule(b_min=0.02, b_max=7.0, t0=0.0, T=1.0)

Timer: Scheduling Integration Steps#

The timer maps discrete integration steps to continuous time \(t \in [0, T]\). It defines the time discretization used during the numerical integration of the reverse SDE:

Figure made with TikZ

Time discretization strategies

from diffuse.timer import VpTimer

# Create timer with 50 integration steps
timer = VpTimer(eps=1e-5, tf=1.0, n_steps=50)

# Timer maps step index to time
step = 25
time = timer(step)
print(f"Step {step} corresponds to time {time:.3f}")

Integrator: Numerical Solvers#

Integrators solve the reverse SDE numerically to perform denoising. The reverse SDE is given by:

\[dx = [f(x,t) - g(t)^2\nabla_x\log p_t(x)]dt + g(t)d\bar{W}(t)\]

Different integrators offer trade-offs between speed and quality:

from diffuse.integrator.deterministic import EulerIntegrator, DDIMIntegrator, DPMpp2sIntegrator
from diffuse.integrator.stochastic import EulerMaruyamaIntegrator

# Fast but lower quality
euler = EulerIntegrator(model=sde, timer=timer)

# Good balance of speed and quality
ddim = DDIMIntegrator(model=sde, timer=timer)

# High quality, slower
dpm = DPMpp2sIntegrator(model=sde, timer=timer)

# Stochastic (adds randomness)
euler_maruyama = EulerMaruyamaIntegrator(model=sde, timer=timer)

Predictor and Network#

The score function \(\nabla_x\log p_t(x)\) predicts the gradient of the log-density of the noisy data distribution at time \(t\). This is the key component that enables the reverse diffusion process. In practice, a neural network is trained to predict one of several equivalent quantities — score, noise \(\varepsilon\), velocity, or \(x_0\). The Predictor wraps the network with the chosen prediction_type and converts between targets internally:

from diffuse.predictor import Predictor

# Wrap a learned network. prediction_type is one of:
# "score", "noise", "velocity", "x0"
predictor = Predictor(model=sde, network=network_fn, prediction_type="score")

The network itself can be loaded from a Flax nnx module:

graphdef, state = nnx.split(model)
def network_fn(x, t):
   model = nnx.merge(graphdef, state)
   return model(x, t).output

Unconditional Generation#

To generate new samples \(x_0\) from pure noise \(x_T\), we integrate the reverse SDE from \(t=T\) to \(t=0\). Combine components to generate samples from pure noise:

from diffuse.denoisers.denoiser import Denoiser

# Create denoiser pipeline
denoiser = Denoiser(
    integrator=ddim,
    model=sde,
    predictor=predictor,
    x0_shape=(data_dim,),  # Shape of data samples
)

# Generate samples
key = jax.random.PRNGKey(42)
n_particles = 100
n_steps = 50

final_state, history = denoiser.generate(
    key, n_steps, n_particles, keep_history=True
)

samples = final_state.integrator_state.position
print(f"Generated {samples.shape[0]} samples of dimension {samples.shape[1]}")

Conditional Generation#

For conditional sampling \(x_0 \sim p(x_0|y)\) given measurements \(y\), use conditional denoisers that incorporate the measurement information during the reverse process:

from diffuse.denoisers.cond import (
    DPSDenoiser, FPSDenoiser, TMPDenoiser,
    DAPSDenoiser, PiGDMDenoiser, PnPDMDenoiser,
    DPSGSGDenoiser, EnKGDenoiser, DiffPIRDenoiser,
)
from diffuse.base_forward_model import MeasurementState
from diffuse.examples.gaussian_mixtures.forward_models.matrix_product import MatrixProduct

# Create measurement
A = jnp.array([[1.0, 0.0]])  # Observe first coordinate
y_observed = jnp.array([1.5])
forward_model = MatrixProduct(A, std=0.1)

measurement_state = MeasurementState(y=y_observed, mask_history=A)

# Create conditional denoiser
fps_denoiser = FPSDenoiser(
    integrator=ddim,
    model=sde,
    predictor=predictor,
    forward_model=forward_model,
    x0_shape=(data_dim,),
)

# Generate conditional samples
cond_state, cond_history = fps_denoiser.generate(
    key, measurement_state, n_steps, n_particles, keep_history=True
)

conditional_samples = cond_state.integrator_state.position

Complete Pipeline Example#

Here’s a minimal working example:

import jax
import jax.numpy as jnp
from diffuse.diffusion.sde import LinearSchedule, SDE
from diffuse.timer import VpTimer
from diffuse.integrator.deterministic import DDIMIntegrator
from diffuse.predictor import Predictor
from diffuse.denoisers.denoiser import Denoiser

# 1. Define components
beta = LinearSchedule(b_min=0.02, b_max=7.0, t0=0.0, T=1.0)
sde = SDE(beta=beta)
timer = VpTimer(eps=1e-5, tf=1.0, n_steps=50)
integrator = DDIMIntegrator(model=sde, timer=timer)
predictor = Predictor(model=sde, network=network_fn, prediction_type="score")

# 2. Create pipeline
denoiser = Denoiser(
    integrator=integrator,
    model=sde,
    predictor=predictor,
    x0_shape=(data_dim,),  # Shape of data samples
)

# 3. Generate samples
key = jax.random.PRNGKey(0)
final_state, _ = denoiser.generate(key, n_steps=50, n_particles=100)
samples = final_state.integrator_state.position

print(f"✓ Generated {samples.shape} samples")

Pytest#

This packages comes with an extensive test suite that can be run using pytest. To visualize the results, you can add –plot and use pytest -k to select desired Denoisers and Integrators combinations:

pytest --plot -k "DDIMIntegrator and DPSDenoiser"