Source code for diffuse.denoisers.base

# Copyright 2025 Jacopo Iollo <jacopo.iollo@inria.fr>, Geoffroy Oudoumanessah <geoffroy.oudoumanessah@inria.fr>
# Licensed under the Apache License, Version 2.0 (the "License");
# http://www.apache.org/licenses/LICENSE-2.0
from abc import ABC, abstractmethod
from typing import Callable, NamedTuple
from jaxtyping import Array, PRNGKeyArray
from diffuse.integrator.base import IntegratorState


[docs] class DenoiserState(NamedTuple): """Base state for all denoisers""" integrator_state: IntegratorState
[docs] class BaseDenoiser(ABC):
[docs] @abstractmethod def init(self, position: Array, rng_key: PRNGKeyArray, dt: float) -> DenoiserState: """Initialize denoiser state""" pass
[docs] @abstractmethod def step(self, state: DenoiserState, score: Callable[[Array, float], Array]) -> DenoiserState: """Perform single denoising step""" pass
[docs] @abstractmethod def generate(self, rng_key: PRNGKeyArray, measurement_state, n_steps: int, n_particles: int): """Generate samples""" pass