jax.numpy.roots#

jax.numpy.roots(p, *, strip_zeros=True)[source]#

Returns the roots of a polynomial given the coefficients p.

JAX implementations of numpy.roots().

Parameters:
  • p (ArrayLike) – Array of polynomial coefficients having rank-1.

  • strip_zeros (bool) – bool, default=True. If True, then leading zeros in the coefficients will be stripped, similar to numpy.roots(). If set to False, leading zeros will not be stripped, and undefined roots will be represented by NaN values in the function output. strip_zeros must be set to False for the function to be compatible with jax.jit() and other JAX transformations.

Returns:

An array containing the roots of the polynomial.

Return type:

Array

Note

Unlike np.roots of this function, the jnp.roots returns the roots in a complex array regardless of the values of the roots.

See also

Examples

>>> coeffs = jnp.array([0, 1, 2])

The default behavior matches numpy and strips leading zeros:

>>> jnp.roots(coeffs)
Array([-2.+0.j], dtype=complex64)

With strip_zeros=False, extra roots are set to NaN:

>>> jnp.roots(coeffs, strip_zeros=False)
Array([-2. +0.j, nan+nanj], dtype=complex64)