Utilities#

Logger#

Mapping#

diffuse.utils.mapping.make_in_axes_except(x: PyTree, except_path: str) PyTree[source]#

Creates an in_axes PyTree where all leaves are 0 except for the specified path which gets None.

Parameters:
  • x – The PyTree to create in_axes for

  • except_path – The path/field name to exclude (will get None instead of 0)

Returns:

A PyTree with the same structure as x but with 0s and one None

Example

class State(NamedTuple):

position: Array step: int

state = State(position=jnp.array([1,2,3]), step=0) in_axes = make_in_axes_except(state, “step”) # Returns: State(position=0, step=None)

diffuse.utils.mapping.pmap_reshaping(x: PyTree) PyTree[source]#
diffuse.utils.mapping.pmap_unshaping(x: PyTree)[source]#
diffuse.utils.mapping.pmapper(fn, x: T, batch_size: int = None, **kwargs) T[source]#