jax.tree_util.register_pytree_node#
- jax.tree_util.register_pytree_node(nodetype, flatten_func, unflatten_func, flatten_with_keys_func=None)[source]#
- Extends the set of types that are considered internal nodes in pytrees. - See example usage. - Parameters:
- nodetype (type[T]) – a Python type to register as a pytree. 
- flatten_func (Callable[[T], tuple[_Children, _AuxData]]) – a function to be used during flattening, taking a value of type - nodetypeand returning a pair, with (1) an iterable for the children to be flattened recursively, and (2) some hashable auxiliary data to be stored in the treedef and to be passed to the- unflatten_func.
- unflatten_func (Callable[[_AuxData, _Children], T]) – a function taking two arguments: the auxiliary data that was returned by - flatten_funcand stored in the treedef, and the unflattened children. The function should return an instance of- nodetype.
- flatten_with_keys_func (Callable[[T], tuple[KeyLeafPairs, _AuxData]] | None) 
 
- Return type:
- None 
 - See also - register_static(): simpler API for registering a static pytree.
- register_dataclass(): simpler API for registering a dataclass.
 - Examples - First we’ll define a custom type: - >>> class MyContainer: ... def __init__(self, size): ... self.x = jnp.zeros(size) ... self.y = jnp.ones(size) ... self.size = size - If we try using this in a JIT-compiled function, we’ll get an error because JAX does not yet know how to handle this type: - >>> m = MyContainer(size=5) >>> def f(m): ... return m.x + m.y + jnp.arange(m.size) >>> jax.jit(f)(m) Traceback (most recent call last): ... TypeError: Cannot interpret value of type <class 'jax.tree_util.MyContainer'> as an abstract array; it does not have a dtype attribute - In order to make our object recognized by JAX, we must register it as a pytree: - >>> def flatten_func(obj): ... children = (obj.x, obj.y) # children must contain arrays & pytrees ... aux_data = (obj.size,) # aux_data must contain static, hashable data. ... return (children, aux_data) ... >>> def unflatten_func(aux_data, children): ... # Here we avoid `__init__` because it has extra logic we don't require: ... obj = object.__new__(MyContainer) ... obj.x, obj.y = children ... obj.size, = aux_data ... return obj ... >>> jax.tree_util.register_pytree_node(MyContainer, flatten_func, unflatten_func) - Now with this defined, we can use instances of this type in JIT-compiled functions. - >>> jax.jit(f)(m) Array([1., 2., 3., 4., 5.], dtype=float32) 
