jax.numpy.compress#
- jax.numpy.compress(condition, a, axis=None, *, size=None, fill_value=0, out=None)[source]#
Compress an array along a given axis using a boolean condition.
JAX implementation of
numpy.compress().- Parameters:
condition (ArrayLike) – 1-dimensional array of conditions. Will be converted to boolean.
a (ArrayLike) – N-dimensional array of values.
axis (int | None) – axis along which to compress. If None (default) then
awill be flattened, and axis will be set to 0.size (int | None) – optional static size for output. Must be specified in order for
compressto be compatible with JAX transformations likejit()orvmap().fill_value (ArrayLike) – if
sizeis specified, fill padded entries with this value (default: 0).out (None) – not implemented by JAX.
- Returns:
An array of dimension
a.ndim, compressed along the specified axis.- Return type:
See also
jax.numpy.extract(): 1D version ofcompress.jax.Array.compress(): equivalent functionality as an array method.
Notes
This function does not require strict shape agreement between
conditionanda. Ifcondition.size > a.shape[axis], thenconditionwill be truncated, and ifa.shape[axis] > condition.size, thenawill be truncated.Examples
Compressing along the rows of a 2D array:
>>> a = jnp.array([[1, 2, 3, 4], ... [5, 6, 7, 8], ... [9, 10, 11, 12]]) >>> condition = jnp.array([True, False, True]) >>> jnp.compress(condition, a, axis=0) Array([[ 1, 2, 3, 4], [ 9, 10, 11, 12]], dtype=int32)
For convenience, you can equivalently use the
compress()method of JAX arrays:>>> a.compress(condition, axis=0) Array([[ 1, 2, 3, 4], [ 9, 10, 11, 12]], dtype=int32)
Note that the condition need not match the shape of the specified axis; here we compress the columns with the length-3 condition. Values beyond the size of the condition are ignored:
>>> jnp.compress(condition, a, axis=1) Array([[ 1, 3], [ 5, 7], [ 9, 11]], dtype=int32)
The optional
sizeargument lets you specify a static output size so that the output is statically-shaped, and so this function can be used with transformations likejit()andvmap():>>> f = lambda c, a: jnp.extract(c, a, size=len(a), fill_value=0) >>> mask = (a % 3 == 0) >>> jax.vmap(f)(mask, a) Array([[ 3, 0, 0, 0], [ 6, 0, 0, 0], [ 9, 12, 0, 0]], dtype=int32)