jax.Array.repeat#
- abstract Array.repeat(repeats, axis=None, *, total_repeat_length=None, out_sharding=None)[source]#
Construct an array from repeated elements.
Refer to
jax.numpy.repeat()for the full documentation.- Parameters:
self (Array)
repeats (ArrayLike)
axis (int | None)
total_repeat_length (int | None)
out_sharding (NamedSharding | PartitionSpec | None)
- Return type: