jax.ShapeDtypeStruct#
- class jax.ShapeDtypeStruct(shape, dtype, *, sharding=None, weak_type=False, vma=None, is_ref=False)[source]#
A container for the shape, dtype, and other static attributes of an array.
ShapeDtypeStructis often used in conjunction withjax.eval_shape().- Parameters:
shape – a sequence of integers representing an array shape
dtype – a dtype-like object
sharding – (optional) a
jax.Shardingobject
Methods
__init__(shape, dtype, *[, sharding, ...])update(**kwargs)Attributes
shapedtypeshardingweak_typevmais_refformatndimsize