jax.hessian#
- jax.hessian(fun, argnums=0, has_aux=False, holomorphic=False)[source]#
Hessian of
funas a dense array.- Parameters:
fun (Callable) – Function whose Hessian is to be computed. Its arguments at positions specified by
argnumsshould be arrays, scalars, or standard Python containers thereof. It should return arrays, scalars, or standard Python containers thereof.argnums (int | Sequence[int]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default
0).has_aux (bool) – Optional, bool. Indicates whether
funreturns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.holomorphic (bool) – Optional, bool. Indicates whether
funis promised to be holomorphic. Default False.
- Returns:
A function with the same arguments as
fun, that evaluates the Hessian offun.- Return type:
Callable
>>> import jax >>> >>> g = lambda x: x[0]**3 - 2*x[0]*x[1] - x[1]**6 >>> print(jax.hessian(g)(jax.numpy.array([1., 2.]))) [[ 6. -2.] [ -2. -480.]]
hessian()is a generalization of the usual definition of the Hessian that supports nested Python containers (i.e. pytrees) as inputs and outputs. The tree structure ofjax.hessian(fun)(x)is given by forming a tree product of the structure offun(x)with a tree product of two copies of the structure ofx. A tree product of two tree structures is formed by replacing each leaf of the first tree with a copy of the second. For example:>>> import jax.numpy as jnp >>> f = lambda dct: {"c": jnp.power(dct["a"], dct["b"])} >>> print(jax.hessian(f)({"a": jnp.arange(2.) + 1., "b": jnp.arange(2.) + 2.})) {'c': {'a': {'a': Array([[[ 2., 0.], [ 0., 0.]], [[ 0., 0.], [ 0., 12.]]], dtype=float32), 'b': Array([[[ 1. , 0. ], [ 0. , 0. ]], [[ 0. , 0. ], [ 0. , 12.317766]]], dtype=float32)}, 'b': {'a': Array([[[ 1. , 0. ], [ 0. , 0. ]], [[ 0. , 0. ], [ 0. , 12.317766]]], dtype=float32), 'b': Array([[[0. , 0. ], [0. , 0. ]], [[0. , 0. ], [0. , 3.843624]]], dtype=float32)}}}
Thus each leaf in the tree structure of
jax.hessian(fun)(x)corresponds to a leaf offun(x)and a pair of leaves ofx. For each leaf injax.hessian(fun)(x), if the corresponding array leaf offun(x)has shape(out_1, out_2, ...)and the corresponding array leaves ofxhave shape(in_1_1, in_1_2, ...)and(in_2_1, in_2_2, ...)respectively, then the Hessian leaf has shape(out_1, out_2, ..., in_1_1, in_1_2, ..., in_2_1, in_2_2, ...). In other words, the Python tree structure represents the block structure of the Hessian, with blocks determined by the input and output pytrees.In particular, an array is produced (with no pytrees involved) when the function input
xand outputfun(x)are each a single array, as in thegexample above. Iffun(x)has shape(out1, out2, ...)andxhas shape(in1, in2, ...)thenjax.hessian(fun)(x)has shape(out1, out2, ..., in1, in2, ..., in1, in2, ...). To flatten pytrees into 1D vectors, consider usingjax.flatten_util.flatten_pytree().