Source code for diffuse.utils.mapping
# Copyright 2025 Jacopo Iollo <jacopo.iollo@inria.fr>, Geoffroy Oudoumanessah <geoffroy.oudoumanessah@inria.fr>
# Licensed under the Apache License, Version 2.0 (the "License");
# http://www.apache.org/licenses/LICENSE-2.0
import jax
from jaxtyping import PyTree
from functools import partial
from typing import TypeVar
T = TypeVar("T", bound=PyTree)
[docs]
def make_in_axes_except(x: PyTree, except_path: str) -> PyTree:
"""
Creates an in_axes PyTree where all leaves are 0 except for the specified path which gets None.
Args:
x: The PyTree to create in_axes for
except_path: The path/field name to exclude (will get None instead of 0)
Returns:
A PyTree with the same structure as x but with 0s and one None
Example:
class State(NamedTuple):
position: Array
step: int
state = State(position=jnp.array([1,2,3]), step=0)
in_axes = make_in_axes_except(state, "step")
# Returns: State(position=0, step=None)
"""
def _set_axes(path, _):
if except_path in str(path):
return None
return 0
return jax.tree_util.tree_map_with_path(_set_axes, x)
[docs]
def pmap_reshaping(x: PyTree) -> PyTree:
num_devices = jax.device_count()
return jax.tree_util.tree_map(
lambda x: x.reshape((num_devices, -1, *x.shape[1:])) if len(x.shape) > 0 else x,
x,
)
[docs]
def pmap_unshaping(x: PyTree):
return jax.tree_util.tree_map(lambda x: x.reshape((-1, *x.shape[2:])) if len(x.shape) > 0 else x, x)
[docs]
def pmapper(fn, x: T, batch_size: int = None, **kwargs) -> T:
fn = partial(fn, **kwargs)
def mapped_fn(x_):
return jax.lax.map(f=fn, xs=x_, batch_size=batch_size)
in_axes = jax.tree_util.tree_map(lambda _: 0, x)
in_axes = (in_axes,)
pmapped_fn = jax.pmap(mapped_fn, axis_name="devices", in_axes=in_axes)
pmap_x = pmap_reshaping(x)
pmaped_y = pmapped_fn(pmap_x)
return pmap_unshaping(pmaped_y)