jax.experimental.sparse.grad#
- jax.experimental.sparse.grad(fun, argnums=0, has_aux=False, **kwargs)[source]#
Sparse-aware version of
jax.grad()Arguments and return values are the same as
jax.grad(), but when taking the gradient with respect to ajax.experimental.sparsearray, the gradient is computed in the subspace defined by the array’s sparsity pattern.Examples
>>> from jax.experimental import sparse >>> X = sparse.BCOO.fromdense(jnp.arange(6.)) >>> y = jnp.ones(6) >>> sparse.grad(lambda X, y: X @ y)(X, y) BCOO(float32[6], nse=5)