jax.tree_util.tree_structure#
- jax.tree_util.tree_structure(tree, is_leaf=None)[source]#
Alias of
jax.tree.structure().- Parameters:
tree (Any)
is_leaf (None | Callable[[Any], bool])
- Return type:
PyTreeDef
Alias of jax.tree.structure().
tree (Any)
is_leaf (None | Callable[[Any], bool])
PyTreeDef