# Copyright 2025 Jacopo Iollo <jacopo.iollo@inria.fr>, Geoffroy Oudoumanessah <geoffroy.oudoumanessah@inria.fr>
# Licensed under the Apache License, Version 2.0 (the "License");
# http://www.apache.org/licenses/LICENSE-2.0
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Callable, NamedTuple, Union
import jax
import jax.numpy as jnp
from jaxtyping import Array, PRNGKeyArray
class SDEState(NamedTuple):
position: Array
t: Array
class Schedule(ABC):
T: float
@abstractmethod
def __call__(self, t: Array) -> Array:
pass
@abstractmethod
def integrate(self, t: Array, s: Array) -> Array:
pass
[docs]
@dataclass
class LinearSchedule:
r"""Linear noise schedule for diffusion processes.
Implements a linear interpolation between minimum and maximum noise levels:
.. math::
\beta(t) = \beta_{\min} + \frac{\beta_{\max} - \beta_{\min}}{T - t_0}(t - t_0)
Args:
b_min: The minimum noise value :math:`\beta_{\min}`
b_max: The maximum noise value :math:`\beta_{\max}`
t0: The starting time :math:`t_0`
T: The ending time :math:`T`
"""
b_min: float
b_max: float
t0: float
T: float
def __call__(self, t: Array) -> Array:
r"""Evaluate the linear schedule at time t.
Args:
t: Time at which to evaluate the schedule
Returns:
Schedule value :math:`\beta(t)`
"""
b_min, b_max, t0, T = self.b_min, self.b_max, self.t0, self.T
return (b_max - b_min) / (T - t0) * t + (b_min * T - b_max * t0) / (T - t0)
[docs]
def integrate(self, t: Array, s: Array) -> Array:
r"""Compute integral :math:`\int_s^t \beta(\tau) d\tau`.
Args:
t: Upper integration bound
s: Lower integration bound
Returns:
Integral value
"""
b_min, b_max, t0, T = self.b_min, self.b_max, self.t0, self.T
slope = (b_max - b_min) / (T - t0)
intercept = (b_min * T - b_max * t0) / (T - t0)
return 0.5 * (t - s) * (slope * (t + s) + 2 * intercept)
[docs]
@dataclass
class CosineSchedule(Schedule):
r"""Cosine noise schedule for improved denoising.
Implements the cosine schedule from Nichol & Dhariwal (2021) which provides
better signal-to-noise ratio properties than linear schedules. The schedule is
based on:
.. math::
\bar{\alpha}(t) = \frac{\cos\left(\frac{t/T + s}{1+s} \cdot \frac{\pi}{2}\right)^2}{\cos\left(\frac{s}{1+s} \cdot \frac{\pi}{2}\right)^2}
Args:
b_min: The minimum beta value :math:`\beta_{\min}`
b_max: The maximum beta value :math:`\beta_{\max}`
t0: The starting time :math:`t_0`
T: The ending time :math:`T`
s: Offset parameter for numerical stability (default: 0.008)
References:
Nichol, A., & Dhariwal, P. (2021). Improved Denoising Diffusion Probabilistic Models.
arXiv:2102.09672
"""
b_min: float
b_max: float
t0: float
T: float
s: float = 0.008
def __call__(self, t: Array) -> Array:
r"""Evaluate the cosine schedule at time t.
Args:
t: Time at which to evaluate the schedule
Returns:
Schedule value :math:`\beta(t)` clipped to [:math:`\beta_{\min}`, :math:`\beta_{\max}`]
"""
t_normalized = (t - self.t0) / (self.T - self.t0)
beta_t = jnp.pi * jnp.tan(0.5 * jnp.pi * (t_normalized + self.s) / (1 + self.s)) / (self.T * (1 + self.s))
beta_t = jnp.clip(beta_t, self.b_min, self.b_max)
return beta_t
[docs]
def integrate(self, t: Array, s: Array) -> Array:
r"""Compute integral :math:`\int_s^t \beta(\tau) d\tau` using :math:`\bar{\alpha}` values.
Returns :math:`\log(\bar{\alpha}(s) / \bar{\alpha}(t))`
Args:
t: Upper integration bound
s: Lower integration bound
Returns:
Integral value
"""
time_scale = self.T - self.t0
offset_scale = 1 + self.s
t_norm = (t - self.t0) / time_scale
s_norm = (s - self.t0) / time_scale
f0 = jnp.cos(self.s / offset_scale * jnp.pi * 0.5) ** 2
ft = jnp.cos((t_norm + self.s) / offset_scale * jnp.pi * 0.5) ** 2
fs = jnp.cos((s_norm + self.s) / offset_scale * jnp.pi * 0.5) ** 2
alpha_t = jnp.clip(ft / f0, 0.001, 0.9999)
alpha_s = jnp.clip(fs / f0, 0.001, 0.9999)
return jnp.log(alpha_s / alpha_t)
class DiffusionModel(ABC):
@abstractmethod
def noise_level(self, t: Array) -> Array:
pass
@abstractmethod
def signal_level(self, t: Array) -> Array:
pass
@abstractmethod
def sde_coefficients(self, t: Array) -> tuple[Array, Array]:
"""Compute SDE coefficients f(t) and g(t) for dx_t = f(t) x_t dt + g(t) dW_t."""
pass
def snr(self, t: Array) -> Array:
"""
Compute Signal-to-Noise Ratio (SNR) at timestep t.
For general interpolation x_t = α_t x_0 + σ_t ε:
SNR(t) = α_t² / σ_t²
"""
noise_level = self.noise_level(t)
signal_level = self.signal_level(t)
return (signal_level * signal_level) / (noise_level * noise_level + 1e-8)
def score(self, state: SDEState, state_0: SDEState) -> Array:
"""
Closed-form expression for the score function ∇ₓ log p(xₜ | x₀) of the Gaussian transition kernel.
From docs: ∇log p_t(x_t|x_0) = -1/σ_t² (x_t - α_t x_0)
"""
x, t = state.position, state.t
x0, _t0 = state_0.position, state_0.t
sigma_t = self.noise_level(t)
signal_level_t = self.signal_level(t)
return -(x - signal_level_t * x0) / (sigma_t * sigma_t)
def tweedie(self, state: SDEState, score_fn: Callable) -> SDEState:
"""
Tweedie's formula to compute E[x_0 | x_t].
From docs: x̂_0 = 1/α_t (x_t + σ_t² ∇log p_t(x_t))
"""
x, t = state.position, state.t
sigma_t = self.noise_level(t)
signal_level_t = self.signal_level(t)
return SDEState((x + sigma_t * sigma_t * score_fn(x, t)) / signal_level_t, jnp.zeros_like(t))
def path(self, key: PRNGKeyArray, state: SDEState, ts: Array, return_noise: bool = False) -> Union[SDEState, tuple[SDEState, Array]]:
"""
Samples from the general interpolation: x_t = α_t x_0 + σ_t ε
"""
x = state.position
sigma_t = self.noise_level(ts)
signal_level_t = self.signal_level(ts)
noise = jax.random.normal(key, x.shape, dtype=x.dtype)
res = signal_level_t * x + sigma_t * noise
return (SDEState(res, ts), noise) if return_noise else SDEState(res, ts)
[docs]
@dataclass
class SDE(DiffusionModel):
r"""Variance Preserving (VP) SDE for diffusion models.
Implements the forward SDE:
.. math::
dX(t) = -\frac{1}{2}\beta(t) X(t) dt + \sqrt{\beta(t)} dW(t)
where :math:`\beta(t)` is the noise schedule and :math:`dW(t)` is the Wiener process.
This formulation preserves the variance of the data distribution and is the
standard choice for diffusion probabilistic models.
Args:
beta: Noise schedule (LinearSchedule or CosineSchedule)
"""
beta: Schedule
def __post_init__(self):
self.tf = self.beta.T
[docs]
def sde_coefficients(self, t: Array) -> tuple[Array, Array]:
r"""Compute SDE coefficients :math:`f(t)` and :math:`g(t)`.
For the VP-SDE: :math:`dX(t) = -\frac{1}{2}\beta(t) X(t) dt + \sqrt{\beta(t)} dW(t)`
Returns:
Tuple of drift coefficient :math:`f(t) = -\frac{1}{2}\beta(t)` and
diffusion coefficient :math:`g(t) = \sqrt{\beta(t)}`
"""
beta_t = self.beta(t)
f_t = -0.5 * beta_t
g_t = jnp.sqrt(beta_t)
return f_t, g_t
[docs]
def noise_level(self, t: Array) -> Array:
r"""Compute noise level :math:`\sigma(t)` for diffusion process.
The solution to the VP-SDE is:
.. math::
X(t) = \alpha(t) X_0 + \sigma(t) \varepsilon, \quad \varepsilon \sim \mathcal{N}(0, I)
where :math:`\alpha(t) = \exp\left(-\frac{1}{2}\int_0^t \beta(s) ds\right)` and
:math:`\sigma(t) = \sqrt{1 - \alpha^2(t)}`
Returns:
Noise level :math:`\sigma(t)` clipped for numerical stability
"""
alpha = jnp.exp(-self.beta.integrate(t, jnp.zeros_like(t)))
sigma = jnp.sqrt(1 - alpha)
sigma = jnp.clip(sigma, 0.001, 0.9999)
return sigma
[docs]
def signal_level(self, t: Array) -> Array:
r"""Compute signal level :math:`\alpha(t) = \exp\left(-\frac{1}{2}\int_0^t \beta(s) ds\right)`.
Returns:
Signal level clipped for numerical stability
"""
alpha = jnp.sqrt(jnp.exp(-self.beta.integrate(t, jnp.zeros_like(t))))
alpha = jnp.clip(alpha, 0.001, 0.9999)
return alpha
[docs]
@dataclass
class Flow(DiffusionModel):
r"""Rectified Flow diffusion model with straight-line interpolation paths.
Implements the rectified flow formulation from Liu et al. (2022) with linear schedules:
.. math::
\alpha(t) = 1 - t, \quad \sigma(t) = t
This creates straight-line interpolation paths:
.. math::
x_t = (1-t)x_0 + t\varepsilon, \quad \varepsilon \sim \mathcal{N}(0, I)
which are more amenable to ODE-based sampling with fewer discretization steps.
Args:
tf: Final time for the diffusion process (default: 1.0)
References:
Liu, X., Gong, C., & Liu, Q. (2022). Flow straight and fast: Learning to
generate and transfer data with rectified flow. arXiv:2209.03003
"""
tf: float = 1.0
[docs]
def noise_level(self, t: Array) -> Array:
r"""Compute noise level :math:`\sigma(t) = t`.
Returns:
Noise level clipped for numerical stability
"""
return jnp.clip(t / self.tf, 0.001, 0.999)
[docs]
def signal_level(self, t: Array) -> Array:
r"""Compute signal level :math:`\alpha(t) = 1 - t`.
Returns:
Signal level clipped for numerical stability
"""
return jnp.clip(1 - t / self.tf, 0.001, 0.999)
[docs]
def sde_coefficients(self, t: Array) -> tuple[Array, Array]:
r"""Compute SDE coefficients for rectified flow.
Returns drift :math:`f(t) = -\frac{1}{1-t}` and diffusion :math:`g(t) = \sqrt{\frac{2t}{1-t}}`
Returns:
Tuple of drift and diffusion coefficients
"""
t_safe = jnp.clip(t / self.tf, 0.001, 0.999)
f_t = -1.0 / (1 - t_safe)
g_t = jnp.sqrt(2 * t_safe / (1 - t_safe))
return f_t, g_t
[docs]
@dataclass
class EDM(DiffusionModel):
r"""Efficient Diffusion Model (EDM) from Karras et al. (2022).
Implements the EDM formulation with constant signal and increasing noise:
.. math::
\alpha(t) = 1, \quad \sigma(t) = t
This creates the simple interpolation:
.. math::
x_t = x_0 + t\varepsilon, \quad \varepsilon \sim \mathcal{N}(0, I)
which simplifies the probability-flow ODE and is particularly effective
with Heun's integration method.
Args:
tf: Final time for the diffusion process (default: 1.0)
References:
Karras, T., Aittala, M., Aila, T., & Laine, S. (2022). Elucidating the
design space of diffusion-based generative models. NeurIPS 35, 26565-26577.
"""
tf: float = 1.0
[docs]
def noise_level(self, t: Array) -> Array:
r"""Compute noise level :math:`\sigma(t) = t`.
Returns:
Noise level clipped for numerical stability
"""
return jnp.clip(t, 0.001, 0.999)
[docs]
def signal_level(self, t: Array) -> Array:
r"""Compute signal level :math:`\alpha(t) = 1`.
Returns:
Constant signal level of 1
"""
return jnp.ones_like(t)
[docs]
def sde_coefficients(self, t: Array) -> tuple[Array, Array]:
r"""Compute SDE coefficients for EDM.
Returns drift :math:`f(t) = 0` and diffusion :math:`g(t) = 1`
Returns:
Tuple of zero drift and unit diffusion coefficients
"""
f_t = jnp.zeros_like(t)
g_t = jnp.ones_like(t)
return f_t, g_t
def check_snr(model: DiffusionModel, t: Array, tolerance: float = 1e-3) -> Array:
"""
Check if SNR at timestep t is effectively zero.
Args:
model: DiffusionModel instance
t: Timestep to check
tolerance: Tolerance for considering SNR as zero
Returns:
True if SNR is effectively zero
"""
return jnp.all(model.snr(t) < tolerance)