jax.experimental.sparse.JAXSparse#
- class jax.experimental.sparse.JAXSparse(args, *, shape)[source]#
Base class for high-level JAX sparse objects.
Methods
__init__(args, *, shape)block_until_ready()sum(*args, **kwargs)transpose([axes])tree_flatten()tree_unflatten(aux_data, children)Attributes
Tndimsizedatashapensedtype