jax.tree_util.register_pytree_node_class#
- jax.tree_util.register_pytree_node_class(cls)[source]#
Extends the set of types that are considered internal nodes in pytrees.
This function is a thin wrapper around
register_pytree_node, and provides a class-oriented interface.- Parameters:
cls (Typ) – a type to register as a pytree
- Returns:
The input class
clsis returned unchanged after being added to JAX’s pytree registry. This return value allowsregister_pytree_node_classto be used as a decorator.- Return type:
Typ
See also
register_static(): simpler API for registering a static pytree.register_dataclass(): simpler API for registering a dataclass.
Examples
Here we’ll define a custom container that will be compatible with
jax.jit()and other JAX transformations:>>> import jax >>> @jax.tree_util.register_pytree_node_class ... class MyContainer: ... def __init__(self, x, y): ... self.x = x ... self.y = y ... def tree_flatten(self): ... return ((self.x, self.y), None) ... @classmethod ... def tree_unflatten(cls, aux_data, children): ... return cls(*children) ... >>> m = MyContainer(jnp.zeros(4), jnp.arange(4)) >>> def f(m): ... return m.x + 2 * m.y >>> jax.jit(f)(m) Array([0., 2., 4., 6.], dtype=float32)