jax.numpy.unique_counts#
- jax.numpy.unique_counts(x, /, *, size=None, fill_value=None)[source]#
Return unique values from x, along with counts.
JAX implementation of
numpy.unique_counts(); this is equivalent to callingjax.numpy.unique()with return_counts and equal_nan set to True.Because the size of the output of
unique_countsis data-dependent, the function is not typically compatible withjit()and other JAX transformations. The JAX version adds the optionalsizeargument which must be specified statically forjnp.uniqueto be used in such contexts.- Parameters:
x (ArrayLike) – N-dimensional array from which unique values will be extracted.
size (int | None) – if specified, return only the first
sizesorted unique elements. If there are fewer unique elements thansizeindicates, the return value will be padded withfill_value.fill_value (ArrayLike | None) – when
sizeis specified and there are fewer than the indicated number of elements, fill the remaining entriesfill_value. Defaults to the minimum unique value.
- Returns:
values:an array of shape
(n_unique,)containing the unique values fromx.
counts:An array of shape
(n_unique,). Contains the number of occurrences of each unique value inx.
- Return type:
A tuple
(values, counts), with the following properties
See also
jax.numpy.unique(): general function for computing unique values.jax.numpy.unique_values(): compute onlyvalues.jax.numpy.unique_inverse(): compute onlyvaluesandinverse.jax.numpy.unique_all(): computevalues,indices,inverse_indices, andcounts.
Examples
Here we compute the unique values in a 1D array:
>>> x = jnp.array([3, 4, 1, 3, 1]) >>> result = jnp.unique_counts(x)
The result is a
NamedTuplewith two named attributes. Thevaluesattribute contains the unique values from the array:>>> result.values Array([1, 3, 4], dtype=int32)
The
countsattribute contains the counts of each unique value in the input:>>> result.counts Array([2, 2, 1], dtype=int32)
For examples of the
sizeandfill_valuearguments, seejax.numpy.unique().