Source code for diffuse.neural_network.nn.condUNet

# Copyright 2025 Jacopo Iollo <jacopo.iollo@inria.fr>, Geoffroy Oudoumanessah <geoffroy.oudoumanessah@inria.fr>
# Licensed under the Apache License, Version 2.0 (the "License");
# http://www.apache.org/licenses/LICENSE-2.0
from collections.abc import Callable

import jax.numpy as jnp
from flax import nnx
from jax.typing import ArrayLike, DTypeLike

from ..blocks import (
    AttnBlock,
    Downsample,
    ResnetBlock,
    TimestepEmbedding,
    TimestepEmbedSequential,
    Timesteps,
    Upsample,
)
from .params import CondUNet2DOutput


[docs] class CondUNet2D(nnx.Module): """Conditional U-Net for diffusion models with timestep conditioning. A U-Net architecture for conditional image generation, featuring hierarchical downsampling/upsampling paths with skip connections, ResNet blocks, attention mechanisms, and sinusoidal timestep embeddings. Args: in_channels: Number of input channels. ch: Base number of channels. ch_mult: Channel multipliers for each resolution level. num_res_blocks: Number of ResNet blocks at each resolution. attention_resolutions: Resolution levels where attention is applied. activation: Activation function to use throughout the network. dropout: Whether to enable dropout in ResNet blocks. num_heads: Number of attention heads for multi-head attention. param_dtype: Data type for parameters. dtype: Data type for computation. rngs: Random number generators for parameter initialization. Example: >>> rngs = nnx.Rngs(42) >>> unet = CondUNet2D(ch=128, attention_resolutions=(8, 16), rngs=rngs) >>> x = jnp.ones((2, 64, 64, 3)) >>> t = jnp.array([100, 200]) >>> output = unet(x, t) """ def __init__( self, in_channels: int = 3, ch: int = 64, ch_mult: tuple[int, ...] = (1, 2, 3, 4), num_res_blocks: int = 2, attention_resolutions: tuple[int, ...] = (1, 2, 4, 8), activation: Callable = nnx.swish, dropout: bool = True, dropout_rate: float = 0.1, num_heads: int = 8, param_dtype: DTypeLike = jnp.float32, dtype: DTypeLike = jnp.float32, rngs: nnx.Rngs = None, ): self.param_dtype = param_dtype self.dtype = dtype time_embed_dim = ch * 4 self.time_proj = Timesteps(embedding_dim=time_embed_dim, max_period=10_000, dtype=dtype) self.time_embedding = TimestepEmbedding( embedding_dim=time_embed_dim, activation=activation, param_dtype=self.param_dtype, dtype=dtype, rngs=rngs, ) current_ch = int(ch_mult[0] * ch) conv_in = nnx.Conv( in_features=in_channels, out_features=current_ch, kernel_size=(3, 3), strides=(1, 1), padding=(1, 1), param_dtype=param_dtype, dtype=dtype, rngs=rngs, ) blocks_down = nnx.List([TimestepEmbedSequential(conv_in)]) input_block_channels = [current_ch] ds = 1 for level, mult in enumerate(ch_mult): for _ in range(num_res_blocks): layers = [ ResnetBlock( in_channels=current_ch, out_channels=int(mult * ch), activation=activation, embedding_dim=time_embed_dim, param_dtype=self.param_dtype, dtype=dtype, dropout=dropout, dropout_rate=dropout_rate, rngs=rngs, ) ] current_ch = int(mult * ch) if ds in attention_resolutions: layers.append( AttnBlock( in_channels=current_ch, num_heads=num_heads, param_dtype=param_dtype, dtype=dtype, rngs=rngs, ) ) blocks_down.append(TimestepEmbedSequential(*layers)) input_block_channels.append(current_ch) if level != len(ch_mult) - 1: blocks_down.append( TimestepEmbedSequential( Downsample( in_channels=current_ch, param_dtype=self.param_dtype, dtype=dtype, rngs=rngs, ) ) ) input_block_channels.append(current_ch) ds *= 2 self.blocks_down = blocks_down # mid self.block_mid = TimestepEmbedSequential( ResnetBlock( in_channels=current_ch, out_channels=current_ch, activation=activation, embedding_dim=time_embed_dim, param_dtype=self.param_dtype, dtype=dtype, dropout=dropout, dropout_rate=dropout_rate, rngs=rngs, ), AttnBlock( in_channels=current_ch, num_heads=num_heads, param_dtype=param_dtype, dtype=dtype, rngs=rngs, ), ResnetBlock( in_channels=current_ch, out_channels=current_ch, activation=activation, embedding_dim=time_embed_dim, param_dtype=self.param_dtype, dtype=dtype, dropout=dropout, dropout_rate=dropout_rate, rngs=rngs, ), ) # up blocks_up = nnx.List([]) for level, mult in reversed(list(enumerate(ch_mult))): for i in range(num_res_blocks + 1): ich = input_block_channels.pop() layers = [ ResnetBlock( in_channels=current_ch + ich, out_channels=int(mult * ch), activation=activation, embedding_dim=time_embed_dim, param_dtype=self.param_dtype, dtype=dtype, dropout=dropout, dropout_rate=dropout_rate, rngs=rngs, ) ] current_ch = int(mult * ch) if ds in attention_resolutions: layers.append( AttnBlock( in_channels=current_ch, num_heads=num_heads, param_dtype=self.param_dtype, dtype=dtype, rngs=rngs, ) ) if level and i == num_res_blocks: layers.append( Upsample( in_channels=current_ch, method="pixel_shuffle", scale_factor=2, param_dtype=param_dtype, dtype=dtype, rngs=rngs, ) ) ds //= 2 blocks_up.append(TimestepEmbedSequential(*layers)) self.blocks_up = blocks_up norm_out = nnx.GroupNorm( num_features=current_ch, num_groups=32, epsilon=1e-6, param_dtype=self.param_dtype, dtype=dtype, rngs=rngs, ) activation_out = activation conv_out = nnx.Conv( in_features=current_ch, out_features=in_channels, kernel_size=(3, 3), strides=(1, 1), padding=(1, 1), param_dtype=self.param_dtype, dtype=dtype, rngs=rngs, ) self.out = nnx.Sequential(norm_out, activation_out, conv_out) def __call__(self, x: ArrayLike, t: ArrayLike | None = None) -> CondUNet2DOutput: """Forward pass through the U-Net with timestep conditioning. Args: x: Input tensor of shape (batch, channels, height, width). t: Timestep values for conditioning. If None, defaults to zeros. Returns: CondUNet2DOutput containing the processed tensor. """ squeeze = False if x.ndim < 4: x = jnp.expand_dims(x, 0) squeeze = True h = x if t is None: t = jnp.zeros((x.shape[0], 1)).astype(self.dtype) t_emb = self.time_proj(t) t_emb = self.time_embedding(t_emb) hs = [] for block in self.blocks_down: h = block(h, t_emb) hs.append(h) h = self.block_mid(h, t_emb) for block in self.blocks_up: h = jnp.concatenate([h, hs.pop()], axis=-1) h = block(h, t_emb) h = self.out(h) if squeeze: h = h.squeeze(axis=0) return CondUNet2DOutput(output=h)