jax.numpy.where#
- jax.numpy.where(condition, x=None, y=None, /, *, size=None, fill_value=None)[source]#
Select elements from two arrays based on a condition.
JAX implementation of
numpy.where().Note
when only
conditionis provided,jnp.where(condition)is equivalent tojnp.nonzero(condition). For that case, refer to the documentation ofjax.numpy.nonzero(). The docstring below focuses on the case wherexandyare specified.The three-term version of
jnp.wherelowers tojax.lax.select().- Parameters:
condition – boolean array. Must be broadcast-compatible with
xandywhen they are specified.x – arraylike. Should be broadcast-compatible with
conditionandy, and typecast-compatible withy.y – arraylike. Should be broadcast-compatible with
conditionandx, and typecast-compatible withx.size – integer, only referenced when
xandyareNone. For details, seejax.numpy.nonzero().fill_value – only referenced when
xandyareNone. For details, seejax.numpy.nonzero().
- Returns:
An array of dtype
jnp.result_type(x, y)with values drawn fromxwhereconditionis True, and fromywhere condition isFalse. IfxandyareNone, the function behaves differently; seejax.numpy.nonzero()for a description of the return type.
Notes
Special care is needed when the
xoryinput tojax.numpy.where()could have a value of NaN. Specifically, when a gradient is taken withjax.grad()(reverse-mode differentiation), a NaN in eitherxorywill propagate into the gradient, regardless of the value ofcondition. More information on this behavior and workarounds is available in the JAX FAQ.Examples
When
xandyare not provided,wherebehaves equivalently tojax.numpy.nonzero():>>> x = jnp.arange(10) >>> jnp.where(x > 4) (Array([5, 6, 7, 8, 9], dtype=int32),) >>> jnp.nonzero(x > 4) (Array([5, 6, 7, 8, 9], dtype=int32),)
When
xandyare provided,whereselects between them based on the specified condition:>>> jnp.where(x > 4, x, 0) Array([0, 0, 0, 0, 0, 5, 6, 7, 8, 9], dtype=int32)