jax.numpy.ravel#
- jax.numpy.ravel(a, order='C', *, out_sharding=None)[source]#
Flatten array into a 1-dimensional shape.
JAX implementation of
numpy.ravel(), implemented in terms ofjax.lax.reshape().ravel(arr, order=order)is equivalent toreshape(arr, -1, order=order).- Parameters:
a (ArrayLike) – array to be flattened.
order (str) –
'F'or'C', specifies whether the reshape should apply column-major (fortran-style,"F") or row-major (C-style,"C") order; default is"C". JAX does not support order=”A” or order=”K”.
- Returns:
flattened copy of input array.
- Return type:
Notes
Unlike
numpy.ravel(),jax.numpy.ravel()will return a copy rather than a view of the input array. However, under JIT, the compiler will optimize-away such copies when possible, so this doesn’t have performance impacts in practice.See also
jax.Array.ravel(): equivalent functionality via an array method.jax.numpy.reshape(): general array reshape.
Examples
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]])
By default, ravel in C-style, row-major order
>>> jnp.ravel(x) Array([1, 2, 3, 4, 5, 6], dtype=int32)
Optionally ravel in Fortran-style, column-major:
>>> jnp.ravel(x, order='F') Array([1, 4, 2, 5, 3, 6], dtype=int32)
For convenience, the same functionality is available via the
jax.Array.ravel()method:>>> x.ravel() Array([1, 2, 3, 4, 5, 6], dtype=int32)