Source code for diffuse.neural_network.blocks.decoder

# Copyright 2025 Jacopo Iollo <jacopo.iollo@inria.fr>, Geoffroy Oudoumanessah <geoffroy.oudoumanessah@inria.fr>
# Licensed under the Apache License, Version 2.0 (the "License");
# http://www.apache.org/licenses/LICENSE-2.0
from typing import Callable

import jax.numpy as jnp

from flax import nnx
from jax import Array
from jax.typing import ArrayLike, DTypeLike

from .attention import AttnBlock
from .resnet_block import ResnetBlock
from .upsample import Upsample


[docs] class Decoder(nnx.Module): def __init__( self, ch: int, out_ch: int, ch_mult: list[int], num_res_blocks: int, z_channels: int, activation: Callable = nnx.swish, dropout: bool = False, param_dtype: DTypeLike = jnp.float32, dtype: DTypeLike = jnp.float32, rngs: nnx.Rngs = None, ): self.activation = activation self.dropout = dropout num_resolutions = len(ch_mult) block_in = ch * ch_mult[num_resolutions - 1] conv_z_in = nnx.Conv( in_features=z_channels, out_features=block_in, kernel_size=(3, 3), strides=(1, 1), padding=(1, 1), param_dtype=param_dtype, dtype=dtype, rngs=rngs, ) res_block_mid_in = ResnetBlock( in_channels=block_in, out_channels=block_in, activation=self.activation, param_dtype=param_dtype, dtype=dtype, dropout=self.dropout, rngs=rngs, ) mid_attn = AttnBlock(in_channels=block_in, param_dtype=param_dtype, dtype=dtype, rngs=rngs) res_block_mid_out = ResnetBlock( in_channels=block_in, out_channels=block_in, activation=self.activation, param_dtype=param_dtype, dtype=dtype, dropout=self.dropout, rngs=rngs, ) self.mid = nnx.Sequential(conv_z_in, res_block_mid_in, mid_attn, res_block_mid_out) blocks_up = nnx.List([]) for i_level in reversed(range(num_resolutions)): block = nnx.List([]) block_out = ch * ch_mult[i_level] for _ in range(num_res_blocks + 1): block.append( ResnetBlock( in_channels=block_in, out_channels=block_out, activation=self.activation, param_dtype=param_dtype, dtype=dtype, dropout=self.dropout, rngs=rngs, ) ) block_in = block_out if i_level != 0: block.append( Upsample( in_channels=block_in, method="resize", scale_factor=2, param_dtype=param_dtype, dtype=dtype, rngs=rngs, ) ) blocks_up.extend(block) self.up = nnx.Sequential(*blocks_up) self.norm_out = nnx.GroupNorm( num_features=block_in, num_groups=32, epsilon=1e-6, param_dtype=param_dtype, dtype=dtype, rngs=rngs, ) self.conv_out = nnx.Conv( in_features=block_in, out_features=out_ch, kernel_size=(3, 3), strides=(1, 1), padding=(1, 1), param_dtype=param_dtype, dtype=dtype, rngs=rngs, ) def __call__(self, z: ArrayLike) -> Array: h = self.mid(z) h = self.up(h) h = self.norm_out(h) h = self.activation(h) h = self.conv_out(h) return h