jax.lax.reduce_window#
- jax.lax.reduce_window(operand, init_value, computation, window_dimensions, window_strides=None, padding='VALID', base_dilation=None, window_dilation=None)[source]#
Reduction over padded windows.
Wraps XLA’s ReduceWindowWithGeneralPadding operator.
- Parameters:
operand (Any) – input array or tree of arrays.
init_value (Any) – value or tree of values. Tree structure must match that of
operand.computation (Callable) – callable function over which to reduce. Input and output must be a tree of the same structure as
operand.window_dimensions (core.Shape) – sequence of integers specifying the window size.
window_strides (Sequence[int] | None) – optional sequence of integers specifying the strides, of the same length as
window_dimensions. Default (None) indicates a unit stride in each window dimension.padding (str | Sequence[tuple[int, int]]) – string or sequence of integer tuples specifying the type of padding to use (default: “VALID”). If a string, must be one of “VALID”, “SAME”, or “SAME_LOWER”. See the
jax.lax.padtype_to_pads()utility.base_dilation (Sequence[int] | None) – optional sequence of integers for base dilation values, of the same length as
window_dimensions. Default (None) indicates unit dilation in each window dimension.window_dilation (Sequence[int] | None) – optional sequence of integers for window dilation values, of the same length as
window_dimensions. Default (None) indicates unit dilation in each window dimension.
- Returns:
A tree of arrays with the same structure as
operand.- Return type:
Any
Example
Here is a simple example of a windowed product over pairs in a 1-dimensional array:
>>> import jax >>> x = jax.numpy.arange(10, dtype='float32') >>> x Array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.], dtype=float32)
>>> initial = jax.numpy.float32(1) >>> jax.lax.reduce_window(x, initial, jax.lax.mul, window_dimensions=(2,)) Array([ 0., 2., 6., 12., 20., 30., 42., 56., 72.], dtype=float32)