Source code for diffuse.denoisers.pnp_denoiser

# Copyright 2025 Jacopo Iollo <jacopo.iollo@inria.fr>, Geoffroy Oudoumanessah <geoffroy.oudoumanessah@inria.fr>
# Licensed under the Apache License, Version 2.0 (the "License");
# http://www.apache.org/licenses/LICENSE-2.0
from dataclasses import dataclass
from typing import Callable, Tuple, NamedTuple

from jaxtyping import Array, PRNGKeyArray

from diffuse.integrator.base import Integrator
from diffuse.diffusion.sde import SDE, SDEState
from diffuse.base_forward_model import ForwardModel, MeasurementState


[docs] class PnPDenoiserState(NamedTuple): position: Array auxiliary: Array log_weights: Array
[docs] @dataclass class PnPDenoiser: """Conditional denoiser implementation""" # Required attributes from base class integrator: Integrator sde: SDE score: Callable[[Array, float], Array] forward_model: ForwardModel _resample: bool = False
[docs] def init(self, position: Array, rng_key: PRNGKeyArray, dt: float) -> PnPDenoiserState: """Initialize denoiser state""" pass
[docs] def step(self, state: PnPDenoiserState, score: Callable[[Array, float], Array]) -> PnPDenoiserState: """Single step update""" pass
[docs] def batch_step( self, rng_key: PRNGKeyArray, state: PnPDenoiserState, score: Callable[[Array, float], Array], measurement_state: MeasurementState, ) -> PnPDenoiserState: """Batch update step""" pass
[docs] def posterior_logpdf( self, rng_key: PRNGKeyArray, t: float, y_meas: Array, design_mask: Array, ): """Compute posterior log probability density""" pass
[docs] def pooled_posterior_logpdf( self, rng_key: PRNGKeyArray, t: float, y_cntrst: Array, y_past: Array, design: Array, mask_history: Array, ): """Compute pooled posterior log probability density""" pass
[docs] def y_noiser(self, mask: Array, key: PRNGKeyArray, state: SDEState, ts: float) -> SDEState: """Add noise to measurements""" pass
def _resampling(self, position: Array, log_weights: Array, rng_key: PRNGKeyArray) -> Tuple[Array, Array]: """Resample particles based on weights""" pass