jax.numpy.sqrt#
- jax.numpy.sqrt(x, /)[source]#
Calculates element-wise non-negative square root of the input array.
JAX implementation of
numpy.sqrt.- Parameters:
x (ArrayLike) – input array or scalar.
- Returns:
An array containing the non-negative square root of the elements of
x.- Return type:
Note
For real-valued negative inputs,
jnp.sqrtproduces ananoutput.For complex-valued negative inputs,
jnp.sqrtproduces acomplexoutput.
See also
jax.numpy.square(): Calculates the element-wise square of the input.jax.numpy.power(): Calculates the element-wise basex1exponential ofx2.
Examples
>>> x = jnp.array([-8-6j, 1j, 4]) >>> with jnp.printoptions(precision=3, suppress=True): ... jnp.sqrt(x) Array([1. -3.j , 0.707+0.707j, 2. +0.j ], dtype=complex64) >>> jnp.sqrt(-1) Array(nan, dtype=float32, weak_type=True)