jax.numpy.searchsorted#
- jax.numpy.searchsorted(a, v, side='left', sorter=None, *, method='scan')[source]#
Perform a binary search within a sorted array.
JAX implementation of
numpy.searchsorted().This will return the indices within a sorted array
awhere values invcan be inserted to maintain its sort order.- Parameters:
a (ArrayLike) – one-dimensional array, assumed to be in sorted order unless
sorteris specified.v (ArrayLike) – N-dimensional array of query values
side (str) –
'left'(default) or'right'; specifies whether insertion indices will be to the left or the right in case of ties.sorter (ArrayLike | None) – optional array of indices specifying the sort order of
a. If specified, then the algorithm assumes thata[sorter]is in sorted order.method (str) – one of
'scan'(default),'scan_unrolled','sort'or'compare_all'. See Note below.
- Returns:
Array of insertion indices of shape
v.shape.- Return type:
Note
The
methodargument controls the algorithm used to compute the insertion indices.'scan'(the default) tends to be more performant on CPU, particularly whenais very large.'scan_unrolled'is more performant on GPU at the expense of additional compile time.'sort'is often more performant on accelerator backends like GPU and TPU, particularly whenvis very large.'compare_all'tends to be the most performant whenais very small.
Examples
Searching for a single value:
>>> a = jnp.array([1, 2, 2, 3, 4, 5, 5]) >>> jnp.searchsorted(a, 2) Array(1, dtype=int32) >>> jnp.searchsorted(a, 2, side='right') Array(3, dtype=int32)
Searching for a batch of values:
>>> vals = jnp.array([0, 3, 8, 1.5, 2]) >>> jnp.searchsorted(a, vals) Array([0, 3, 7, 1, 1], dtype=int32)
Optionally, the
sorterargument can be used to find insertion indices into an array sorted viajax.numpy.argsort():>>> a = jnp.array([4, 3, 5, 1, 2]) >>> sorter = jnp.argsort(a) >>> jnp.searchsorted(a, vals, sorter=sorter) Array([0, 2, 5, 1, 1], dtype=int32)
The result is equivalent to passing the sorted array:
>>> jnp.searchsorted(jnp.sort(a), vals) Array([0, 2, 5, 1, 1], dtype=int32)