jax.custom_vjp.defvjp#
- custom_vjp.defvjp(fwd, bwd, symbolic_zeros=False, optimize_remat=False)[source]#
Define a custom VJP rule for the function represented by this instance.
- Parameters:
fwd (Callable[..., tuple[ReturnValue, Any]]) – a Python callable representing the forward pass of the custom VJP rule. When there are no
nondiff_argnums, thefwdfunction has the same input signature as the underlying primal function. It should return as output a pair, where the first element represents the primal output and the second element represents any “residual” values to store from the forward pass for use on the backward pass by the functionbwd. Input arguments and elements of the output pair may be arrays or nested tuples/lists/dicts thereof.bwd (Callable[..., tuple[Any, ...]]) – a Python callable representing the backward pass of the custom VJP rule. When there are no
nondiff_argnums, thebwdfunction takes two arguments, where the first is the “residual” values produced on the forward pass byfwd, and the second is the output cotangent with the same structure as the primal function output. The output ofbwdmust be a tuple of length equal to the number of arguments of the primal function, and the tuple elements may be arrays or nested tuples/lists/dicts thereof so as to match the structure of the primal input arguments.symbolic_zeros (bool) –
boolean, determining whether to indicate symbolic zeros to the
fwdandbwdrules. Enabling this option allows custom derivative rules to detect when certain inputs, and when certain output cotangents, are not involved in differentiation. IfTrue:fwdmust accept, in place of each leaf valuexin the pytree comprising an argument to the original function, an object (of typejax.custom_derivatives.CustomVJPPrimal) with two attributes instead:valueandperturbed. Thevaluefield is the original primal argument, andperturbedis a boolean. Theperturbedbit indicates whether the argument is involved in differentiation (i.e., if it isFalse, then the corresponding Jacobian “column” is zero).bwdwill be passed objects representing static symbolic zeros in its cotangent argument in correspondence with unperturbed values; otherwise, only standard JAX types (e.g. array-likes) are passed.
Setting this option to
Trueallows these rules to detect whether certain inputs and outputs are not involved in differentiation, but at the cost of special handling. For instance:The signature of
fwdchanges, and the objects it is passed cannot be output from the rule directly.The
bwdrule is passed objects that are not entirely array-like, and that cannot be passed to mostjax.numpyfunctions.Any custom pytree nodes involved in the primal function’s arguments must accept, in their unflattening functions, the two-field record objects that are given as input leaves to the
fwdrule.
Default
False.optimize_remat (bool) – boolean, an experimental flag to enable an automatic optimization when this function is used under
jax.remat(). This will be most useful when thefwdrule is an opaque call such as a Pallas kernel or a custom call. DefaultFalse.
- Returns:
None.
- Return type:
None
Examples
>>> @jax.custom_vjp ... def f(x, y): ... return jnp.sin(x) * y ... >>> def f_fwd(x, y): ... return f(x, y), (jnp.cos(x), jnp.sin(x), y) ... >>> def f_bwd(res, g): ... cos_x, sin_x, y = res ... return (cos_x * g * y, sin_x * g) ... >>> f.defvjp(f_fwd, f_bwd)
>>> x = jnp.float32(1.0) >>> y = jnp.float32(2.0) >>> with jnp.printoptions(precision=2): ... print(jax.value_and_grad(f)(x, y)) (Array(1.68, dtype=float32), Array(1.08, dtype=float32))