Forward Models#

Predictor#

class diffuse.predictor.Predictor(model: DiffusionModel, network: Callable, prediction_type: str)[source]#

Bases: object

Network adapter providing all prediction types (score, noise, velocity, x0).

Automatically converts between different diffusion model parameterizations:

  • Score: Predicts \(\nabla \log p_t(x)\)

  • Noise: Predicts \(\varepsilon\) added during forward process

  • Velocity: Predicts velocity field \(u_t(x)\) for probability flow ODEs

  • x0: Predicts denoised data \(\hat{x}_0\)

Parameters:
  • model – Diffusion model defining the diffusion process

  • network – The trained neural network (e.g., UNet)

  • prediction_type – Type of prediction the network outputs (“score”, “noise”, “velocity”, or “x0”)

noise(x: Array, t: Array) Array[source]#

Get noise prediction \(\varepsilon_\theta(x,t)\).

Parameters:
  • x – Current state

  • t – Current time

Returns:

Noise prediction

score(x: Array, t: Array) Array[source]#

Get score function \(\nabla \log p_t(x)\).

Parameters:
  • x – Current state

  • t – Current time

Returns:

Score prediction

velocity(x: Array, t: Array) Array[source]#

Get velocity field \(u_t(x)\).

Parameters:
  • x – Current state

  • t – Current time

Returns:

Velocity prediction

x0(x: Array, t: Array) Array[source]#

Get denoised prediction \(\hat{x}_0(x,t)\).

Parameters:
  • x – Current state

  • t – Current time

Returns:

Denoised data prediction

Forward Model Protocol#

class diffuse.base_forward_model.ForwardModel(*args, **kwargs)[source]#

Bases: Protocol

Protocol defining the interface for forward models in inverse problems.

Forward models implement measurement operators and their adjoint operators for conditional generation tasks (e.g., inpainting, super-resolution, denoising).

std#

Standard deviation of measurement noise

Type:

float

adjoint(meas: Array, measurement_state: MeasurementState) Array[source]#

Apply the adjoint of the measurement operator.

Parameters:
  • meas – Array in the measurement space

  • measurement_state – Current measurement state

Returns:

Array in the original image/data space corresponding to Aᵀ meas

apply(img: Array, measurement_state: MeasurementState) Array[source]#

Apply the forward measurement operator.

Parameters:
  • img – Input image/data

  • measurement_state – Current measurement state

Returns:

Measured/degraded output

restore(img: Array, measurement_state: MeasurementState) Array[source]#

Apply the restoration operator associated with the measurement.

Parameters:
  • img – Data to apply adjoint to

  • measurement_state – Current measurement state

Returns:

Restored output in the original data space

class diffuse.base_forward_model.MeasurementState(y: Array, mask_history: Array)[source]#

Container for measurement information during conditional sampling.

y#

The measured/observed data

Type:

jax.Array

mask_history#

History of measurement masks (for partial observations)

Type:

jax.Array