jax.numpy.argpartition#
- jax.numpy.argpartition(a, kth, axis=-1)[source]#
Returns indices that partially sort an array.
JAX implementation of
numpy.argpartition(). The JAX version differs from NumPy in the treatment of NaN entries: NaNs which have the negative bit set are sorted to the beginning of the array.- Parameters:
- Returns:
Indices which partition
aat thekthvalue alongaxis. The entries beforekthare indices of values smaller thantake(a, kth, axis), and entries afterkthare indices of values larger thantake(a, kth, axis)- Return type:
Note
The JAX version requires the
kthargument to be a static integer rather than a general array. This is implemented via two calls tojax.lax.top_k(). If you’re only accessing the top or bottom k values of the output, it may be more efficient to calljax.lax.top_k()directly.See also
jax.numpy.partition(): direct partial sortjax.numpy.argsort(): full indirect sortjax.lax.top_k(): directly find the top k entriesjax.lax.approx_max_k(): compute the approximate top k entriesjax.lax.approx_min_k(): compute the approximate bottom k entries
Examples
>>> x = jnp.array([6, 8, 4, 3, 1, 9, 7, 5, 2, 3]) >>> kth = 4 >>> idx = jnp.argpartition(x, kth) >>> idx Array([4, 8, 3, 9, 2, 0, 1, 5, 6, 7], dtype=int32)
The result is a sequence of indices that partially sort the input. All indices before
kthare of values smaller than the pivot value, and all indices afterkthare of values larger than the pivot value:>>> x_partitioned = x[idx] >>> smallest_values = x_partitioned[:kth] >>> pivot_value = x_partitioned[kth] >>> largest_values = x_partitioned[kth + 1:] >>> print(smallest_values, pivot_value, largest_values) [1 2 3 3] 4 [6 8 9 7 5]
Notice that among
smallest_valuesandlargest_values, the returned order is arbitrary and implementation-dependent.