jax.numpy.ndim#
- jax.numpy.ndim(a)[source]#
Return the number of dimensions of an array.
JAX implementation of
numpy.ndim(). Unlikenp.ndim, this function raises aTypeErrorif the input is a collection such as a list or tuple.- Parameters:
a (ArrayLike | SupportsNdim) – array-like object, or any object with an
ndimattribute.- Returns:
An integer specifying the number of dimensions of
a.- Return type:
Examples
Number of dimensions for arrays:
>>> x = jnp.arange(10) >>> jnp.ndim(x) 1 >>> y = jnp.ones((2, 3)) >>> jnp.ndim(y) 2
This also works for scalars:
>>> jnp.ndim(3.14) 0
For arrays, this can also be accessed via the
jax.Array.ndimproperty:>>> x.ndim 1