jax.numpy.arange#
- jax.numpy.arange(start, stop=None, step=None, dtype=None, *, device=None, out_sharding=None)[source]#
Create an array of evenly-spaced values.
JAX implementation of
numpy.arange()
, implemented in terms ofjax.lax.iota()
.Similar to Python’s
range()
function, this can be called with a few different positional signatures:jnp.arange(stop)
: generate values from 0 tostop
, stepping by 1.jnp.arange(start, stop)
: generate values fromstart
tostop
, stepping by 1.jnp.arange(start, stop, step)
: generate values fromstart
tostop
, stepping bystep
.
Like with Python’s
range()
function, the starting value is inclusive, and the stop value is exclusive.- Parameters:
start (ArrayLike | DimSize) – start of the interval, inclusive.
stop (ArrayLike | DimSize | None) – optional end of the interval, exclusive. If not specified, then
(start, stop) = (0, start)
step (ArrayLike | None) – optional step size for the interval. Default = 1.
dtype (DTypeLike | None) – optional dtype for the returned array; if not specified it will be determined via type promotion of start, stop, and step.
device (xc.Device | Sharding | None) – (optional)
Device
orSharding
to which the created array will be committed.out_sharding (NamedSharding | P | None) – (optional)
NamedSharding
orP
to which the created array will be committed. Use out_sharding argument, if using explicit sharding (https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html)
- Returns:
Array of evenly-spaced values from
start
tostop
, separated bystep
.- Return type:
Note
Using
arange
with a floating-pointstep
argument can lead to unexpected results due to accumulation of floating-point errors, especially with lower-precision data types likefloat8_*
andbfloat16
. To avoid precision errors, consider generating a range of integers, and scaling it to the desired range. For example, instead of this:jnp.arange(-1, 1, 0.01, dtype='bfloat16')
it can be more accurate to generate a sequence of integers, and scale them:
(jnp.arange(-100, 100) * 0.01).astype('bfloat16')
Examples
Single-argument version specifies only the
stop
value:>>> jnp.arange(4) Array([0, 1, 2, 3], dtype=int32)
Passing a floating-point
stop
value leads to a floating-point result:>>> jnp.arange(4.0) Array([0., 1., 2., 3.], dtype=float32)
Two-argument version specifies
start
andstop
, withstep=1
:>>> jnp.arange(1, 6) Array([1, 2, 3, 4, 5], dtype=int32)
Three-argument version specifies
start
,stop
, andstep
:>>> jnp.arange(0, 2, 0.5) Array([0. , 0.5, 1. , 1.5], dtype=float32)
See also
jax.numpy.linspace()
: generate a fixed number of evenly-spaced values.jax.lax.iota()
: directly generate integer sequences in XLA.