jax.lax.associative_scan#
- jax.lax.associative_scan(fn, elems, reverse=False, axis=0)[source]#
Performs a scan with an associative binary operation, in parallel.
For an introduction to associative scans, see [BLE1990].
- Parameters:
fn (Callable) –
A Python callable implementing an associative binary operation with signature
r = fn(a, b). Function fn must be associative, i.e., it must satisfy the equationfn(a, fn(b, c)) == fn(fn(a, b), c).The inputs and result are (possibly nested Python tree structures of) array(s) matching
elems. Each array has a dimension in place of theaxisdimension. fn should be applied elementwise over theaxisdimension (for example, by usingjax.vmap()over the elementwise function.)The result
rhas the same shape (and structure) as the two inputsaandb.elems – A (possibly nested Python tree structure of) array(s), each with an
axisdimension of sizenum_elems.reverse (bool) – A boolean stating if the scan should be reversed with respect to the
axisdimension.axis (int) – an integer identifying the axis over which the scan should occur.
- Returns:
A (possibly nested Python tree structure of) array(s) of the same shape and structure as
elems, in which thek’th element ofaxisis the result of recursively applyingfnto combine the firstkelements ofelemsalongaxis. For example, givenelems = [a, b, c, ...], the result would be[a, fn(a, b), fn(fn(a, b), c), ...].If
elems = [..., x, y, z]andreverseis true, the result is[..., f(f(z, y), x), f(z, y), z].
Example 1: partial sums of an array of numbers:
>>> lax.associative_scan(jnp.add, jnp.arange(0, 4)) Array([0, 1, 3, 6], dtype=int32)
Example 2: partial products of an array of matrices
>>> mats = jax.random.uniform(jax.random.key(0), (4, 2, 2)) >>> partial_prods = lax.associative_scan(jnp.matmul, mats) >>> partial_prods.shape (4, 2, 2)
Example 3: reversed partial sums of an array of numbers
>>> lax.associative_scan(jnp.add, jnp.arange(0, 4), reverse=True) Array([6, 6, 5, 3], dtype=int32)
[BLE1990]Blelloch, Guy E. 1990. “Prefix Sums and Their Applications.”, Technical Report CMU-CS-90-190, School of Computer Science, Carnegie Mellon University.