jax.numpy.roots#
- jax.numpy.roots(p, *, strip_zeros=True)[source]#
Returns the roots of a polynomial given the coefficients
p.JAX implementations of
numpy.roots().- Parameters:
p (ArrayLike) – Array of polynomial coefficients having rank-1.
strip_zeros (bool) – bool, default=True. If True, then leading zeros in the coefficients will be stripped, similar to
numpy.roots(). If set to False, leading zeros will not be stripped, and undefined roots will be represented by NaN values in the function output.strip_zerosmust be set toFalsefor the function to be compatible withjax.jit()and other JAX transformations.
- Returns:
An array containing the roots of the polynomial.
- Return type:
Note
Unlike
np.rootsof this function, thejnp.rootsreturns the roots in a complex array regardless of the values of the roots.See also
jax.numpy.poly(): Finds the polynomial coefficients of the given sequence of roots.jax.numpy.polyfit(): Least squares polynomial fit to data.jax.numpy.polyval(): Evaluate a polynomial at specific values.
Examples
>>> coeffs = jnp.array([0, 1, 2])
The default behavior matches numpy and strips leading zeros:
>>> jnp.roots(coeffs) Array([-2.+0.j], dtype=complex64)
With
strip_zeros=False, extra roots are set to NaN:>>> jnp.roots(coeffs, strip_zeros=False) Array([-2. +0.j, nan+nanj], dtype=complex64)