jax.numpy.bincount#
- jax.numpy.bincount(x, weights=None, minlength=0, *, length=None)[source]#
Count the number of occurrences of each value in an integer array.
JAX implementation of
numpy.bincount().For an array of non-negative integers
x, this function returns an arraycountsof sizex.max() + 1, such thatcounts[i]contains the number of occurrences of the valueiinx.The JAX version has a few differences from the NumPy version:
In NumPy, passing an array
xwith negative entries will result in an error. In JAX, negative values are clipped to zero.JAX adds an optional
lengthparameter which can be used to statically specify the length of the output array so that this function can be used with transformations likejax.jit(). In this case, items larger than length + 1 will be dropped.
- Parameters:
x (ArrayLike) – 1-dimensional array of non-negative integers
weights (ArrayLike | None) – optional array of weights associated with
x. If not specified, the weight for each entry will be1.minlength (int) – the minimum length of the output counts array.
length (int | None) – the length of the output counts array. Must be specified statically for
bincountto be used withjax.jit()and other JAX transformations.
- Returns:
An array of counts or summed weights reflecting the number of occurrences of values in
x.- Return type:
Examples
Basic bincount:
>>> x = jnp.array([1, 1, 2, 3, 3, 3]) >>> jnp.bincount(x) Array([0, 2, 1, 3], dtype=int32)
Weighted bincount:
>>> weights = jnp.array([1, 2, 3, 4, 5, 6]) >>> jnp.bincount(x, weights) Array([ 0, 3, 3, 15], dtype=int32)
Specifying a static
lengthmakes this jit-compatible:>>> jit_bincount = jax.jit(jnp.bincount, static_argnames=['length']) >>> jit_bincount(x, length=5) Array([0, 2, 1, 3, 0], dtype=int32)
Any negative numbers are clipped to the first bin, and numbers beyond the specified
lengthare dropped:>>> x = jnp.array([-1, -1, 1, 3, 10]) >>> jnp.bincount(x, length=5) Array([2, 1, 0, 1, 0], dtype=int32)