jax.lax.collapse#
- jax.lax.collapse(operand, start_dimension, stop_dimension=None)[source]#
Collapses dimensions of an array into a single dimension.
For example, if
operandis an array with shape[2, 3, 4],collapse(operand, 0, 2).shape == [6, 4]. The elements of the collapsed dimension are laid out major-to-minor, i.e., with the lowest-numbered dimension as the slowest varying dimension.- Parameters:
- Returns:
An array where dimensions
[start_dimension, stop_dimension)have been collapsed (raveled) into a single dimension.- Return type: