jax.tree_util module#

Utilities for working with tree-like container data structures.

This module provides a small set of utility functions for working with tree-like data structures, such as nested tuples, lists, and dicts. We call these structures pytrees. They are trees in that they are defined recursively (any non-pytree is a pytree, i.e. a leaf, and any pytree of pytrees is a pytree) and can be operated on recursively (object identity equivalence is not preserved by mapping operations, and the structures cannot contain reference cycles).

The set of Python types that are considered pytree nodes (e.g. that can be mapped over, rather than treated as leaves) is extensible. There is a single module-level registry of types, and class hierarchy is ignored. By registering a new pytree node type, that type in effect becomes transparent to the utility functions in this file.

The primary purpose of this module is to enable the interoperability between user defined data structures and JAX transformations (e.g. jit). This is not meant to be a general purpose tree-like data structure handling library.

See the JAX pytrees note for examples.

List of Functions#

Partial(func, *args, **kw)

A version of functools.partial that works in pytrees.

all_leaves(iterable[, is_leaf])

Tests whether all elements in the given iterable are all leaves.

register_dataclass(nodetype[, data_fields, ...])

Extends the set of types that are considered internal nodes in pytrees.

register_pytree_node(nodetype, flatten_func, ...)

Extends the set of types that are considered internal nodes in pytrees.

register_pytree_node_class(cls)

Extends the set of types that are considered internal nodes in pytrees.

register_pytree_with_keys(nodetype, ...[, ...])

Extends the set of types that are considered internal nodes in pytrees.

register_pytree_with_keys_class(cls)

Extends the set of types that are considered internal nodes in pytrees.

register_static(cls)

Registers cls as a pytree with no leaves.

tree_flatten_with_path(tree[, is_leaf, ...])

Alias of jax.tree.flatten_with_path().

tree_leaves_with_path(tree[, is_leaf, ...])

Alias of jax.tree.leaves_with_path().

tree_map_with_path(f, tree, *rest[, ...])

Alias of jax.tree.map_with_path().

treedef_children(treedef)

Return a list of treedefs for immediate children

treedef_is_leaf(treedef)

Return True if the treedef represents a leaf.

treedef_tuple(treedefs)

Makes a tuple treedef from an iterable of child treedefs.

KeyEntry

Type variable.

KeyPath

Built-in immutable sequence.

keystr(keys, *[, simple, separator])

Helper to pretty-print a tuple of keys.

Legacy APIs#

These APIs are now accessed via jax.tree.

tree_all(tree, *[, is_leaf])

Alias of jax.tree.all().

tree_broadcast(prefix_tree, full_tree[, is_leaf])

Alias of jax.tree.broadcast().

tree_flatten(tree[, is_leaf])

Alias of jax.tree.flatten().

tree_leaves(tree[, is_leaf])

Alias of jax.tree.leaves().

tree_map(f, tree, *rest[, is_leaf])

Alias of jax.tree.map().

tree_reduce(function, tree[, initializer, ...])

Alias of jax.tree.reduce().

tree_reduce_associative(operation, tree, *)

Alias of jax.tree.reduce_associative().

tree_structure(tree[, is_leaf])

Alias of jax.tree.structure().

tree_transpose(outer_treedef, inner_treedef, ...)

Alias of jax.tree.transpose().

tree_unflatten(treedef, leaves)

Alias of jax.tree.unflatten().