# Copyright 2025 Jacopo Iollo <jacopo.iollo@inria.fr>, Geoffroy Oudoumanessah <geoffroy.oudoumanessah@inria.fr>
# Licensed under the Apache License, Version 2.0 (the "License");
# http://www.apache.org/licenses/LICENSE-2.0
from dataclasses import dataclass
from functools import partial
from jaxtyping import Array, PRNGKeyArray
from typing import NamedTuple, Tuple
import jax
import jax.numpy as jnp
from diffuse.diffusion.sde import DiffusionModel
from diffuse.timer.base import Timer
from diffuse.predictor import Predictor
__all__ = ["Integrator", "IntegratorState", "ChurnedIntegrator"]
[docs]
class IntegratorState(NamedTuple):
"""State container for numerical integrators.
Attributes:
position: Current state vector/tensor
rng_key: JAX random number generator key
step: Current integration step counter (default: 0)
"""
position: Array
rng_key: PRNGKeyArray
step: int = 0
[docs]
@dataclass
class Integrator:
"""Base class for numerical integrators of diffusion processes.
Provides the basic interface for implementing various numerical integration
schemes for both deterministic and stochastic differential equations.
Args:
model: Diffusion model defining the diffusion process
timer: Timer object managing the discretization of the time interval
"""
model: DiffusionModel
timer: Timer
[docs]
def init(self, position: Array, rng_key: PRNGKeyArray) -> IntegratorState:
"""Initialize the integrator state.
Args:
position: Initial state vector/tensor
rng_key: JAX random number generator key
Returns:
Initial IntegratorState
"""
return IntegratorState(position, rng_key)
def __call__(self, integrator_state: IntegratorState, predictor: Predictor) -> IntegratorState:
"""Perform one integration step.
Args:
integrator_state: Current state of the integration
predictor: Neural network predictor providing score/noise/velocity/x0 predictions
Returns:
Updated IntegratorState
"""
...
[docs]
@dataclass
class ChurnedIntegrator(Integrator):
"""Integrator with stochastic churning for improved sampling.
Implements the stochastic churning mechanism that can help improve sampling
quality by occasionally injecting controlled noise into the process.
Args:
model: Diffusion model defining the diffusion process
timer: Timer object managing the discretization of the time interval
stochastic_churn_rate: Rate of applying stochastic churning (default: 0.0)
churn_min: Minimum time threshold for churning (default: 0.0)
churn_max: Maximum time threshold for churning (default: 0.0)
noise_inflation_factor: Factor to scale injected noise (default: 1.0)
"""
stochastic_churn_rate: float = 0.0
churn_min: float = 0.0
churn_max: float = 0.0
noise_inflation_factor: float = 1.0
def _churn_fn(self, integrator_state: IntegratorState) -> Tuple[Array, float]:
"""Apply stochastic churning to the current state.
Args:
integrator_state: Current integration state
Returns:
Tuple of (churned_position, churned_time)
"""
position, _, step = integrator_state
t = self.timer(step)
_apply_stochastic_churn = partial(
apply_stochastic_churn,
stochastic_churn_rate=self.stochastic_churn_rate,
churn_min=self.churn_min,
churn_max=self.churn_max,
noise_inflation_factor=self.noise_inflation_factor,
model=self.model,
timer=self.timer,
)
position_churned, t_churned = jax.lax.cond(
self.stochastic_churn_rate > 0,
_apply_stochastic_churn,
lambda _: (position, t),
integrator_state,
)
return position_churned, t_churned
def next_churn_noise_level(
t: float,
stochastic_churn_rate: float,
churn_min: float,
churn_max: float,
timer: Timer,
) -> float:
"""Compute the next noise level for stochastic churning.
Determines the appropriate noise level based on the current time and churning
parameters, ensuring the noise stays within specified bounds.
Args:
t: Current time
stochastic_churn_rate: Rate of stochastic churning
churn_min: Minimum time threshold for churning
churn_max: Maximum time threshold for churning
timer: Timer object managing time discretization
Returns:
Next noise level for churning
"""
churn_rate = jnp.where(
stochastic_churn_rate / timer.n_steps - jnp.sqrt(2) + 1 > 0,
jnp.sqrt(2) - 1,
stochastic_churn_rate / timer.n_steps,
)
churn_rate = jnp.where(t > churn_min, jnp.where(t < churn_max, churn_rate, 0), 0)
t_churned = t * (1 + churn_rate)
# Ensure churned time doesn't exceed timer bounds
return jnp.minimum(t_churned, timer.tf)
def apply_stochastic_churn(
integrator_state: IntegratorState,
stochastic_churn_rate: float,
churn_min: float,
churn_max: float,
noise_inflation_factor: float,
model: DiffusionModel,
timer: Timer,
) -> Tuple[Array, float]:
"""Apply stochastic churning to the current sample.
Implements the stochastic churning mechanism by:
1. Computing the next noise level
2. Adjusting the position using the noise schedule
3. Injecting scaled random noise
Args:
integrator_state: Current integration state
stochastic_churn_rate: Rate of stochastic churning
churn_min: Minimum time threshold for churning
churn_max: Maximum time threshold for churning
noise_inflation_factor: Factor to scale injected noise
model: DiffusionModel object defining the diffusion process
timer: Timer object managing time discretization
Returns:
Tuple of (churned_position, churned_time)
Notes:
The churning process follows:
x_churned = sqrt(α_churned/α) * x + sqrt(1 - α_churned/α) * ε * noise_factor
where:
- α values are computed from the noise schedule
- ε is standard normal noise
"""
position, rng_key, step = integrator_state
t = timer(step)
t_churned = next_churn_noise_level(t, stochastic_churn_rate, churn_min, churn_max, timer)
alpha = model.signal_level(t)
sigma = model.noise_level(t)
sigma_churned = model.noise_level(t_churned)
alpha_churned = model.signal_level(t_churned)
delta_sigma = jnp.sqrt(jnp.maximum(0.0, sigma_churned**2 - (alpha_churned * sigma / alpha) ** 2))
new_position = (alpha_churned / alpha) * position + jax.random.normal(rng_key, position.shape) * delta_sigma * noise_inflation_factor
return new_position, t_churned