jax.tree.reduce_associative#

jax.tree.reduce_associative(operation, tree, *, identity=<jax._src.tree_util.Unspecified object>, is_leaf=None)[source]#

Perform a reduction over a pytree with an associative binary operation.

This function exploits the fact that the operation is associative to perform the reduction in parallel (logarithmic depth).

Parameters:
  • operation (Callable[[T, T], T]) – the associative binary operation

  • tree (Any) – the pytree to reduce

  • identity (T | tree_util.Unspecified) – the identity element of the associative binary operation. This is used only when the tree is empty. It is optional otherwise.

  • is_leaf (Callable[[Any], bool] | None) – an optionally specified function that will be called at each flattening step. It should return a boolean, which indicates whether the flattening should traverse the current object, or if it should be stopped immediately, with the whole subtree being treated as a leaf.

Returns:

the reduced value

Return type:

result

Examples

>>> import jax
>>> import operator
>>> jax.tree.reduce_associative(operator.add, [1, (2, 3), [4, 5, 6]])
21