jax.numpy.shape#
- jax.numpy.shape(a)[source]#
Return the shape an array.
JAX implementation of
numpy.shape(). Unlikenp.shape, this function raises aTypeErrorif the input is a collection such as a list or tuple.- Parameters:
a (ArrayLike | SupportsShape) – array-like object, or any object with a
shapeattribute.- Returns:
An tuple of integers representing the shape of
a.- Return type:
Examples
Shape for arrays:
>>> x = jnp.arange(10) >>> jnp.shape(x) (10,) >>> y = jnp.ones((2, 3)) >>> jnp.shape(y) (2, 3)
This also works for scalars:
>>> jnp.shape(3.14) ()
For arrays, this can also be accessed via the
jax.Array.shapeproperty:>>> x.shape (10,)