Source code for diffuse.integrator.stochastic

# Copyright 2025 Jacopo Iollo <jacopo.iollo@inria.fr>, Geoffroy Oudoumanessah <geoffroy.oudoumanessah@inria.fr>
# Licensed under the Apache License, Version 2.0 (the "License");
# http://www.apache.org/licenses/LICENSE-2.0
from dataclasses import dataclass

import jax
import jax.numpy as jnp

from diffuse.integrator.base import IntegratorState, Integrator
from diffuse.diffusion.sde import DiffusionModel
from diffuse.predictor import Predictor

__all__ = ["EulerMaruyamaIntegrator"]


[docs] @dataclass class EulerMaruyamaIntegrator(Integrator): r"""Euler-Maruyama stochastic integrator for Stochastic Differential Equations (SDEs). Implements the Euler-Maruyama method for numerical integration of SDEs of the form: .. math:: dX(t) = \mu(X,t)dt + \sigma(X,t)dW(t) where: - :math:`\mu(X,t)` is the drift term: :math:`\beta(t) \cdot (0.5 X + \nabla_x \log p(x|t))` - :math:`\sigma(X,t)` is the diffusion term: :math:`\sqrt{\beta(t)}` - :math:`dW(t)` is the Wiener process increment - :math:`\beta(t)` is the noise schedule The method advances the solution using the discrete approximation: .. math:: X(t + dt) = X(t) + \mu(X,t)dt + \sigma(X,t)\sqrt{dt} \cdot \mathcal{N}(0,1) This is the simplest stochastic integration scheme with strong order 0.5 convergence for general SDEs. Args: model: Diffusion model defining the diffusion process timer: Timer object managing the discretization of the time interval """ model: DiffusionModel def __call__(self, integrator_state: IntegratorState, predictor: Predictor) -> IntegratorState: r"""Perform one Euler-Maruyama integration step. Args: integrator_state: Current state containing (position, rng_key, step) predictor: Predictor providing score function :math:`\nabla_x \log p(x|t)` Returns: Updated IntegratorState with the next position Notes: The integration step implements: .. math:: dx = \text{drift} \cdot dt + \text{diffusion} \cdot \sqrt{dt} \cdot \varepsilon where: - :math:`\text{drift} = \beta(t) \cdot (0.5 \cdot x + \nabla_x \log p(x|t))` - :math:`\text{diffusion} = \sqrt{\beta(t)}` - :math:`\varepsilon \sim \mathcal{N}(0,1)` """ position, rng_key, step = integrator_state t, t_next = self.timer(step), self.timer(step + 1) dt = t - t_next f_t, g_t = self.model.sde_coefficients(t) # For reverse-time: drift = f(t)*x - g(t)^2*score, but rearranged as: g(t)^2 * (0.5*x + score) # Since f(t) = -0.5*beta(t) and g(t) = sqrt(beta(t)), we have beta(t) = g(t)^2 drift = g_t * g_t * (0.5 * position + predictor.score(position, t)) diffusion = g_t noise = jax.random.normal(rng_key, position.shape) * jnp.sqrt(dt) dx = drift * dt + diffusion * noise _, rng_key_next = jax.random.split(rng_key) return IntegratorState(position + dx, rng_key_next, step + 1)