jax.tree.flatten_with_path#
- jax.tree.flatten_with_path(tree, is_leaf=None, is_leaf_takes_path=False)[source]#
Flattens a pytree like
tree_flatten, but also returns each leaf’s key path.- Parameters:
- Returns:
A pair which the first element is a list of key-leaf pairs, each of which contains a leaf and its key path. The second element is a treedef representing the structure of the flattened tree.
- Return type:
tuple[list[tuple[tree_util.KeyPath, Any]], tree_util.PyTreeDef]
Examples
>>> import jax >>> path_vals, treedef = jax.tree.flatten_with_path([1, {'x': 3}]) >>> path_vals [((SequenceKey(idx=0),), 1), ((SequenceKey(idx=1), DictKey(key='x')), 3)] >>> treedef PyTreeDef([*, {'x': *}])