jax.numpy.unique_all#

jax.numpy.unique_all(x, /, *, size=None, fill_value=None)[source]#

Return unique values from x, along with indices, inverse indices, and counts.

JAX implementation of numpy.unique_all(); this is equivalent to calling jax.numpy.unique() with return_index, return_inverse, return_counts, and equal_nan set to True.

Because the size of the output of unique_all is data-dependent, the function is not typically compatible with jit() and other JAX transformations. The JAX version adds the optional size argument which must be specified statically for jnp.unique to be used in such contexts.

Parameters:
  • x (ArrayLike) – N-dimensional array from which unique values will be extracted.

  • size (int | None) – if specified, return only the first size sorted unique elements. If there are fewer unique elements than size indicates, the return value will be padded with fill_value.

  • fill_value (ArrayLike | None) – when size is specified and there are fewer than the indicated number of elements, fill the remaining entries fill_value. Defaults to the minimum unique value.

Returns:

  • values:

    an array of shape (n_unique,) containing the unique values from x.

  • indices:

    An array of shape (n_unique,). Contains the indices of the first occurrence of each unique value in x. For 1D inputs, x[indices] is equivalent to values.

  • inverse_indices:

    An array of shape x.shape. Contains the indices within values of each value in x. For 1D inputs, values[inverse_indices] is equivalent to x.

  • counts:

    An array of shape (n_unique,). Contains the number of occurrences of each unique value in x.

Return type:

A tuple (values, indices, inverse_indices, counts), with the following properties

See also

Examples

Here we compute the unique values in a 1D array:

>>> x = jnp.array([3, 4, 1, 3, 1])
>>> result = jnp.unique_all(x)

The result is a NamedTuple with four named attributes. The values attribute contains the unique values from the array:

>>> result.values
Array([1, 3, 4], dtype=int32)

The indices attribute contains the indices of the unique values within the input array:

>>> result.indices
Array([2, 0, 1], dtype=int32)
>>> jnp.all(result.values == x[result.indices])
Array(True, dtype=bool)

The inverse_indices attribute contains the indices of the input within values:

>>> result.inverse_indices
Array([1, 2, 0, 1, 0], dtype=int32)
>>> jnp.all(x == result.values[result.inverse_indices])
Array(True, dtype=bool)

The counts attribute contains the counts of each unique value in the input:

>>> result.counts
Array([2, 2, 1], dtype=int32)

For examples of the size and fill_value arguments, see jax.numpy.unique().