jax.numpy.invert#

jax.numpy.invert(x, /)[source]#

Compute the bitwise inversion of an input.

JAX implementation of numpy.invert(). This function provides the implementation of the ~ operator for JAX arrays.

Parameters:

x (ArrayLike) – input array, must be boolean or integer typed.

Returns:

An array of the same shape and dtype as `x, with the bits inverted.

Return type:

Array

See also

Examples

>>> x = jnp.arange(5, dtype='uint8')
>>> print(x)
[0 1 2 3 4]
>>> print(jnp.invert(x))
[255 254 253 252 251]

This function implements the unary ~ operator for JAX arrays:

>>> print(~x)
[255 254 253 252 251]

invert() operates bitwise on the input, and so the meaning of its output may be more clear by showing the bitwise representation:

>>> with jnp.printoptions(formatter={'int': lambda x: format(x, '#010b')}):
...   print(f"{x  = }")
...   print(f"{~x = }")
x  = Array([0b00000000, 0b00000001, 0b00000010, 0b00000011, 0b00000100], dtype=uint8)
~x = Array([0b11111111, 0b11111110, 0b11111101, 0b11111100, 0b11111011], dtype=uint8)

For boolean inputs, invert() is equivalent to logical_not():

>>> x = jnp.array([True, False, True, True, False])
>>> jnp.invert(x)
Array([False,  True, False, False,  True], dtype=bool)