jax.numpy.polyval#
- jax.numpy.polyval(p, x, *, unroll=16)[source]#
Evaluates the polynomial at specific values.
JAX implementations of
numpy.polyval().For the 1D-polynomial coefficients
pof lengthM, the function returns the value:\[p_0 x^{M - 1} + p_1 x^{M - 2} + ... + p_{M - 1}\]- Parameters:
p (ArrayLike) – An array of polynomial coefficients of shape
(M,).x (ArrayLike) – A number or an array of numbers.
unroll (int) – A number used to control the number of unrolled steps with
lax.scan. It must be specified statically.
- Returns:
An array of same shape as
x.- Return type:
Note
The
unrollparameter is JAX specific. It does not affect correctness but can have a major impact on performance for evaluating high-order polynomials. The parameter controls the number of unrolled steps withlax.scaninside thejnp.polyvalimplementation. Consider settingunroll=128(or even higher) to improve runtime performance on accelerators, at the cost of increased compilation time.See also
jax.numpy.polyfit(): Least squares polynomial fit.jax.numpy.poly(): Finds the coefficients of a polynomial with given roots.jax.numpy.roots(): Computes the roots of a polynomial for given coefficients.
Examples
>>> p = jnp.array([2, 5, 1]) >>> jnp.polyval(p, 3) Array(34., dtype=float32)
If
xis a 2D array,polyvalreturns 2D-array with same shape as that ofx:>>> x = jnp.array([[2, 1, 5], ... [3, 4, 7], ... [1, 3, 5]]) >>> jnp.polyval(p, x) Array([[ 19., 8., 76.], [ 34., 53., 134.], [ 8., 34., 76.]], dtype=float32)