Utilities#
Logger#
Mapping#
- diffuse.utils.mapping.make_in_axes_except(x: PyTree, except_path: str) PyTree[source]#
Creates an in_axes PyTree where all leaves are 0 except for the specified path which gets None.
- Parameters:
x – The PyTree to create in_axes for
except_path – The path/field name to exclude (will get None instead of 0)
- Returns:
A PyTree with the same structure as x but with 0s and one None
Example
- class State(NamedTuple):
position: Array step: int
state = State(position=jnp.array([1,2,3]), step=0) in_axes = make_in_axes_except(state, “step”) # Returns: State(position=0, step=None)