jax.numpy.place#
- jax.numpy.place(arr, mask, vals, *, inplace=True)[source]#
Update array elements based on a mask.
JAX implementation of
numpy.place().The semantics of
numpy.place()are to modify arrays in-place, which is not possible for JAX’s immutable arrays. The JAX version returns a modified copy of the input, and adds theinplaceparameter which must be set to False` by the user as a reminder of this API difference.- Parameters:
arr (Array | ndarray | bool | number | bool | int | float | complex) – array into which values will be placed.
mask (Array | ndarray | bool | number | bool | int | float | complex) – boolean mask with the same size as
arr.vals (Array | ndarray | bool | number | bool | int | float | complex) – values to be inserted into
arrat the locations indicated by mask. If too many values are supplied, they will be truncated. If not enough values are supplied, they will be repeated.inplace (bool) – must be set to False to indicate that the input is not modified in-place, but rather a modified copy is returned.
- Returns:
A copy of
arrwith masked values set to entries from vals.- Return type:
See also
jax.numpy.put(): put elements into an array at numerical indices.jax.numpy.ndarray.at(): array updates using NumPy-style indexing
Examples
>>> x = jnp.zeros((3, 5), dtype=int) >>> mask = (jnp.arange(x.size) % 3 == 0).reshape(x.shape) >>> mask Array([[ True, False, False, True, False], [False, True, False, False, True], [False, False, True, False, False]], dtype=bool)
Placing a scalar value:
>>> jnp.place(x, mask, 1, inplace=False) Array([[1, 0, 0, 1, 0], [0, 1, 0, 0, 1], [0, 0, 1, 0, 0]], dtype=int32)
In this case,
jnp.placeis similar to the masked array update syntax:>>> x.at[mask].set(1) Array([[1, 0, 0, 1, 0], [0, 1, 0, 0, 1], [0, 0, 1, 0, 0]], dtype=int32)
placediffers when placing values from an array. The array is repeated to fill the masked entries:>>> vals = jnp.array([1, 3, 5]) >>> jnp.place(x, mask, vals, inplace=False) Array([[1, 0, 0, 3, 0], [0, 5, 0, 0, 1], [0, 0, 3, 0, 0]], dtype=int32)