Integrators#

Deterministic Integrators#

class diffuse.integrator.EulerIntegrator(model: DiffusionModel, timer: Timer, stochastic_churn_rate: float = 0.0, churn_min: float = 0.0, churn_max: float = 0.0, noise_inflation_factor: float = 1.0)[source]#

Bases: ChurnedIntegrator

Euler integrator for probability flow ODEs in diffusion models.

Implements the basic Euler method for numerical integration:

\[dx = v(x,t) \cdot dt\]

where \(v(x,t)\) is the velocity field from the probability flow ODE. Works with all diffusion models (SDE, Flow, EDM) using the velocity parameterization.

Parameters:
  • model – Diffusion model defining the diffusion process

  • timer – Timer object managing the discretization of the time interval

  • stochastic_churn_rate – Rate of applying stochastic churning (default: 0.0)

  • churn_min – Minimum time threshold for churning (default: 0.0)

  • churn_max – Maximum time threshold for churning (default: 0.0)

  • noise_inflation_factor – Factor to scale injected noise (default: 1.0)

class diffuse.integrator.HeunIntegrator(model: DiffusionModel, timer: Timer, stochastic_churn_rate: float = 0.0, churn_min: float = 0.0, churn_max: float = 0.0, noise_inflation_factor: float = 1.0)[source]#

Bases: ChurnedIntegrator

Heun’s method integrator for probability flow ODEs in diffusion models.

Implements a second-order Runge-Kutta method (Heun’s method) that uses an intermediate Euler step to improve accuracy:

\[x_{n+1} = x_n + \frac{v_1 + v_2}{2} \cdot dt\]

where:

  • \(v_1 = v(x_n, t_n)\)

  • \(v_2 = v(x_n + v_1 \cdot dt, t_{n+1})\)

Works with all diffusion models (SDE, Flow, EDM) using the velocity parameterization.

Parameters:
  • model – Diffusion model defining the diffusion process

  • timer – Timer object managing the discretization of the time interval

  • stochastic_churn_rate – Rate of applying stochastic churning (default: 0.0)

  • churn_min – Minimum time threshold for churning (default: 0.0)

  • churn_max – Maximum time threshold for churning (default: 0.0)

  • noise_inflation_factor – Factor to scale injected noise (default: 1.0)

class diffuse.integrator.DPMpp2sIntegrator(model: DiffusionModel, timer: Timer, stochastic_churn_rate: float = 0.0, churn_min: float = 0.0, churn_max: float = 0.0, noise_inflation_factor: float = 1.0)[source]#

Bases: ChurnedIntegrator

DPM-Solver++ (2S) integrator for reverse-time diffusion processes.

Implements the 2nd-order DPM-Solver++ algorithm which uses a midpoint prediction step and dynamic thresholding. This method provides improved stability and accuracy compared to basic Euler integration.

The method uses log-space computations and midpoint predictions to better handle the diffusion process dynamics.

class diffuse.integrator.DDIMIntegrator(model: DiffusionModel, timer: Timer, stochastic_churn_rate: float = 0.0, churn_min: float = 0.0, churn_max: float = 0.0, noise_inflation_factor: float = 1.0)[source]#

Bases: ChurnedIntegrator

Denoising Diffusion Implicit Models (DDIM) integrator.

DDIM assumes the same latent noise \(\varepsilon\) along the entire path. The update rule is:

\[x_s = \frac{\alpha_s}{\alpha_t} x_t - \left(\frac{\alpha_s \sigma_t}{\alpha_t} - \sigma_s\right) \varepsilon_\theta(x_t, t)\]

where \(s < t\), and \(\varepsilon_\theta(x_t, t)\) is the predicted noise.

Parameters:
  • model – Diffusion model defining the diffusion process

  • timer – Timer object managing the discretization of the time interval

  • stochastic_churn_rate – Rate of applying stochastic churning (default: 0.0)

  • churn_min – Minimum time threshold for churning (default: 0.0)

  • churn_max – Maximum time threshold for churning (default: 0.0)

  • noise_inflation_factor – Factor to scale injected noise (default: 1.0)

References

Song, J., Meng, C., Ermon, S. (2020). “Denoising Diffusion Implicit Models” arXiv:2010.02502

Stochastic Integrators#

class diffuse.integrator.EulerMaruyamaIntegrator(model: DiffusionModel, timer: Timer)[source]#

Bases: Integrator

Euler-Maruyama stochastic integrator for Stochastic Differential Equations (SDEs).

Implements the Euler-Maruyama method for numerical integration of SDEs of the form:

\[dX(t) = \mu(X,t)dt + \sigma(X,t)dW(t)\]

where:

  • \(\mu(X,t)\) is the drift term: \(\beta(t) \cdot (0.5 X + \nabla_x \log p(x|t))\)

  • \(\sigma(X,t)\) is the diffusion term: \(\sqrt{\beta(t)}\)

  • \(dW(t)\) is the Wiener process increment

  • \(\beta(t)\) is the noise schedule

The method advances the solution using the discrete approximation:

\[X(t + dt) = X(t) + \mu(X,t)dt + \sigma(X,t)\sqrt{dt} \cdot \mathcal{N}(0,1)\]

This is the simplest stochastic integration scheme with strong order 0.5 convergence for general SDEs.

Parameters:
  • model – Diffusion model defining the diffusion process

  • timer – Timer object managing the discretization of the time interval

Base Classes#

class diffuse.integrator.IntegratorState(position: Array, rng_key: Key[Array, ''] | UInt32[Array, '2'], step: int = 0)[source]#

State container for numerical integrators.

position#

Current state vector/tensor

Type:

jax.Array

rng_key#

JAX random number generator key

Type:

jaxtyping.Key[Array, ‘’] | jaxtyping.UInt32[Array, ‘2’]

step#

Current integration step counter (default: 0)

Type:

int

class diffuse.integrator.Integrator(model: DiffusionModel, timer: Timer)[source]#

Bases: object

Base class for numerical integrators of diffusion processes.

Provides the basic interface for implementing various numerical integration schemes for both deterministic and stochastic differential equations.

Parameters:
  • model – Diffusion model defining the diffusion process

  • timer – Timer object managing the discretization of the time interval

init(position: Array, rng_key: Key[Array, ''] | UInt32[Array, '2']) IntegratorState[source]#

Initialize the integrator state.

Parameters:
  • position – Initial state vector/tensor

  • rng_key – JAX random number generator key

Returns:

Initial IntegratorState

class diffuse.integrator.ChurnedIntegrator(model: DiffusionModel, timer: Timer, stochastic_churn_rate: float = 0.0, churn_min: float = 0.0, churn_max: float = 0.0, noise_inflation_factor: float = 1.0)[source]#

Bases: Integrator

Integrator with stochastic churning for improved sampling.

Implements the stochastic churning mechanism that can help improve sampling quality by occasionally injecting controlled noise into the process.

Parameters:
  • model – Diffusion model defining the diffusion process

  • timer – Timer object managing the discretization of the time interval

  • stochastic_churn_rate – Rate of applying stochastic churning (default: 0.0)

  • churn_min – Minimum time threshold for churning (default: 0.0)

  • churn_max – Maximum time threshold for churning (default: 0.0)

  • noise_inflation_factor – Factor to scale injected noise (default: 1.0)