Integrators#
Deterministic Integrators#
- class diffuse.integrator.EulerIntegrator(model: DiffusionModel, timer: Timer, stochastic_churn_rate: float = 0.0, churn_min: float = 0.0, churn_max: float = 0.0, noise_inflation_factor: float = 1.0)[source]#
Bases:
ChurnedIntegratorEuler integrator for probability flow ODEs in diffusion models.
Implements the basic Euler method for numerical integration:
\[dx = v(x,t) \cdot dt\]where \(v(x,t)\) is the velocity field from the probability flow ODE. Works with all diffusion models (SDE, Flow, EDM) using the velocity parameterization.
- Parameters:
model – Diffusion model defining the diffusion process
timer – Timer object managing the discretization of the time interval
stochastic_churn_rate – Rate of applying stochastic churning (default: 0.0)
churn_min – Minimum time threshold for churning (default: 0.0)
churn_max – Maximum time threshold for churning (default: 0.0)
noise_inflation_factor – Factor to scale injected noise (default: 1.0)
- class diffuse.integrator.HeunIntegrator(model: DiffusionModel, timer: Timer, stochastic_churn_rate: float = 0.0, churn_min: float = 0.0, churn_max: float = 0.0, noise_inflation_factor: float = 1.0)[source]#
Bases:
ChurnedIntegratorHeun’s method integrator for probability flow ODEs in diffusion models.
Implements a second-order Runge-Kutta method (Heun’s method) that uses an intermediate Euler step to improve accuracy:
\[x_{n+1} = x_n + \frac{v_1 + v_2}{2} \cdot dt\]where:
\(v_1 = v(x_n, t_n)\)
\(v_2 = v(x_n + v_1 \cdot dt, t_{n+1})\)
Works with all diffusion models (SDE, Flow, EDM) using the velocity parameterization.
- Parameters:
model – Diffusion model defining the diffusion process
timer – Timer object managing the discretization of the time interval
stochastic_churn_rate – Rate of applying stochastic churning (default: 0.0)
churn_min – Minimum time threshold for churning (default: 0.0)
churn_max – Maximum time threshold for churning (default: 0.0)
noise_inflation_factor – Factor to scale injected noise (default: 1.0)
- class diffuse.integrator.DPMpp2sIntegrator(model: DiffusionModel, timer: Timer, stochastic_churn_rate: float = 0.0, churn_min: float = 0.0, churn_max: float = 0.0, noise_inflation_factor: float = 1.0)[source]#
Bases:
ChurnedIntegratorDPM-Solver++ (2S) integrator for reverse-time diffusion processes.
Implements the 2nd-order DPM-Solver++ algorithm which uses a midpoint prediction step and dynamic thresholding. This method provides improved stability and accuracy compared to basic Euler integration.
The method uses log-space computations and midpoint predictions to better handle the diffusion process dynamics.
- class diffuse.integrator.DDIMIntegrator(model: DiffusionModel, timer: Timer, stochastic_churn_rate: float = 0.0, churn_min: float = 0.0, churn_max: float = 0.0, noise_inflation_factor: float = 1.0)[source]#
Bases:
ChurnedIntegratorDenoising Diffusion Implicit Models (DDIM) integrator.
DDIM assumes the same latent noise \(\varepsilon\) along the entire path. The update rule is:
\[x_s = \frac{\alpha_s}{\alpha_t} x_t - \left(\frac{\alpha_s \sigma_t}{\alpha_t} - \sigma_s\right) \varepsilon_\theta(x_t, t)\]where \(s < t\), and \(\varepsilon_\theta(x_t, t)\) is the predicted noise.
- Parameters:
model – Diffusion model defining the diffusion process
timer – Timer object managing the discretization of the time interval
stochastic_churn_rate – Rate of applying stochastic churning (default: 0.0)
churn_min – Minimum time threshold for churning (default: 0.0)
churn_max – Maximum time threshold for churning (default: 0.0)
noise_inflation_factor – Factor to scale injected noise (default: 1.0)
References
Song, J., Meng, C., Ermon, S. (2020). “Denoising Diffusion Implicit Models” arXiv:2010.02502
Stochastic Integrators#
- class diffuse.integrator.EulerMaruyamaIntegrator(model: DiffusionModel, timer: Timer)[source]#
Bases:
IntegratorEuler-Maruyama stochastic integrator for Stochastic Differential Equations (SDEs).
Implements the Euler-Maruyama method for numerical integration of SDEs of the form:
\[dX(t) = \mu(X,t)dt + \sigma(X,t)dW(t)\]where:
\(\mu(X,t)\) is the drift term: \(\beta(t) \cdot (0.5 X + \nabla_x \log p(x|t))\)
\(\sigma(X,t)\) is the diffusion term: \(\sqrt{\beta(t)}\)
\(dW(t)\) is the Wiener process increment
\(\beta(t)\) is the noise schedule
The method advances the solution using the discrete approximation:
\[X(t + dt) = X(t) + \mu(X,t)dt + \sigma(X,t)\sqrt{dt} \cdot \mathcal{N}(0,1)\]This is the simplest stochastic integration scheme with strong order 0.5 convergence for general SDEs.
- Parameters:
model – Diffusion model defining the diffusion process
timer – Timer object managing the discretization of the time interval
Base Classes#
- class diffuse.integrator.IntegratorState(position: Array, rng_key: Key[Array, ''] | UInt32[Array, '2'], step: int = 0)[source]#
State container for numerical integrators.
- rng_key#
JAX random number generator key
- Type:
jaxtyping.Key[Array, ‘’] | jaxtyping.UInt32[Array, ‘2’]
- class diffuse.integrator.Integrator(model: DiffusionModel, timer: Timer)[source]#
Bases:
objectBase class for numerical integrators of diffusion processes.
Provides the basic interface for implementing various numerical integration schemes for both deterministic and stochastic differential equations.
- Parameters:
model – Diffusion model defining the diffusion process
timer – Timer object managing the discretization of the time interval
- init(position: Array, rng_key: Key[Array, ''] | UInt32[Array, '2']) IntegratorState[source]#
Initialize the integrator state.
- Parameters:
position – Initial state vector/tensor
rng_key – JAX random number generator key
- Returns:
Initial IntegratorState
- class diffuse.integrator.ChurnedIntegrator(model: DiffusionModel, timer: Timer, stochastic_churn_rate: float = 0.0, churn_min: float = 0.0, churn_max: float = 0.0, noise_inflation_factor: float = 1.0)[source]#
Bases:
IntegratorIntegrator with stochastic churning for improved sampling.
Implements the stochastic churning mechanism that can help improve sampling quality by occasionally injecting controlled noise into the process.
- Parameters:
model – Diffusion model defining the diffusion process
timer – Timer object managing the discretization of the time interval
stochastic_churn_rate – Rate of applying stochastic churning (default: 0.0)
churn_min – Minimum time threshold for churning (default: 0.0)
churn_max – Maximum time threshold for churning (default: 0.0)
noise_inflation_factor – Factor to scale injected noise (default: 1.0)