Source code for diffuse.base_forward_model

# Copyright 2025 Jacopo Iollo <jacopo.iollo@inria.fr>, Geoffroy Oudoumanessah <geoffroy.oudoumanessah@inria.fr>
# Licensed under the Apache License, Version 2.0 (the "License");
# http://www.apache.org/licenses/LICENSE-2.0
from typing import Protocol, NamedTuple
from jaxtyping import Array


[docs] class MeasurementState(NamedTuple): """Container for measurement information during conditional sampling. Attributes: y: The measured/observed data mask_history: History of measurement masks (for partial observations) """ y: Array mask_history: Array
[docs] class ForwardModel(Protocol): """Protocol defining the interface for forward models in inverse problems. Forward models implement measurement operators and their adjoint operators for conditional generation tasks (e.g., inpainting, super-resolution, denoising). Attributes: std: Standard deviation of measurement noise """ std: float
[docs] def apply(self, img: Array, measurement_state: MeasurementState) -> Array: """Apply the forward measurement operator. Args: img: Input image/data measurement_state: Current measurement state Returns: Measured/degraded output """ ...
[docs] def adjoint(self, meas: Array, measurement_state: MeasurementState) -> Array: """Apply the adjoint of the measurement operator. Args: meas: Array in the measurement space measurement_state: Current measurement state Returns: Array in the original image/data space corresponding to Aᵀ meas """ ...
[docs] def restore(self, img: Array, measurement_state: MeasurementState) -> Array: """Apply the restoration operator associated with the measurement. Args: img: Data to apply adjoint to measurement_state: Current measurement state Returns: Restored output in the original data space """ ...