Diffuse: JAX-based Diffusion Models#
Diffuse is a research-oriented Python package for diffusion-based generative modeling built on JAX and Flax. It provides modular, swappable components for building and experimenting with diffusion models.
⚡ JAX-Native
Built from the ground up with JAX for automatic differentiation, JIT compilation, and GPU acceleration.
🔧 Modular Design
Mix and match components: SDE + Timer + Integrator + Denoiser = Complete pipeline.
🧪 Research-Ready
Experiment with different noise schedules, integrators, and conditioning methods.
🎯 Conditional Generation
Built-in support for DPS, FPS, TMP, DAPS, PiGDM, PnPDM, DPS-GSG, EnKG, and DiffPIR.
Quick Installation#
For development:
git clone https://github.com/jcopo/diffuse.git
cd diffuse
pip install -e .
Quick Start#
Here’s a minimal pipeline example:
import jax
import jax.numpy as jnp
from diffuse.diffusion.sde import LinearSchedule, SDE
from diffuse.timer import VpTimer
from diffuse.integrator.deterministic import DDIMIntegrator
from diffuse.predictor import Predictor
from diffuse.denoisers.denoiser import Denoiser
# 1. Define components
beta = LinearSchedule(b_min=0.02, b_max=7.0, t0=0.0, T=1.0)
sde = SDE(beta=beta)
timer = VpTimer(eps=1e-5, tf=1.0, n_steps=50)
integrator = DDIMIntegrator(model=sde, timer=timer)
predictor = Predictor(model=sde, network=network_fn, prediction_type="score")
# 2. Create pipeline
denoiser = Denoiser(
integrator=integrator,
model=sde,
predictor=predictor,
x0_shape=data_dim, # Shape of data samples
)
# 3. Generate samples
key = jax.random.PRNGKey(0)
final_state, _ = denoiser.generate(key, n_steps=50, n_particles=100)
samples = final_state.integrator_state.position
print(f"✓ Generated {samples.shape} samples")
See the Quick Start Guide guide for a complete tutorial.
Citation#
If you use Diffuse in your research, please cite the library:
@software{diffuse2024,
title = {Diffuse: A modular diffusion model library},
author = {Iollo, J., Oudoumanessah G.},
year = {2025},
url = {https://github.com/jcopo/diffuse}
}