jax.numpy.size#
- jax.numpy.size(a, axis=None)[source]#
Return number of elements along a given axis.
JAX implementation of
numpy.size(). Unlikenp.size, this function raises aTypeErrorif the input is a collection such as a list or tuple.- Parameters:
a (ArrayLike | SupportsSize | SupportsShape) – array-like object, or any object with a
sizeattribute whenaxisis not specified, or with ashapeattribute whenaxisis specified.axis (int | Sequence[int] | None) – optional integer or sequence of integers indicating which axis or axes to count elements along.
None(the default) returns the total number of elements.
- Returns:
An integer specifying the number of elements in
a.- Return type:
Examples
Size for arrays:
>>> x = jnp.arange(10) >>> jnp.size(x) 10 >>> y = jnp.ones((2, 3)) >>> jnp.size(y) 6 >>> jnp.size(y, axis=1) 3 >>> jnp.size(y, axis=(1,)) 3 >>> jnp.size(y, axis=(0, 1)) 6
This also works for scalars:
>>> jnp.size(3.14) 1
For arrays, this can also be accessed via the
jax.Array.sizeproperty:>>> y.size 6