jax.lax.custom_root#
- jax.lax.custom_root(f, initial_guess, solve, tangent_solve, has_aux=False)[source]#
Differentiably solve for the roots of a function.
This is a low-level routine, mostly intended for internal use in JAX. Gradients of custom_root() are defined with respect to closed-over variables from the provided function
fvia the implicit function theorem: https://en.wikipedia.org/wiki/Implicit_function_theorem- Parameters:
f (Callable) – function for which to find a root. Should accept a single argument, return a tree of arrays with the same structure as its input.
initial_guess (Any) – initial guess for a zero of f.
solve (Callable[[Callable, Any], Any]) –
function to solve for the roots of f. Should take two positional arguments, f and initial_guess, and return a solution with the same structure as initial_guess such that func(solution) = 0. In other words, the following is assumed to be true (but not checked):
solution = solve(f, initial_guess) error = f(solution) assert all(error == 0)
tangent_solve (Callable[[Callable, Any], Any]) –
function to solve the tangent system. Should take two positional arguments, a linear function
g(the functionflinearized at its root) and a tree of array(s)ywith the same structure as initial_guess, and return a solutionxsuch thatg(x)=y:For scalar
y, uselambda g, y: y / g(1.0).For vector
y, you could use a linear solve with the Jacobian, if dimensionality ofyis not too large:lambda g, y: np.linalg.solve(jacobian(g)(y), y).
has_aux – bool indicating whether the
solvefunction returns auxiliary data like solver diagnostics as a second argument.
- Returns:
The result of calling solve(f, initial_guess) with gradients defined via implicit differentiation assuming
f(solve(f, initial_guess)) == 0.