Source code for diffuse.neural_network.blocks.attention

# Copyright 2025 Jacopo Iollo <jacopo.iollo@inria.fr>, Geoffroy Oudoumanessah <geoffroy.oudoumanessah@inria.fr>
# Licensed under the Apache License, Version 2.0 (the "License");
# http://www.apache.org/licenses/LICENSE-2.0
"""Multi-head self-attention block for VAE-GAN and diffusion models.

This module implements spatial self-attention over feature maps, allowing each
spatial position to attend to all other positions. Commonly used in the bottleneck
of VAE encoders/decoders and UNet architectures for diffusion models.
"""

import jax.numpy as jnp
from einops import rearrange
from flax import nnx
from jax import Array
from jax.typing import ArrayLike, DTypeLike


[docs] class AttnBlock(nnx.Module): """Multi-head self-attention block with spatial attention mechanism. Performs self-attention over spatial dimensions of feature maps, where each pixel can attend to all other pixels. Uses multi-head attention to learn diverse attention patterns. Automatically leverages JAX's optimized attention implementations including cuDNN flash attention when available. Args: in_channels: Number of input channels. Must be divisible by num_heads. num_heads: Number of attention heads. Default is 8. param_dtype: Data type for parameters (weights and biases). dtype: Data type for computation. rngs: Random number generators for parameter initialization. Raises: AssertionError: If in_channels is not divisible by num_heads. Example: >>> rngs = nnx.Rngs(42) >>> attn = AttnBlock(in_channels=256, num_heads=8, rngs=rngs) >>> x = jnp.ones((2, 16, 16, 256)) # (batch, height, width, channels) >>> output = attn(x) # Same shape as input """ def __init__( self, in_channels: int, num_heads: int = 8, param_dtype: DTypeLike = jnp.float32, dtype: DTypeLike = jnp.float32, rngs: nnx.Rngs = None, ): assert in_channels % num_heads == 0, f"in_channels ({in_channels}) must be divisible by num_heads ({num_heads})" self.dtype = dtype self.num_heads = num_heads self.head_dim = in_channels // num_heads self.norm = nnx.GroupNorm( num_features=in_channels, num_groups=32, epsilon=1e-6, param_dtype=param_dtype, dtype=self.dtype, rngs=rngs, ) self.q = nnx.Conv( in_features=in_channels, out_features=in_channels, kernel_size=(1, 1), param_dtype=param_dtype, dtype=self.dtype, rngs=rngs, ) self.k = nnx.Conv( in_features=in_channels, out_features=in_channels, kernel_size=(1, 1), param_dtype=param_dtype, dtype=self.dtype, rngs=rngs, ) self.v = nnx.Conv( in_features=in_channels, out_features=in_channels, kernel_size=(1, 1), param_dtype=param_dtype, dtype=self.dtype, rngs=rngs, ) self.proj_out = nnx.Conv( in_features=in_channels, out_features=in_channels, kernel_size=(1, 1), param_dtype=param_dtype, dtype=self.dtype, rngs=rngs, )
[docs] def attention(self, h_: ArrayLike) -> Array: """Compute multi-head self-attention over spatial dimensions. Args: h_: Input feature map of shape (batch, height, width, channels). Returns: Attention output of shape (batch, height, width, channels). """ h_ = self.norm(h_) q = self.q(h_) k = self.k(h_) v = self.v(h_) b, h, w, c = q.shape # Reshape to multi-head format: [batch, seq_len, num_heads, head_dim] q = rearrange(q, "b h w (nh hd) -> b (h w) nh hd", nh=self.num_heads, hd=self.head_dim) k = rearrange(k, "b h w (nh hd) -> b (h w) nh hd", nh=self.num_heads, hd=self.head_dim) v = rearrange(v, "b h w (nh hd) -> b (h w) nh hd", nh=self.num_heads, hd=self.head_dim) # Apply multi-head attention (automatically uses best backend including cuDNN flash attention) h_ = nnx.dot_product_attention(q, k, v, dtype=self.dtype) # Reshape back to spatial format h_ = rearrange(h_, "b (h w) nh hd -> b h w (nh hd)", h=h, w=w, nh=self.num_heads, hd=self.head_dim) return h_
def __call__(self, x: ArrayLike) -> Array: """Forward pass with residual connection. Args: x: Input feature map of shape (batch, height, width, channels). Returns: Output with residual connection: x + proj_out(attention(x)). """ return x + self.proj_out(self.attention(x))