jax.tree_util.register_dataclass#
- jax.tree_util.register_dataclass(nodetype, data_fields=None, meta_fields=None, drop_fields=())[source]#
Extends the set of types that are considered internal nodes in pytrees.
This differs from
register_pytree_with_keys_classin that the C++ registries use the optimized C++ dataclass builtin instead of the argument functions.See Extending pytrees for more information about registering pytrees.
- Parameters:
nodetype (Typ) – a Python type to treat as an internal pytree node. This is assumed to have the semantics of a
dataclass: namely, class attributes represent the whole of the object state, and can be passed as keywords to the class constructor to create a copy of the object. All defined attributes should be listed amongmeta_fieldsordata_fields.meta_fields (Sequence[str] | None) – metadata field names: these are attributes which will be treated as {term}`static` when this pytree is passed to
jax.jit().meta_fieldsis optional only ifnodetypeis a dataclass, in which case individual fields can be marked static viadataclasses.field()(see examples below). Metadata fields must be static, hashable, immutable objects, as these objects are used to generate JIT cache keys. In particular, metadata fields cannot containjax.Arrayornumpy.ndarrayobjects.data_fields (Sequence[str] | None) – data field names: these are attributes which will be treated as non-static when this pytree is passed to
jax.jit().data_fieldsis optional only ifnodetypeis a dataclass, in which case fields are assumed data fields unless marked viadataclasses.field()(see examples below). Data fields must be JAX-compatible objects such as arrays (jax.Arrayornumpy.ndarray), scalars, or pytrees whose leaves are arrays or scalars. Note thatNoneis a valid data field, as JAX recognizes this as an empty pytree.drop_fields (Sequence[str])
- Returns:
The input class
nodetypeis returned unchanged after being added to JAX’s pytree registry, so thatregister_dataclass()can be used as a decorator.- Return type:
Typ
Examples
In JAX v0.4.35 or older, you must specify
data_fieldsandmeta_fieldsin order to use this decorator:>>> import jax >>> from dataclasses import dataclass >>> from functools import partial ... >>> @partial(jax.tree_util.register_dataclass, ... data_fields=['x', 'y'], ... meta_fields=['op']) ... @dataclass ... class MyStruct: ... x: jax.Array ... y: jax.Array ... op: str ... >>> m = MyStruct(x=jnp.ones(3), y=jnp.arange(3), op='add') >>> m MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
Starting in JAX v0.4.36, the
data_fieldsandmeta_fieldsarguments are optional fordataclass()inputs, with fields defaulting todata_fieldsunless marked as static using static metadata indataclasses.field().>>> import jax >>> from dataclasses import dataclass, field ... >>> @jax.tree_util.register_dataclass ... @dataclass ... class MyStruct: ... x: jax.Array # defaults to non-static data field ... y: jax.Array # defaults to non-static data field ... op: str = field(metadata=dict(static=True)) # marked as static meta field. ... >>> m = MyStruct(x=jnp.ones(3), y=jnp.arange(3), op='add') >>> m MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
Once this class is registered, it can be used with functions in
jax.treeandjax.tree_util:>>> leaves, treedef = jax.tree.flatten(m) >>> leaves [Array([1., 1., 1.], dtype=float32), Array([0, 1, 2], dtype=int32)] >>> treedef PyTreeDef(CustomNode(MyStruct[('add',)], [*, *])) >>> jax.tree.unflatten(treedef, leaves) MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
In particular, this registration allows
mto be passed seamlessly through code wrapped injax.jit()and other JAX transformations, withdata_fieldsbeing treated as dynamic arguments, andmeta_fieldsbeing treated as static arguments:>>> @jax.jit ... def compiled_func(m): ... if m.op == 'add': ... return m.x + m.y ... else: ... raise ValueError(f"{m.op=}") ... >>> compiled_func(m) Array([1., 2., 3.], dtype=float32)