jax.numpy.compress#
- jax.numpy.compress(condition, a, axis=None, *, size=None, fill_value=0, out=None)[source]#
Compress an array along a given axis using a boolean condition.
JAX implementation of
numpy.compress()
.- Parameters:
condition (ArrayLike) – 1-dimensional array of conditions. Will be converted to boolean.
a (ArrayLike) – N-dimensional array of values.
axis (int | None) – axis along which to compress. If None (default) then
a
will be flattened, and axis will be set to 0.size (int | None) – optional static size for output. Must be specified in order for
compress
to be compatible with JAX transformations likejit()
orvmap()
.fill_value (ArrayLike) – if
size
is specified, fill padded entries with this value (default: 0).out (None) – not implemented by JAX.
- Returns:
An array of dimension
a.ndim
, compressed along the specified axis.- Return type:
See also
jax.numpy.extract()
: 1D version ofcompress
.jax.Array.compress()
: equivalent functionality as an array method.
Notes
This function does not require strict shape agreement between
condition
anda
. Ifcondition.size > a.shape[axis]
, thencondition
will be truncated, and ifa.shape[axis] > condition.size
, thena
will be truncated.Examples
Compressing along the rows of a 2D array:
>>> a = jnp.array([[1, 2, 3, 4], ... [5, 6, 7, 8], ... [9, 10, 11, 12]]) >>> condition = jnp.array([True, False, True]) >>> jnp.compress(condition, a, axis=0) Array([[ 1, 2, 3, 4], [ 9, 10, 11, 12]], dtype=int32)
For convenience, you can equivalently use the
compress()
method of JAX arrays:>>> a.compress(condition, axis=0) Array([[ 1, 2, 3, 4], [ 9, 10, 11, 12]], dtype=int32)
Note that the condition need not match the shape of the specified axis; here we compress the columns with the length-3 condition. Values beyond the size of the condition are ignored:
>>> jnp.compress(condition, a, axis=1) Array([[ 1, 3], [ 5, 7], [ 9, 11]], dtype=int32)
The optional
size
argument lets you specify a static output size so that the output is statically-shaped, and so this function can be used with transformations likejit()
andvmap()
:>>> f = lambda c, a: jnp.extract(c, a, size=len(a), fill_value=0) >>> mask = (a % 3 == 0) >>> jax.vmap(f)(mask, a) Array([[ 3, 0, 0, 0], [ 6, 0, 0, 0], [ 9, 12, 0, 0]], dtype=int32)