Source code for diffuse.timer.base

# Copyright 2025 Jacopo Iollo <jacopo.iollo@inria.fr>, Geoffroy Oudoumanessah <geoffroy.oudoumanessah@inria.fr>
# Licensed under the Apache License, Version 2.0 (the "License");
# http://www.apache.org/licenses/LICENSE-2.0
from dataclasses import dataclass

import jax
import jax.numpy as jnp


[docs] @dataclass class Timer: """Base Timer class for scheduling time steps in diffusion processes. Args: n_steps: Number of discrete time steps """ n_steps: int def __call__(self, step: int) -> float: ...
[docs] @dataclass class VpTimer(Timer): """Variance Preserving Timer that implements linear interpolation between final and initial time. Args: n_steps: Number of discrete time steps eps: Initial time value tf: Final time value """ eps: float tf: float def __call__(self, step: int) -> float: """Compute time value for given step. Args: step (int): Current step number Returns: float: Interpolated time value between tf and eps """ return self.tf + step / self.n_steps * (self.eps - self.tf)
[docs] @dataclass class HeunTimer(Timer): """Heun Timer implementing power-law scaling of noise levels. This timer discretizes noise space rather than time space, using a power-law relationship to schedule noise levels. It is designed to be used with sampling methods that are defined on noise levels (like EDM - Elucidating the Design Space of Diffusion-Based Generative Models) rather than time-based approaches. Args: n_steps: Number of discrete time steps rho: Power scaling factor (default: 7.0) sigma_min: Minimum noise level (default: 0.002) sigma_max: Maximum noise level (default: 0.002) """ rho: float = 7.0 sigma_min: float = 0.002 sigma_max: float = 80.0 def __call__(self, step: int) -> float: """Compute noise level for given step using power-law scaling. Args: step (int): Current step number Returns: float: Noise level at current step """ sigma_max_rho = self.sigma_max ** (1 / self.rho) sigma_min_rho = self.sigma_min ** (1 / self.rho) return (sigma_max_rho + step / self.n_steps * (sigma_min_rho - sigma_max_rho)) ** self.rho
[docs] @dataclass class DDIMTimer(Timer): """Denoising Diffusion Implicit Models (DDIM) Timer. Implements custom time scheduling for DDIM as described in https://arxiv.org/pdf/2206.00364. Uses a power-law interpolation between c_1 and c_2 with exponent j0. Args: n_steps (int): Number of discrete time steps n_time_training (int): Number of training timesteps c_1 (float, optional): Lower bound parameter. Defaults to 0.001 c_2 (float, optional): Upper bound parameter. Defaults to 0.008 j0 (int, optional): Power-law exponent. Defaults to 8 """ n_time_training: int c_1: float = 0.001 c_2: float = 0.008 j0: int = 8 def __post_init__(self): def body_fun(u, i): alpha = self._alpha(i) alpha_next = self._alpha(i - 1) maxi = jnp.maximum(alpha_next / alpha, self.c_1) u_next = jnp.sqrt((u**2 + 1) / maxi - 1) return u_next, u_next indices = jnp.arange(self.n_time_training, 0, -1) _, self.u_list = jax.lax.scan(body_fun, 0.0, indices) def __call__(self, step: int) -> float: """Compute time value for given step using DDIM scheduling. Args: step (int): Current step number Returns: float: Time value at current step """ j = jnp.floor(self.j0 + (self.n_time_training - 1 - self.j0) * step / (self.n_steps - 1) + 0.5).astype(int).item() return self.u_list[j] def _alpha(self, j: int) -> float: return jnp.sin(0.5 * jnp.pi * j / (self.n_time_training * (self.c_2 + 1))) ** 2
if __name__ == "__main__": timer = DDIMTimer(n_steps=100, n_time_training=1000, c_1=0.001, c_2=0.008, j0=8) print(timer(0)) print(timer(50)) print(timer(100))