Source code for diffuse.integrator.deterministic

# Copyright 2025 Jacopo Iollo <jacopo.iollo@inria.fr>, Geoffroy Oudoumanessah <geoffroy.oudoumanessah@inria.fr>
# Licensed under the Apache License, Version 2.0 (the "License");
# http://www.apache.org/licenses/LICENSE-2.0
from dataclasses import dataclass

import jax
import jax.numpy as jnp

from diffuse.diffusion.sde import SDEState, DiffusionModel
from diffuse.integrator.base import IntegratorState, ChurnedIntegrator
from diffuse.predictor import Predictor


__all__ = ["EulerIntegrator", "HeunIntegrator", "DPMpp2sIntegrator", "DDIMIntegrator"]


[docs] @dataclass class EulerIntegrator(ChurnedIntegrator): r"""Euler integrator for probability flow ODEs in diffusion models. Implements the basic Euler method for numerical integration: .. math:: dx = v(x,t) \cdot dt where :math:`v(x,t)` is the velocity field from the probability flow ODE. Works with all diffusion models (SDE, Flow, EDM) using the velocity parameterization. Args: model: Diffusion model defining the diffusion process timer: Timer object managing the discretization of the time interval stochastic_churn_rate: Rate of applying stochastic churning (default: 0.0) churn_min: Minimum time threshold for churning (default: 0.0) churn_max: Maximum time threshold for churning (default: 0.0) noise_inflation_factor: Factor to scale injected noise (default: 1.0) """ model: DiffusionModel def __call__(self, integrator_state: IntegratorState, predictor: Predictor) -> IntegratorState: """Perform one Euler integration step. Args: integrator_state: Current state containing (position, rng_key, step) predictor: Predictor providing velocity field v(x,t) Returns: Updated IntegratorState with the next position """ _, rng_key, step = integrator_state _, rng_key_next = jax.random.split(rng_key) position_churned, t_churned = self._churn_fn(integrator_state) t_next = self.timer(step + 1) dt = t_next - t_churned # Use velocity directly for probability flow ODE velocity = predictor.velocity(position_churned, t_churned) dx = velocity * dt _, rng_key_next = jax.random.split(rng_key) return IntegratorState(position_churned + dx, rng_key_next, step + 1)
[docs] @dataclass class HeunIntegrator(ChurnedIntegrator): r"""Heun's method integrator for probability flow ODEs in diffusion models. Implements a second-order Runge-Kutta method (Heun's method) that uses an intermediate Euler step to improve accuracy: .. math:: x_{n+1} = x_n + \frac{v_1 + v_2}{2} \cdot dt where: - :math:`v_1 = v(x_n, t_n)` - :math:`v_2 = v(x_n + v_1 \cdot dt, t_{n+1})` Works with all diffusion models (SDE, Flow, EDM) using the velocity parameterization. Args: model: Diffusion model defining the diffusion process timer: Timer object managing the discretization of the time interval stochastic_churn_rate: Rate of applying stochastic churning (default: 0.0) churn_min: Minimum time threshold for churning (default: 0.0) churn_max: Maximum time threshold for churning (default: 0.0) noise_inflation_factor: Factor to scale injected noise (default: 1.0) """ model: DiffusionModel def __call__(self, integrator_state: IntegratorState, predictor: Predictor) -> IntegratorState: """Perform one Heun integration step. Args: integrator_state: Current state containing (position, rng_key, step) predictor: Predictor providing velocity field v(x,t) Returns: Updated IntegratorState with the next position using Heun's method """ _, rng_key, step = integrator_state _, rng_key_next = jax.random.split(rng_key) position_churned, t_churned = self._churn_fn(integrator_state) t_next = self.timer(step + 1) dt = t_next - t_churned # Heun's method using velocity (probability flow ODE) # k1 = velocity at current point velocity_churned = predictor.velocity(position_churned, t_churned) position_next_euler = position_churned + velocity_churned * dt # k2 = velocity at Euler prediction velocity_next = predictor.velocity(position_next_euler, t_next) # Heun correction: average of the two velocities position_next_heun = position_churned + (velocity_churned + velocity_next) * dt / 2 return IntegratorState(position_next_heun, rng_key_next, step + 1)
[docs] @dataclass class DPMpp2sIntegrator(ChurnedIntegrator): """DPM-Solver++ (2S) integrator for reverse-time diffusion processes. Implements the 2nd-order DPM-Solver++ algorithm which uses a midpoint prediction step and dynamic thresholding. This method provides improved stability and accuracy compared to basic Euler integration. The method uses log-space computations and midpoint predictions to better handle the diffusion process dynamics. """ model: DiffusionModel def __call__(self, integrator_state: IntegratorState, predictor: Predictor) -> IntegratorState: """Perform one DPM-Solver++ (2S) integration step in reverse time. Args: integrator_state: Current state containing (position, rng_key, step) score: Score function s(x,t) that approximates ∇ₓ log p(x|t) Returns: Updated IntegratorState with the next position computed using the DPM-Solver++ (2S) algorithm """ _, rng_key, step = integrator_state _, rng_key_next = jax.random.split(rng_key) position_churned, t_churned = self._churn_fn(integrator_state) t_next = self.timer(step + 1) t_mid = (t_churned + t_next) / 2 signal_level_churned = self.model.signal_level(t_churned) signal_level_mid = self.model.signal_level(t_mid) signal_level_next = self.model.signal_level(t_next) sigma_churned = self.model.noise_level(t_churned) sigma_next = self.model.noise_level(t_next) sigma_mid = self.model.noise_level(t_mid) log_scale_churned, log_scale_next, log_scale_mid = ( jnp.log(signal_level_churned / sigma_churned), jnp.log(signal_level_next / sigma_next), jnp.log(signal_level_mid / sigma_mid), ) h = jnp.clip(log_scale_next - log_scale_churned, 1e-6) r = jnp.clip((log_scale_mid - log_scale_churned) / h, 1e-6) pred_x0_churned = self.model.tweedie(SDEState(position_churned, t_churned), predictor.score).position u = sigma_mid / sigma_churned * position_churned - signal_level_mid * jnp.expm1(-h * r) * pred_x0_churned pred_x0_mid = self.model.tweedie(SDEState(u, t_mid), predictor.score).position D = (1 - 1 / (2 * r)) * pred_x0_churned + (1 / (2 * r)) * pred_x0_mid next_position = sigma_next / sigma_churned * position_churned - signal_level_next * jnp.expm1(-h) * D _, rng_key_next = jax.random.split(rng_key) next_state = IntegratorState(next_position, rng_key_next, step + 1) return next_state
[docs] @dataclass class DDIMIntegrator(ChurnedIntegrator): r"""Denoising Diffusion Implicit Models (DDIM) integrator. DDIM assumes the same latent noise :math:`\varepsilon` along the entire path. The update rule is: .. math:: x_s = \frac{\alpha_s}{\alpha_t} x_t - \left(\frac{\alpha_s \sigma_t}{\alpha_t} - \sigma_s\right) \varepsilon_\theta(x_t, t) where :math:`s < t`, and :math:`\varepsilon_\theta(x_t, t)` is the predicted noise. Args: model: Diffusion model defining the diffusion process timer: Timer object managing the discretization of the time interval stochastic_churn_rate: Rate of applying stochastic churning (default: 0.0) churn_min: Minimum time threshold for churning (default: 0.0) churn_max: Maximum time threshold for churning (default: 0.0) noise_inflation_factor: Factor to scale injected noise (default: 1.0) References: Song, J., Meng, C., Ermon, S. (2020). "Denoising Diffusion Implicit Models" arXiv:2010.02502 """ model: DiffusionModel def __call__(self, integrator_state: IntegratorState, predictor: Predictor) -> IntegratorState: r"""Perform one DDIM step in reverse time. Args: integrator_state: Current state containing (position, rng_key, step) predictor: Predictor providing noise prediction :math:`\varepsilon_\theta(x_t, t)` Returns: Updated IntegratorState with the next position computed using the DDIM update rule: .. math:: x_{t-1} = \sqrt{\alpha_{t-1}} \hat{x}_0 + \sqrt{1 - \alpha_{t-1}} \varepsilon_\theta where: - :math:`\hat{x}_0 = (x_t - \sqrt{1-\alpha_t} \varepsilon_\theta) / \sqrt{\alpha_t}` is the predicted denoised sample - :math:`\varepsilon_\theta` is the predicted noise from the model - :math:`\alpha_t` represents the signal level (cumulative product of :math:`1 - \beta_t`) - :math:`\beta_t` is the forward process noise schedule """ _, rng_key, step = integrator_state _, rng_key_next = jax.random.split(rng_key) position_churned, t_churned = self._churn_fn(integrator_state) t_next = self.timer(step + 1) signal_level_churned = self.model.signal_level(t_churned) signal_level_next = self.model.signal_level(t_next) sigma_churned = self.model.noise_level(t_churned) sigma_next = self.model.noise_level(t_next) eps = predictor.noise(position_churned, t_churned) pred_x0 = (position_churned - sigma_churned * eps) / signal_level_churned position_next = signal_level_next * pred_x0 + sigma_next * eps return IntegratorState(position_next, rng_key_next, step + 1)