Source code for diffuse.base_forward_model
# Copyright 2025 Jacopo Iollo <jacopo.iollo@inria.fr>, Geoffroy Oudoumanessah <geoffroy.oudoumanessah@inria.fr>
# Licensed under the Apache License, Version 2.0 (the "License");
# http://www.apache.org/licenses/LICENSE-2.0
from typing import Protocol, NamedTuple
from jaxtyping import Array
[docs]
class MeasurementState(NamedTuple):
"""Container for measurement information during conditional sampling.
Attributes:
y: The measured/observed data
mask_history: History of measurement masks (for partial observations)
"""
y: Array
mask_history: Array
[docs]
class ForwardModel(Protocol):
"""Protocol defining the interface for forward models in inverse problems.
Forward models implement measurement operators and their adjoint operators
for conditional generation tasks (e.g., inpainting, super-resolution, denoising).
Attributes:
std: Standard deviation of measurement noise
"""
std: float
[docs]
def apply(self, img: Array, measurement_state: MeasurementState) -> Array:
"""Apply the forward measurement operator.
Args:
img: Input image/data
measurement_state: Current measurement state
Returns:
Measured/degraded output
"""
...
[docs]
def adjoint(self, meas: Array, measurement_state: MeasurementState) -> Array:
"""Apply the adjoint of the measurement operator.
Args:
meas: Array in the measurement space
measurement_state: Current measurement state
Returns:
Array in the original image/data space corresponding to Aᵀ meas
"""
...
[docs]
def restore(self, img: Array, measurement_state: MeasurementState) -> Array:
"""Apply the restoration operator associated with the measurement.
Args:
img: Data to apply adjoint to
measurement_state: Current measurement state
Returns:
Restored output in the original data space
"""
...