jax.numpy.extract#
- jax.numpy.extract(condition, arr, *, size=None, fill_value=0)[source]#
Return the elements of an array that satisfy a condition.
JAX implementation of
numpy.extract().- Parameters:
condition (ArrayLike) – array of conditions. Will be converted to boolean and flattened to 1D.
arr (ArrayLike) – array of values to extract. Will be flattened to 1D.
size (int | None) – optional static size for output. Must be specified in order for
extractto be compatible with JAX transformations likejit()orvmap().fill_value (ArrayLike) – if
sizeis specified, fill padded entries with this value (default: 0).
- Returns:
1D array of extracted entries . If
sizeis specified, the result will have shape(size,)and be right-padded withfill_value. Ifsizeis not specified, the output shape will depend on the number of True entries incondition.- Return type:
Notes
This function does not require strict shape agreement between
conditionandarr. Ifcondition.size > arr.size, thenconditionwill be truncated, and ifarr.size > condition.size, thenarrwill be truncated.See also
jax.numpy.compress(): multi-dimensional version ofextract.Examples
Extract values from a 1D array:
>>> x = jnp.array([1, 2, 3, 4, 5, 6]) >>> mask = (x % 2 == 0) >>> jnp.extract(mask, x) Array([2, 4, 6], dtype=int32)
In the simplest case, this is equivalent to boolean indexing:
>>> x[mask] Array([2, 4, 6], dtype=int32)
For use with JAX transformations, you can pass the
sizeargument to specify a static shape for the output, along with an optionalfill_valuethat defaults to zero:>>> jnp.extract(mask, x, size=len(x), fill_value=0) Array([2, 4, 6, 0, 0, 0], dtype=int32)
Notice that unlike with boolean indexing,
extractdoes not require strict agreement between the sizes of the array and condition, and will effectively truncate both to the minimum size:>>> short_mask = jnp.array([False, True]) >>> jnp.extract(short_mask, x) Array([2], dtype=int32) >>> long_mask = jnp.array([True, False, True, False, False, False, False, False]) >>> jnp.extract(long_mask, x) Array([1, 3], dtype=int32)