jax.numpy.arange#
- jax.numpy.arange(start, stop=None, step=None, dtype=None, *, device=None, out_sharding=None)[source]#
Create an array of evenly-spaced values.
JAX implementation of
numpy.arange(), implemented in terms ofjax.lax.iota().Similar to Python’s
range()function, this can be called with a few different positional signatures:jnp.arange(stop): generate values from 0 tostop, stepping by 1.jnp.arange(start, stop): generate values fromstarttostop, stepping by 1.jnp.arange(start, stop, step): generate values fromstarttostop, stepping bystep.
Like with Python’s
range()function, the starting value is inclusive, and the stop value is exclusive.- Parameters:
start (ArrayLike | DimSize) – start of the interval, inclusive.
stop (ArrayLike | DimSize | None) – optional end of the interval, exclusive. If not specified, then
(start, stop) = (0, start)step (ArrayLike | None) – optional step size for the interval. Default = 1.
dtype (DTypeLike | None) – optional dtype for the returned array; if not specified it will be determined via type promotion of start, stop, and step.
device (xc.Device | Sharding | None) – (optional)
DeviceorShardingto which the created array will be committed.out_sharding (NamedSharding | P | None) – (optional)
NamedShardingorPto which the created array will be committed. Use out_sharding argument, if using explicit sharding (https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html)
- Returns:
Array of evenly-spaced values from
starttostop, separated bystep.- Return type:
Note
Using
arangewith a floating-pointstepargument can lead to unexpected results due to accumulation of floating-point errors, especially with lower-precision data types likefloat8_*andbfloat16. To avoid precision errors, consider generating a range of integers, and scaling it to the desired range. For example, instead of this:jnp.arange(-1, 1, 0.01, dtype='bfloat16')
it can be more accurate to generate a sequence of integers, and scale them:
(jnp.arange(-100, 100) * 0.01).astype('bfloat16')
Examples
Single-argument version specifies only the
stopvalue:>>> jnp.arange(4) Array([0, 1, 2, 3], dtype=int32)
Passing a floating-point
stopvalue leads to a floating-point result:>>> jnp.arange(4.0) Array([0., 1., 2., 3.], dtype=float32)
Two-argument version specifies
startandstop, withstep=1:>>> jnp.arange(1, 6) Array([1, 2, 3, 4, 5], dtype=int32)
Three-argument version specifies
start,stop, andstep:>>> jnp.arange(0, 2, 0.5) Array([0. , 0.5, 1. , 1.5], dtype=float32)
See also
jax.numpy.linspace(): generate a fixed number of evenly-spaced values.jax.lax.iota(): directly generate integer sequences in XLA.