jax.vjp#
- jax.vjp(fun: Callable[..., T], *primals: Any, has_aux: Literal[False] = False, reduce_axes: Sequence[AxisName] = ()) tuple[T, Callable][source]#
- jax.vjp(fun: Callable[..., tuple[T, U]], *primals: Any, has_aux: Literal[True], reduce_axes: Sequence[AxisName] = ()) tuple[T, Callable, U]
Compute a (reverse-mode) vector-Jacobian product of
fun.grad()is implemented as a special case ofvjp().- Parameters:
fun – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars.
primals – A sequence of primal values at which the Jacobian of
funshould be evaluated. The number ofprimalsshould be equal to the number of positional parameters offun. Each primal value should be an array, a scalar, or a pytree (standard Python containers) thereof.has_aux – Optional, bool. Indicates whether
funreturns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.
- Returns:
If
has_auxisFalse, returns a(primals_out, vjpfun)pair, whereprimals_outisfun(*primals). Ifhas_auxisTrue, returns a(primals_out, vjpfun, aux)tuple whereauxis the auxiliary data returned byfun.vjpfunis a function from a cotangent vector with the same shape asprimals_outto a tuple of cotangent vectors with the same number and shapes asprimals, representing the vector-Jacobian product offunevaluated atprimals.
>>> import jax >>> >>> def f(x, y): ... return jax.numpy.sin(x), jax.numpy.cos(y) ... >>> primals, f_vjp = jax.vjp(f, 0.5, 1.0) >>> xbar, ybar = f_vjp((-0.7, 0.3)) >>> print(xbar) -0.61430776 >>> print(ybar) -0.2524413