jax.numpy.transpose#
- jax.numpy.transpose(a, axes=None)[source]#
Return a transposed version of an N-dimensional array.
JAX implementation of
numpy.transpose(), implemented in terms ofjax.lax.transpose().- Parameters:
a (ArrayLike) – input array
axes (Sequence[int] | None) – optionally specify the permutation using a length-a.ndim sequence of integers
isatisfying0 <= i < a.ndim. Defaults torange(a.ndim)[::-1], i.e. reverses the order of all axes.
- Returns:
transposed copy of the array.
- Return type:
See also
jax.Array.transpose(): equivalent function via anArraymethod.jax.Array.T: equivalent function via anArrayproperty.jax.numpy.matrix_transpose(): transpose the last two axes of an array. This is suitable for working with batched 2D matrices.jax.numpy.swapaxes(): swap any two axes in an array.jax.numpy.moveaxis(): move an axis to another position in the array.
Note
Unlike
numpy.transpose(),jax.numpy.transpose()will return a copy rather than a view of the input array. However, under JIT, the compiler will optimize-away such copies when possible, so this doesn’t have performance impacts in practice.Examples
For a 1D array, the transpose is the identity:
>>> x = jnp.array([1, 2, 3, 4]) >>> jnp.transpose(x) Array([1, 2, 3, 4], dtype=int32)
For a 2D array, the transpose is a matrix transpose:
>>> x = jnp.array([[1, 2], ... [3, 4]]) >>> jnp.transpose(x) Array([[1, 3], [2, 4]], dtype=int32)
For an N-dimensional array, the transpose reverses the order of the axes:
>>> x = jnp.zeros(shape=(3, 4, 5)) >>> jnp.transpose(x).shape (5, 4, 3)
The
axesargument can be specified to change this default behavior:>>> jnp.transpose(x, (0, 2, 1)).shape (3, 5, 4)
Since swapping the last two axes is a common operation, it can be done via its own API,
jax.numpy.matrix_transpose():>>> jnp.matrix_transpose(x).shape (3, 5, 4)
For convenience, transposes may also be performed using the
jax.Array.transpose()method or thejax.Array.Tproperty:>>> x = jnp.array([[1, 2], ... [3, 4]]) >>> x.transpose() Array([[1, 3], [2, 4]], dtype=int32) >>> x.T Array([[1, 3], [2, 4]], dtype=int32)