Source code for diffuse.neural_network.nn.params

# Copyright 2025 Jacopo Iollo <jacopo.iollo@inria.fr>, Geoffroy Oudoumanessah <geoffroy.oudoumanessah@inria.fr>
# Licensed under the Apache License, Version 2.0 (the "License");
# http://www.apache.org/licenses/LICENSE-2.0
from chex import dataclass

from jax import Array


[docs] @dataclass class CondUNet2DOutput: """Output of the CondUNet2D model. Attributes: output: The processed output tensor from the U-Net """ output: Array
[docs] @dataclass class SDVaeOutput: """Output of the Stable Diffusion VAE model. Attributes: output: The reconstructed/decoded output tensor mean: Mean of the latent distribution logvar: Log variance of the latent distribution """ output: Array mean: Array logvar: Array