# Copyright 2025 Jacopo Iollo <jacopo.iollo@inria.fr>, Geoffroy Oudoumanessah <geoffroy.oudoumanessah@inria.fr>
# Licensed under the Apache License, Version 2.0 (the "License");
# http://www.apache.org/licenses/LICENSE-2.0
"""Network adapter providing all prediction types (score, noise, velocity, x0) from any trained network.
This module implements general conversions between different diffusion model parameterizations:
1. **Score parameterization**: Predicts the score function ∇log p_t(x)
2. **Noise parameterization**: Predicts the noise ε added during the forward process
3. **Velocity parameterization**: Predicts the velocity field u_t(x) for probability flow ODEs
4. **x0 parameterization**: Predicts the denoised data x̂_0
The conversions use the general SDE formulation dx_t = f(t) x_t dt + g(t) dW_t
where f(t) and g(t) are model-specific coefficients:
- **SDE**: f(t) = -β(t)/2, g(t) = √β(t)
- **Flow**: f(t) = -1/(1-t), g(t) = √(2t/(1-t))
- **EDM**: f(t) = 0, g(t) = 1
Velocity conversion uses the probability flow ODE:
u_t(x) = f(t) x - g(t)²/2 ∇log p_t(x)
"""
from typing import Callable, Dict
from dataclasses import dataclass
from jaxtyping import Array
from diffuse.diffusion.sde import DiffusionModel
# Conversion functions from score to other types
def score_to_noise(score_fn: Callable, model: DiffusionModel) -> Callable:
def noise_fn(x: Array, t: Array) -> Array:
sigma_t = model.noise_level(t)
score = score_fn(x, t)
return -sigma_t * score
return noise_fn
def score_to_velocity(score_fn: Callable, model: DiffusionModel) -> Callable:
"""Convert score function to velocity field using general SDE coefficients.
Uses the probability flow ODE formula:
u_t(x) = f(t) x - g(t)²/2 ∇log p_t(x)
This replaces the previous rectified flow-specific implementation.
"""
def velocity_fn(x: Array, t: Array) -> Array:
score = score_fn(x, t)
f_t, g_t = model.sde_coefficients(t)
# General velocity formula from probability flow ODE
return f_t * x - (g_t * g_t / 2) * score
return velocity_fn
def score_to_x0(score_fn: Callable, model: DiffusionModel) -> Callable:
def x0_fn(x: Array, t: Array) -> Array:
alpha_t = model.signal_level(t)
sigma_t = model.noise_level(t)
score = score_fn(x, t)
return (x + sigma_t * sigma_t * score) / (alpha_t + 1e-8)
return x0_fn
# Conversion functions from noise to other types
def noise_to_score(noise_fn: Callable, model: DiffusionModel) -> Callable:
def score_fn(x: Array, t: Array) -> Array:
sigma_t = model.noise_level(t)
noise = noise_fn(x, t)
return -noise / (sigma_t + 1e-8)
return score_fn
def noise_to_velocity(noise_fn: Callable, model: DiffusionModel) -> Callable:
def velocity_fn(x: Array, t: Array) -> Array:
# Convert noise -> score -> velocity
score_fn = noise_to_score(noise_fn, model)
return score_to_velocity(score_fn, model)(x, t)
return velocity_fn
def noise_to_x0(noise_fn: Callable, model: DiffusionModel) -> Callable:
def x0_fn(x: Array, t: Array) -> Array:
alpha_t = model.signal_level(t)
sigma_t = model.noise_level(t)
noise = noise_fn(x, t)
return (x - sigma_t * noise) / (alpha_t + 1e-8)
return x0_fn
# Conversion functions from velocity to other types
def velocity_to_score(velocity_fn: Callable, model: DiffusionModel) -> Callable:
"""Convert velocity field to score function using general SDE coefficients.
Inverts the probability flow ODE formula:
∇log p_t(x) = 2(f(t) x - u_t(x)) / g(t)²
This replaces the previous rectified flow-specific implementation.
"""
def score_fn(x: Array, t: Array) -> Array:
v = velocity_fn(x, t)
f_t, g_t = model.sde_coefficients(t)
# Invert the velocity formula: score = 2(f(t) x - u_t(x)) / g(t)²
return 2 * (f_t * x - v) / (g_t * g_t + 1e-8)
return score_fn
def velocity_to_noise(velocity_fn: Callable, model: DiffusionModel) -> Callable:
def noise_fn(x: Array, t: Array) -> Array:
# Convert velocity -> score -> noise
score_fn = velocity_to_score(velocity_fn, model)
return score_to_noise(score_fn, model)(x, t)
return noise_fn
def velocity_to_x0(velocity_fn: Callable, model: DiffusionModel) -> Callable:
def x0_fn(x: Array, t: Array) -> Array:
# Convert velocity -> score -> x0
score_fn = velocity_to_score(velocity_fn, model)
return score_to_x0(score_fn, model)(x, t)
return x0_fn
# Conversion functions from x0 to other types
def x0_to_score(x0_fn: Callable, model: DiffusionModel) -> Callable:
def score_fn(x: Array, t: Array) -> Array:
x0_pred = x0_fn(x, t)
alpha_t = model.signal_level(t)
sigma_t = model.noise_level(t)
return (alpha_t * x0_pred - x) / (sigma_t * sigma_t + 1e-8)
return score_fn
def x0_to_noise(x0_fn: Callable, model: DiffusionModel) -> Callable:
def noise_fn(x: Array, t: Array) -> Array:
x0_pred = x0_fn(x, t)
alpha_t = model.signal_level(t)
sigma_t = model.noise_level(t)
return (x - alpha_t * x0_pred) / (sigma_t + 1e-8)
return noise_fn
def x0_to_velocity(x0_fn: Callable, model: DiffusionModel) -> Callable:
def velocity_fn(x: Array, t: Array) -> Array:
# Convert x0 -> score -> velocity
score_fn = x0_to_score(x0_fn, model)
return score_to_velocity(score_fn, model)(x, t)
return velocity_fn
# Identity functions
def identity(fn: Callable, model: DiffusionModel) -> Callable:
return fn
# Registry of all conversion functions
CONVERSIONS: Dict[str, Dict[str, Callable]] = {
"score": {
"score": identity,
"noise": score_to_noise,
"velocity": score_to_velocity,
"x0": score_to_x0,
},
"noise": {
"score": noise_to_score,
"noise": identity,
"velocity": noise_to_velocity,
"x0": noise_to_x0,
},
"velocity": {
"score": velocity_to_score,
"noise": velocity_to_noise,
"velocity": identity,
"x0": velocity_to_x0,
},
"x0": {
"score": x0_to_score,
"noise": x0_to_noise,
"velocity": x0_to_velocity,
"x0": identity,
},
}
[docs]
@dataclass
class Predictor:
"""Network adapter providing all prediction types (score, noise, velocity, x0).
Automatically converts between different diffusion model parameterizations:
- **Score**: Predicts :math:`\\nabla \\log p_t(x)`
- **Noise**: Predicts :math:`\\varepsilon` added during forward process
- **Velocity**: Predicts velocity field :math:`u_t(x)` for probability flow ODEs
- **x0**: Predicts denoised data :math:`\\hat{x}_0`
Args:
model: Diffusion model defining the diffusion process
network: The trained neural network (e.g., UNet)
prediction_type: Type of prediction the network outputs ("score", "noise", "velocity", or "x0")
"""
model: DiffusionModel
network: Callable
prediction_type: str
def __post_init__(self):
if self.prediction_type not in CONVERSIONS:
available = ", ".join(CONVERSIONS.keys())
raise ValueError(f"Unknown prediction type '{self.prediction_type}'. Available: {available}")
# Cache converted functions
self._score_fn = CONVERSIONS[self.prediction_type]["score"](self.network, self.model)
self._noise_fn = CONVERSIONS[self.prediction_type]["noise"](self.network, self.model)
self._velocity_fn = CONVERSIONS[self.prediction_type]["velocity"](self.network, self.model)
self._x0_fn = CONVERSIONS[self.prediction_type]["x0"](self.network, self.model)
[docs]
def score(self, x: Array, t: Array) -> Array:
r"""Get score function :math:`\nabla \log p_t(x)`.
Args:
x: Current state
t: Current time
Returns:
Score prediction
"""
return self._score_fn(x, t)
[docs]
def noise(self, x: Array, t: Array) -> Array:
r"""Get noise prediction :math:`\varepsilon_\theta(x,t)`.
Args:
x: Current state
t: Current time
Returns:
Noise prediction
"""
return self._noise_fn(x, t)
[docs]
def velocity(self, x: Array, t: Array) -> Array:
r"""Get velocity field :math:`u_t(x)`.
Args:
x: Current state
t: Current time
Returns:
Velocity prediction
"""
return self._velocity_fn(x, t)
[docs]
def x0(self, x: Array, t: Array) -> Array:
r"""Get denoised prediction :math:`\hat{x}_0(x,t)`.
Args:
x: Current state
t: Current time
Returns:
Denoised data prediction
"""
return self._x0_fn(x, t)