jax.numpy.partition#
- jax.numpy.partition(a, kth, axis=-1)[source]#
Returns a partially-sorted copy of an array.
JAX implementation of
numpy.partition(). The JAX version differs from NumPy in the treatment of NaN entries: NaNs which have the negative bit set are sorted to the beginning of the array.- Parameters:
- Returns:
A copy of
apartitioned at thekthvalue alongaxis. The entries beforekthare values smaller thantake(a, kth, axis), and entries afterkthare indices of values larger thantake(a, kth, axis)- Return type:
Note
The JAX version requires the
kthargument to be a static integer rather than a general array. This is implemented via two calls tojax.lax.top_k(). If you’re only accessing the top or bottom k values of the output, it may be more efficient to calljax.lax.top_k()directly.See also
jax.numpy.sort(): full sortjax.numpy.argpartition(): indirect partial sortjax.lax.top_k(): directly find the top k entriesjax.lax.approx_max_k(): compute the approximate top k entriesjax.lax.approx_min_k(): compute the approximate bottom k entries
Examples
>>> x = jnp.array([6, 8, 4, 3, 1, 9, 7, 5, 2, 3]) >>> kth = 4 >>> x_partitioned = jnp.partition(x, kth) >>> x_partitioned Array([1, 2, 3, 3, 4, 9, 8, 7, 6, 5], dtype=int32)
The result is a partially-sorted copy of the input. All values before
kthare of smaller than the pivot value, and all values afterkthare larger than the pivot value:>>> smallest_values = x_partitioned[:kth] >>> pivot_value = x_partitioned[kth] >>> largest_values = x_partitioned[kth + 1:] >>> print(smallest_values, pivot_value, largest_values) [1 2 3 3] 4 [9 8 7 6 5]
Notice that among
smallest_valuesandlargest_values, the returned order is arbitrary and implementation-dependent.