jax.lax.scatter#
- jax.lax.scatter(operand, scatter_indices, updates, dimension_numbers, *, indices_are_sorted=False, unique_indices=False, mode=None)[source]#
Scatter-update operator.
Wraps XLA’s Scatter operator, where updates replace values from operand.
If multiple updates are performed to the same index of operand, they may be applied in any order.
scatter()
is a low-level operator with complicated semantics, and most JAX users will never need to call it directly. Instead, you should prefer usingjax.numpy.ndarray.at()
for more familiary NumPy-style indexing syntax.- Parameters:
operand (ArrayLike) – an array to which the scatter should be applied
scatter_indices (ArrayLike) – an array that gives the indices in operand to which each update in updates should be applied.
updates (ArrayLike) – the updates that should be scattered onto operand.
dimension_numbers (ScatterDimensionNumbers) – a lax.ScatterDimensionNumbers object that describes how dimensions of operand, start_indices, updates and the output relate.
indices_are_sorted (bool) – whether scatter_indices is known to be sorted. If true, may improve performance on some backends.
unique_indices (bool) – whether the elements to be updated in
operand
are guaranteed to not overlap with each other. If true, may improve performance on some backends. JAX does not check this promise: if the updated elements overlap whenunique_indices
isTrue
the behavior is undefined.mode (str | GatherScatterMode | None) – how to handle indices that are out of bounds: when set to ‘clip’, indices are clamped so that the slice is within bounds, and when set to ‘fill’ or ‘drop’ out-of-bounds updates are dropped. The behavior for out-of-bounds indices when set to ‘promise_in_bounds’ is implementation-defined.
- Returns:
An array containing the values of operand and the scattered updates.
- Return type:
Examples
As mentioned above, you should basically never use
scatter()
directly, and instead perform scatter-style operations using NumPy-style indexing expressions viajax.numpy.ndarray.at
.Here is and example of updating entries in an array using
jax.numpy.ndarray.at
, which lowers to an XLA Scatter operation:>>> x = jnp.ones(5) >>> indices = jnp.array([1, 2, 4]) >>> values = jnp.array([2.0, 3.0, 4.0])
>>> x.at[indices].set(values) Array([1., 2., 3., 1., 4.], dtype=float32)
This syntax also supports several of the optional arguments to
scatter()
, for example:>>> x.at[indices].set(values, indices_are_sorted=True, mode='promise_in_bounds') Array([1., 2., 3., 1., 4.], dtype=float32)
By comparison, here is the equivalent function call using
scatter()
directly, which is not something typical users should ever need to do:>>> lax.scatter(x, indices[:, None], values, ... dimension_numbers=lax.ScatterDimensionNumbers( ... update_window_dims=(), ... inserted_window_dims=(0,), ... scatter_dims_to_operand_dims=(0,)), ... indices_are_sorted=True, ... mode=lax.GatherScatterMode.PROMISE_IN_BOUNDS) Array([1., 2., 3., 1., 4.], dtype=float32)