jax.jvp#
- jax.jvp(fun, primals, tangents, has_aux=False)[source]#
Computes a (forward-mode) Jacobian-vector product of
fun.- Parameters:
fun (Callable) – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars.
primals – The primal values at which the Jacobian of
funshould be evaluated. Should be either a tuple or a list of arguments, and its length should be equal to the number of positional parameters offun.tangents – The tangent vector for which the Jacobian-vector product should be evaluated. Should be either a tuple or a list of tangents, with the same tree structure and array shapes as
primals.has_aux (bool) – Optional, bool. Indicates whether
funreturns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.
- Returns:
If
has_auxisFalse, returns a(primals_out, tangents_out)pair, whereprimals_outisfun(*primals), andtangents_outis the Jacobian-vector product offunctionevaluated atprimalswithtangents. Thetangents_outvalue has the same Python tree structure and shapes asprimals_out. Ifhas_auxisTrue, returns a(primals_out, tangents_out, aux)tuple whereauxis the auxiliary data returned byfun.- Return type:
tuple[Any, …]
For example:
>>> import jax >>> >>> primals, tangents = jax.jvp(jax.numpy.sin, (0.1,), (0.2,)) >>> print(primals) 0.09983342 >>> print(tangents) 0.19900084