jax.tree.map_with_path#
- jax.tree.map_with_path(f, tree, *rest, is_leaf=None, is_leaf_takes_path=False)[source]#
Maps a multi-input function over pytree key path and args to produce a new pytree.
This is a more powerful alternative of
tree_mapthat can take the key path of each leaf as input argument as well.- Parameters:
f (Callable[..., Any]) – function that takes
2 + len(rest)arguments, aka. the key path and each corresponding leaves of the pytrees.tree (Any) – a pytree to be mapped over, with each leaf’s key path as the first positional argument and the leaf itself as the second argument to
f.*rest (Any) – a tuple of pytrees, each of which has the same structure as
treeor hastreeas a prefix.is_leaf (Callable[..., bool] | None)
is_leaf_takes_path (bool)
- Returns:
A new pytree with the same structure as
treebut with the value at each leaf given byf(kp, x, *xs)wherekpis the key path of the leaf at the corresponding leaf intree,xis the leaf value andxsis the tuple of values at corresponding nodes inrest.- Return type:
Any
Examples
>>> import jax >>> jax.tree.map_with_path(lambda path, x: x + path[0].idx, [1, 2, 3]) [1, 3, 5]