Denoisers#
Unconditional Denoisers#
Base Denoiser#
- class diffuse.denoisers.denoiser.Denoiser(integrator: Integrator, model: DiffusionModel, predictor: Predictor, x0_shape: Tuple[int, ...])[source]#
Bases:
BaseDenoiserDenoiser for generating samples using reverse diffusion.
- Parameters:
integrator – The integrator to use for solving the reverse SDE
model – The diffusion model (SDE) defining the forward process
predictor – The predictor for computing the score/denoised estimate
x0_shape – Shape of the data samples (excluding batch dimension)
- generate(rng_key: Key[Array, ''] | UInt32[Array, '2'], n_steps: int, n_particles: int, keep_history: bool = False, data_sharding: Any | None = None) Tuple[DenoiserState, Array | None][source]#
Generate denoised samples \(x_0\).
- Parameters:
rng_key – Random key for initialization
n_steps – Number of denoising steps to perform
n_particles – Number of samples to generate (batch size)
keep_history – If True, return the full trajectory of samples
data_sharding – Optional JAX sharding specification for distributed computation
- Returns:
Tuple of (final_state, history), where history is None if keep_history=False
- init(position: Array, rng_key: Key[Array, ''] | UInt32[Array, '2']) DenoiserState[source]#
Initialize denoiser state
- integrator: Integrator#
- model: DiffusionModel#
- step(state: DenoiserState) DenoiserState[source]#
Perform one denoising step.
Sample \(x_{t-1} \sim p(x_{t-1} | x_t)\)
- Parameters:
state – Current denoiser state
- Returns:
Updated denoiser state at the previous time step
Plug-and-Play Denoiser#
- class diffuse.denoisers.pnp_denoiser.PnPDenoiser(integrator: Integrator, sde: SDE, score: Callable[[Array, float], Array], forward_model: ForwardModel, _resample: bool = False)[source]#
Bases:
objectConditional denoiser implementation
- batch_step(rng_key: Key[Array, ''] | UInt32[Array, '2'], state: PnPDenoiserState, score: Callable[[Array, float], Array], measurement_state: MeasurementState) PnPDenoiserState[source]#
Batch update step
- forward_model: ForwardModel#
- init(position: Array, rng_key: Key[Array, ''] | UInt32[Array, '2'], dt: float) PnPDenoiserState[source]#
Initialize denoiser state
- integrator: Integrator#
- pooled_posterior_logpdf(rng_key: Key[Array, ''] | UInt32[Array, '2'], t: float, y_cntrst: Array, y_past: Array, design: Array, mask_history: Array)[source]#
Compute pooled posterior log probability density
- posterior_logpdf(rng_key: Key[Array, ''] | UInt32[Array, '2'], t: float, y_meas: Array, design_mask: Array)[source]#
Compute posterior log probability density
- step(state: PnPDenoiserState, score: Callable[[Array, float], Array]) PnPDenoiserState[source]#
Single step update
Base Classes#
- class diffuse.denoisers.base.BaseDenoiser[source]#
Bases:
ABC- abstractmethod generate(rng_key: Key[Array, ''] | UInt32[Array, '2'], measurement_state, n_steps: int, n_particles: int)[source]#
Generate samples
- abstractmethod init(position: Array, rng_key: Key[Array, ''] | UInt32[Array, '2'], dt: float) DenoiserState[source]#
Initialize denoiser state
- abstractmethod step(state: DenoiserState, score: Callable[[Array, float], Array]) DenoiserState[source]#
Perform single denoising step
- class diffuse.denoisers.base.DenoiserState(integrator_state: IntegratorState)[source]#
Bases:
NamedTupleBase state for all denoisers
- integrator_state: IntegratorState#
Alias for field number 0
Utilities#
- diffuse.denoisers.utils.resample_particles(position: Array, log_weights: Array, rng_key: Array, ess_low: float = 0.2, ess_high: float = 0.5) Tuple[Array, Array][source]#
Internal function to perform the actual resampling given the weights.
- Parameters:
position – Current particle positions
log_weights – Log weights of the particles
rng_key – Random number generator key
ess_low – Lower threshold for ESS
ess_high – Upper threshold for ESS
- Returns:
Tuple of (resampled_position, normalized_log_weights)
- diffuse.denoisers.utils.weights_tweedie(state_next, measurement_state: MeasurementState, rng_key: Array, sde, score_fn, forward_model, ess_low: float = 0.2, ess_high: float = 0.5) Tuple[Array, Array][source]#
Compute weight with Tweedie’s formula and resample particles.
- Parameters:
state_next – The current state of the particles
measurement_state – The measurement state containing observations
rng_key – Random number generator key
sde – The SDE object
score_fn – The score function
forward_model – The forward model
ess_low – Lower threshold for ESS (default: 0.2)
ess_high – Upper threshold for ESS (default: 0.5)
- Returns:
Tuple of (position, log_weights)