Neural Networks#

Models#

Conditional UNet#

class diffuse.neural_network.nn.condUNet.CondUNet2D(*args: Any, **kwargs: Any)[source]#

Bases: Module

Conditional U-Net for diffusion models with timestep conditioning.

A U-Net architecture for conditional image generation, featuring hierarchical downsampling/upsampling paths with skip connections, ResNet blocks, attention mechanisms, and sinusoidal timestep embeddings.

Parameters:
  • in_channels – Number of input channels.

  • ch – Base number of channels.

  • ch_mult – Channel multipliers for each resolution level.

  • num_res_blocks – Number of ResNet blocks at each resolution.

  • attention_resolutions – Resolution levels where attention is applied.

  • activation – Activation function to use throughout the network.

  • dropout – Whether to enable dropout in ResNet blocks.

  • num_heads – Number of attention heads for multi-head attention.

  • param_dtype – Data type for parameters.

  • dtype – Data type for computation.

  • rngs – Random number generators for parameter initialization.

Example

>>> rngs = nnx.Rngs(42)
>>> unet = CondUNet2D(ch=128, attention_resolutions=(8, 16), rngs=rngs)
>>> x = jnp.ones((2, 64, 64, 3))
>>> t = jnp.array([100, 200])
>>> output = unet(x, t)

VAE (Variational Autoencoder)#

class diffuse.neural_network.nn.sdVae.SDVae(*args: Any, **kwargs: Any)[source]#

Bases: Module

Stable Diffusion Variational Autoencoder (VAE).

A VAE architecture for encoding images into latent representations and decoding them back. Uses an encoder-decoder structure with diagonal Gaussian posterior, commonly used in latent diffusion models for compression.

Parameters:
  • in_channels – Number of input image channels

  • ch – Base number of channels in the network

  • out_ch – Number of output channels

  • ch_mult – Channel multipliers for each resolution level

  • num_res_blocks – Number of ResNet blocks at each resolution

  • z_channels – Number of latent space channels

  • scale_factor – Scaling factor applied to latent codes (default: 0.18215 for SD)

  • shift_factor – Shift applied to latent codes before scaling

  • activation – Activation function used throughout the network

  • param_dtype – Data type for parameters

  • dtype – Data type for computation

  • rngs – Random number generators for parameter initialization

decode(z: Array | ndarray | bool | number | bool | int | float | complex) Array[source]#

Decode latent representation back to image space.

Parameters:

z – Latent code tensor of shape (batch, z_channels, latent_h, latent_w)

Returns:

Reconstructed image tensor of shape (batch, out_ch, height, width)

encode(x: Array | ndarray | bool | number | bool | int | float | complex) Array | Tuple[Array, Array][source]#

Encode image into latent representation.

Parameters:

x – Input image tensor of shape (batch, channels, height, width)

Returns:

Tuple of (latent_code, mean, logvar) where latent_code is the sampled latent representation and mean/logvar define the diagonal Gaussian posterior

Network Parameters#

class diffuse.neural_network.nn.params.CondUNet2DOutput(output: Array)[source]#

Bases: Mapping

Output of the CondUNet2D model.

output#

The processed output tensor from the U-Net

Type:

jax.Array

from_tuple()#
items() a set-like object providing a view on D's items#
keys() a set-like object providing a view on D's keys#
output: Array#
replace(**kwargs)#
to_tuple()#
values() an object providing a view on D's values#
class diffuse.neural_network.nn.params.SDVaeOutput(output: Array, mean: Array, logvar: Array)[source]#

Bases: Mapping

Output of the Stable Diffusion VAE model.

output#

The reconstructed/decoded output tensor

Type:

jax.Array

mean#

Mean of the latent distribution

Type:

jax.Array

logvar#

Log variance of the latent distribution

Type:

jax.Array

from_tuple()#
items() a set-like object providing a view on D's items#
keys() a set-like object providing a view on D's keys#
logvar: Array#
mean: Array#
output: Array#
replace(**kwargs)#
to_tuple()#
values() an object providing a view on D's values#

Building Blocks#

Attention#

Multi-head self-attention block for VAE-GAN and diffusion models.

This module implements spatial self-attention over feature maps, allowing each spatial position to attend to all other positions. Commonly used in the bottleneck of VAE encoders/decoders and UNet architectures for diffusion models.

class diffuse.neural_network.blocks.attention.AttnBlock(*args: Any, **kwargs: Any)[source]#

Bases: Module

Multi-head self-attention block with spatial attention mechanism.

Performs self-attention over spatial dimensions of feature maps, where each pixel can attend to all other pixels. Uses multi-head attention to learn diverse attention patterns. Automatically leverages JAX’s optimized attention implementations including cuDNN flash attention when available.

Parameters:
  • in_channels – Number of input channels. Must be divisible by num_heads.

  • num_heads – Number of attention heads. Default is 8.

  • param_dtype – Data type for parameters (weights and biases).

  • dtype – Data type for computation.

  • rngs – Random number generators for parameter initialization.

Raises:

AssertionError – If in_channels is not divisible by num_heads.

Example

>>> rngs = nnx.Rngs(42)
>>> attn = AttnBlock(in_channels=256, num_heads=8, rngs=rngs)
>>> x = jnp.ones((2, 16, 16, 256))  # (batch, height, width, channels)
>>> output = attn(x)  # Same shape as input
attention(h_: Array | ndarray | bool | number | bool | int | float | complex) Array[source]#

Compute multi-head self-attention over spatial dimensions.

Parameters:

h – Input feature map of shape (batch, height, width, channels).

Returns:

Attention output of shape (batch, height, width, channels).

ResNet Block#

ResNet block with optional timestep conditioning for diffusion models.

This module implements a ResNet block that can optionally receive timestep embeddings for conditioning in diffusion models. Supports FiLM (Feature-wise Linear Modulation) conditioning and configurable dropout.

class diffuse.neural_network.blocks.resnet_block.ResnetBlock(*args: Any, **kwargs: Any)[source]#

Bases: TimestepBlock

ResNet block with optional timestep conditioning.

A residual block consisting of two convolutional layers with GroupNorm and activation functions. Optionally accepts timestep embeddings for FiLM conditioning in diffusion models. Includes skip connections for residual learning.

Parameters:
  • in_channels – Number of input channels.

  • out_channels – Number of output channels. If None, uses in_channels.

  • activation – Activation function to use. Default is swish.

  • embedding_dim – Dimension of timestep embeddings for conditioning. If None, no timestep conditioning is applied.

  • param_dtype – Data type for parameters.

  • dtype – Data type for computation.

  • dropout – Whether to apply dropout.

  • dropout_rate – Dropout rate when dropout is enabled.

  • rngs – Random number generators for parameter initialization.

Example

>>> rngs = nnx.Rngs(42)
>>> block = ResnetBlock(in_channels=128, out_channels=256, rngs=rngs)
>>> x = jnp.ones((2, 32, 32, 128))
>>> output = block(x)  # Shape: (2, 32, 32, 256)
>>>
>>> # With timestep conditioning
>>> block_cond = ResnetBlock(in_channels=128, embedding_dim=512, rngs=rngs)
>>> time_emb = jnp.ones((2, 512))
>>> output = block_cond(x, time_emb)

Time Embedding#

Timestep embedding modules for diffusion models.

This module provides sinusoidal positional encodings and MLP-based timestep embeddings commonly used for conditioning diffusion model predictions on the current noise level or timestep.

class diffuse.neural_network.blocks.time_embedding.TimestepEmbedding(*args: Any, **kwargs: Any)[source]#

Bases: Module

MLP-based timestep embedding processor.

Transforms sinusoidal timestep embeddings through a two-layer MLP with an activation function. Commonly used to increase expressiveness of the timestep conditioning signal.

Parameters:
  • embedding_dim – Dimension of input and output embeddings

  • activation – Activation function to use between layers

  • param_dtype – Data type for parameters

  • dtype – Data type for computation

  • rngs – Random number generators for parameter initialization

class diffuse.neural_network.blocks.time_embedding.Timesteps(*args: Any, **kwargs: Any)[source]#

Bases: Module

Sinusoidal timestep embedding layer.

Converts scalar timestep values into high-dimensional sinusoidal embeddings for use as conditioning signals in diffusion models.

Parameters:
  • embedding_dim – Dimension of the output embedding

  • max_period – Maximum period for the sinusoidal functions

  • dtype – Data type for the output embeddings

diffuse.neural_network.blocks.time_embedding.get_sinusoidal_embedding(t: Array | ndarray | bool | number | bool | int | float | complex, embedding_dim: int = 64, max_period: int = 10000) Array[source]#

Generate sinusoidal positional embeddings for timesteps.

Creates embeddings using sine and cosine functions at different frequencies, similar to the positional encodings in “Attention is All You Need”.

Parameters:
  • t – Timestep values to encode

  • embedding_dim – Dimension of the output embedding

  • max_period – Maximum period for the sinusoidal functions

Returns:

Sinusoidal embeddings of shape (batch, embedding_dim)

Timestep#

Base classes for modules that accept timestep embeddings.

class diffuse.neural_network.blocks.timestep.TimestepBlock(*args: Any, **kwargs: Any)[source]#

Bases: Module

Base class for modules that can accept optional timestep embeddings.

This abstract class defines the interface for neural network modules that support timestep conditioning, commonly used in diffusion models.

class diffuse.neural_network.blocks.timestep.TimestepEmbedSequential(*args: Any, **kwargs: Any)[source]#

Bases: Sequential, TimestepBlock

Sequential container that passes timestep embeddings to compatible layers.

Extends nnx.Sequential to support modules that accept timestep embeddings. Automatically passes time_emb to layers that inherit from TimestepBlock, while calling other layers without the timestep argument.

Encoder#

class diffuse.neural_network.blocks.encoder.Encoder(*args: Any, **kwargs: Any)[source]#

Bases: Module

Decoder#

class diffuse.neural_network.blocks.decoder.Decoder(*args: Any, **kwargs: Any)[source]#

Bases: Module

Downsample#

class diffuse.neural_network.blocks.downsample.Downsample(*args: Any, **kwargs: Any)[source]#

Bases: Module

Upsample#

Upsampling blocks for neural networks.

This module provides various upsampling methods including nearest neighbor resize and pixel shuffle (sub-pixel convolution) for increasing spatial resolution in generative models and decoders.

class diffuse.neural_network.blocks.upsample.PixelShuffle(*args: Any, **kwargs: Any)[source]#

Bases: Module

Pixel shuffle operation for sub-pixel convolution upsampling.

Rearranges elements in a tensor from channel dimension to spatial dimensions. Commonly used in super-resolution and generative models for learnable upsampling.

Parameters:

scale – Upsampling scale factor (both height and width).

Example

>>> shuffle = PixelShuffle(scale=2)
>>> x = jnp.ones((1, 4, 4, 16))  # 4 channels per output pixel after 2x upsample
>>> y = shuffle(x)  # Shape: (1, 8, 8, 4)
class diffuse.neural_network.blocks.upsample.Upsample(*args: Any, **kwargs: Any)[source]#

Bases: Module

Flexible upsampling block with multiple methods.

Supports different upsampling methods: nearest neighbor resize and pixel shuffle. Includes a convolutional layer after upsampling for feature refinement.

Parameters:
  • in_channels – Number of input channels.

  • method – Upsampling method, either “resize” or “pixel_shuffle”.

  • scale_factor – Upsampling scale factor (integer).

  • param_dtype – Data type for parameters.

  • dtype – Data type for computation.

  • rngs – Random number generators for parameter initialization.

Example

>>> rngs = nnx.Rngs(42)
>>> upsample = Upsample(in_channels=128, method="resize", rngs=rngs)
>>> x = jnp.ones((2, 16, 16, 128))
>>> y = upsample(x)  # Shape: (2, 32, 32, 128)