Source code for diffuse.neural_network.blocks.downsample
# Copyright 2025 Jacopo Iollo <jacopo.iollo@inria.fr>, Geoffroy Oudoumanessah <geoffroy.oudoumanessah@inria.fr>
# Licensed under the Apache License, Version 2.0 (the "License");
# http://www.apache.org/licenses/LICENSE-2.0
import jax.numpy as jnp
from flax import nnx
from jax import Array
from jax.typing import ArrayLike, DTypeLike
[docs]
class Downsample(nnx.Module):
def __init__(
self,
in_channels: int,
param_dtype: DTypeLike = jnp.float32,
dtype: DTypeLike = jnp.float32,
rngs: nnx.Rngs = None,
):
self.conv = nnx.Conv(
in_features=in_channels,
out_features=in_channels,
kernel_size=(3, 3),
strides=(2, 2),
padding=(0, 0),
param_dtype=param_dtype,
dtype=dtype,
rngs=rngs,
)
def __call__(self, x: ArrayLike) -> Array:
pad_width = ((0, 0), (0, 1), (0, 1), (0, 0))
x = jnp.pad(array=x, pad_width=pad_width, mode="constant", constant_values=0)
return self.conv(x)