Neural Networks#
Models#
Conditional UNet#
- class diffuse.neural_network.nn.condUNet.CondUNet2D(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleConditional U-Net for diffusion models with timestep conditioning.
A U-Net architecture for conditional image generation, featuring hierarchical downsampling/upsampling paths with skip connections, ResNet blocks, attention mechanisms, and sinusoidal timestep embeddings.
- Parameters:
in_channels – Number of input channels.
ch – Base number of channels.
ch_mult – Channel multipliers for each resolution level.
num_res_blocks – Number of ResNet blocks at each resolution.
attention_resolutions – Resolution levels where attention is applied.
activation – Activation function to use throughout the network.
dropout – Whether to enable dropout in ResNet blocks.
num_heads – Number of attention heads for multi-head attention.
param_dtype – Data type for parameters.
dtype – Data type for computation.
rngs – Random number generators for parameter initialization.
Example
>>> rngs = nnx.Rngs(42) >>> unet = CondUNet2D(ch=128, attention_resolutions=(8, 16), rngs=rngs) >>> x = jnp.ones((2, 64, 64, 3)) >>> t = jnp.array([100, 200]) >>> output = unet(x, t)
VAE (Variational Autoencoder)#
- class diffuse.neural_network.nn.sdVae.SDVae(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleStable Diffusion Variational Autoencoder (VAE).
A VAE architecture for encoding images into latent representations and decoding them back. Uses an encoder-decoder structure with diagonal Gaussian posterior, commonly used in latent diffusion models for compression.
- Parameters:
in_channels – Number of input image channels
ch – Base number of channels in the network
out_ch – Number of output channels
ch_mult – Channel multipliers for each resolution level
num_res_blocks – Number of ResNet blocks at each resolution
z_channels – Number of latent space channels
scale_factor – Scaling factor applied to latent codes (default: 0.18215 for SD)
shift_factor – Shift applied to latent codes before scaling
activation – Activation function used throughout the network
param_dtype – Data type for parameters
dtype – Data type for computation
rngs – Random number generators for parameter initialization
- decode(z: Array | ndarray | bool | number | bool | int | float | complex) Array[source]#
Decode latent representation back to image space.
- Parameters:
z – Latent code tensor of shape (batch, z_channels, latent_h, latent_w)
- Returns:
Reconstructed image tensor of shape (batch, out_ch, height, width)
- encode(x: Array | ndarray | bool | number | bool | int | float | complex) Array | Tuple[Array, Array][source]#
Encode image into latent representation.
- Parameters:
x – Input image tensor of shape (batch, channels, height, width)
- Returns:
Tuple of (latent_code, mean, logvar) where latent_code is the sampled latent representation and mean/logvar define the diagonal Gaussian posterior
Network Parameters#
- class diffuse.neural_network.nn.params.CondUNet2DOutput(output: Array)[source]#
Bases:
MappingOutput of the CondUNet2D model.
- from_tuple()#
- items() a set-like object providing a view on D's items#
- keys() a set-like object providing a view on D's keys#
- replace(**kwargs)#
- to_tuple()#
- values() an object providing a view on D's values#
- class diffuse.neural_network.nn.params.SDVaeOutput(output: Array, mean: Array, logvar: Array)[source]#
Bases:
MappingOutput of the Stable Diffusion VAE model.
- from_tuple()#
- items() a set-like object providing a view on D's items#
- keys() a set-like object providing a view on D's keys#
- replace(**kwargs)#
- to_tuple()#
- values() an object providing a view on D's values#
Building Blocks#
Attention#
Multi-head self-attention block for VAE-GAN and diffusion models.
This module implements spatial self-attention over feature maps, allowing each spatial position to attend to all other positions. Commonly used in the bottleneck of VAE encoders/decoders and UNet architectures for diffusion models.
- class diffuse.neural_network.blocks.attention.AttnBlock(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleMulti-head self-attention block with spatial attention mechanism.
Performs self-attention over spatial dimensions of feature maps, where each pixel can attend to all other pixels. Uses multi-head attention to learn diverse attention patterns. Automatically leverages JAX’s optimized attention implementations including cuDNN flash attention when available.
- Parameters:
in_channels – Number of input channels. Must be divisible by num_heads.
num_heads – Number of attention heads. Default is 8.
param_dtype – Data type for parameters (weights and biases).
dtype – Data type for computation.
rngs – Random number generators for parameter initialization.
- Raises:
AssertionError – If in_channels is not divisible by num_heads.
Example
>>> rngs = nnx.Rngs(42) >>> attn = AttnBlock(in_channels=256, num_heads=8, rngs=rngs) >>> x = jnp.ones((2, 16, 16, 256)) # (batch, height, width, channels) >>> output = attn(x) # Same shape as input
ResNet Block#
ResNet block with optional timestep conditioning for diffusion models.
This module implements a ResNet block that can optionally receive timestep embeddings for conditioning in diffusion models. Supports FiLM (Feature-wise Linear Modulation) conditioning and configurable dropout.
- class diffuse.neural_network.blocks.resnet_block.ResnetBlock(*args: Any, **kwargs: Any)[source]#
Bases:
TimestepBlockResNet block with optional timestep conditioning.
A residual block consisting of two convolutional layers with GroupNorm and activation functions. Optionally accepts timestep embeddings for FiLM conditioning in diffusion models. Includes skip connections for residual learning.
- Parameters:
in_channels – Number of input channels.
out_channels – Number of output channels. If None, uses in_channels.
activation – Activation function to use. Default is swish.
embedding_dim – Dimension of timestep embeddings for conditioning. If None, no timestep conditioning is applied.
param_dtype – Data type for parameters.
dtype – Data type for computation.
dropout – Whether to apply dropout.
dropout_rate – Dropout rate when dropout is enabled.
rngs – Random number generators for parameter initialization.
Example
>>> rngs = nnx.Rngs(42) >>> block = ResnetBlock(in_channels=128, out_channels=256, rngs=rngs) >>> x = jnp.ones((2, 32, 32, 128)) >>> output = block(x) # Shape: (2, 32, 32, 256) >>> >>> # With timestep conditioning >>> block_cond = ResnetBlock(in_channels=128, embedding_dim=512, rngs=rngs) >>> time_emb = jnp.ones((2, 512)) >>> output = block_cond(x, time_emb)
Time Embedding#
Timestep embedding modules for diffusion models.
This module provides sinusoidal positional encodings and MLP-based timestep embeddings commonly used for conditioning diffusion model predictions on the current noise level or timestep.
- class diffuse.neural_network.blocks.time_embedding.TimestepEmbedding(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleMLP-based timestep embedding processor.
Transforms sinusoidal timestep embeddings through a two-layer MLP with an activation function. Commonly used to increase expressiveness of the timestep conditioning signal.
- Parameters:
embedding_dim – Dimension of input and output embeddings
activation – Activation function to use between layers
param_dtype – Data type for parameters
dtype – Data type for computation
rngs – Random number generators for parameter initialization
- class diffuse.neural_network.blocks.time_embedding.Timesteps(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleSinusoidal timestep embedding layer.
Converts scalar timestep values into high-dimensional sinusoidal embeddings for use as conditioning signals in diffusion models.
- Parameters:
embedding_dim – Dimension of the output embedding
max_period – Maximum period for the sinusoidal functions
dtype – Data type for the output embeddings
- diffuse.neural_network.blocks.time_embedding.get_sinusoidal_embedding(t: Array | ndarray | bool | number | bool | int | float | complex, embedding_dim: int = 64, max_period: int = 10000) Array[source]#
Generate sinusoidal positional embeddings for timesteps.
Creates embeddings using sine and cosine functions at different frequencies, similar to the positional encodings in “Attention is All You Need”.
- Parameters:
t – Timestep values to encode
embedding_dim – Dimension of the output embedding
max_period – Maximum period for the sinusoidal functions
- Returns:
Sinusoidal embeddings of shape (batch, embedding_dim)
Timestep#
Base classes for modules that accept timestep embeddings.
- class diffuse.neural_network.blocks.timestep.TimestepBlock(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleBase class for modules that can accept optional timestep embeddings.
This abstract class defines the interface for neural network modules that support timestep conditioning, commonly used in diffusion models.
- class diffuse.neural_network.blocks.timestep.TimestepEmbedSequential(*args: Any, **kwargs: Any)[source]#
Bases:
Sequential,TimestepBlockSequential container that passes timestep embeddings to compatible layers.
Extends nnx.Sequential to support modules that accept timestep embeddings. Automatically passes time_emb to layers that inherit from TimestepBlock, while calling other layers without the timestep argument.
Encoder#
Decoder#
Downsample#
Upsample#
Upsampling blocks for neural networks.
This module provides various upsampling methods including nearest neighbor resize and pixel shuffle (sub-pixel convolution) for increasing spatial resolution in generative models and decoders.
- class diffuse.neural_network.blocks.upsample.PixelShuffle(*args: Any, **kwargs: Any)[source]#
Bases:
ModulePixel shuffle operation for sub-pixel convolution upsampling.
Rearranges elements in a tensor from channel dimension to spatial dimensions. Commonly used in super-resolution and generative models for learnable upsampling.
- Parameters:
scale – Upsampling scale factor (both height and width).
Example
>>> shuffle = PixelShuffle(scale=2) >>> x = jnp.ones((1, 4, 4, 16)) # 4 channels per output pixel after 2x upsample >>> y = shuffle(x) # Shape: (1, 8, 8, 4)
- class diffuse.neural_network.blocks.upsample.Upsample(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleFlexible upsampling block with multiple methods.
Supports different upsampling methods: nearest neighbor resize and pixel shuffle. Includes a convolutional layer after upsampling for feature refinement.
- Parameters:
in_channels – Number of input channels.
method – Upsampling method, either “resize” or “pixel_shuffle”.
scale_factor – Upsampling scale factor (integer).
param_dtype – Data type for parameters.
dtype – Data type for computation.
rngs – Random number generators for parameter initialization.
Example
>>> rngs = nnx.Rngs(42) >>> upsample = Upsample(in_channels=128, method="resize", rngs=rngs) >>> x = jnp.ones((2, 16, 16, 128)) >>> y = upsample(x) # Shape: (2, 32, 32, 128)