jax.numpy.linspace#
- jax.numpy.linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, *, device=None)[source]#
Return evenly-spaced numbers within an interval.
JAX implementation of
numpy.linspace().- Parameters:
start (Array | ndarray | bool | number | bool | int | float | complex) – scalar or array of starting values.
stop (Array | ndarray | bool | number | bool | int | float | complex) – scalar or array of stop values.
num (int) – number of values to generate. Default: 50.
endpoint (bool) – if True (default) then include the
stopvalue in the result. If False, then exclude thestopvalue.retstep (bool) – If True, then return a
(result, step)tuple, wherestepis the interval between adjacent values inresult.axis (int) – integer axis along which to generate the linspace. Defaults to zero.
device (Device | Sharding | None) – optional
DeviceorShardingto which the created array will be committed.
- Returns:
valuesis an array of evenly-spaced values fromstarttostopstepis the interval between adjacent values.
- Return type:
An array
values, or a tuple(values, step)ifretstepis True, where
See also
jax.numpy.arange(): GenerateNevenly-spaced values given a starting point and a stepjax.numpy.logspace(): Generate logarithmically-spaced values.jax.numpy.geomspace(): Generate geometrically-spaced values.
Examples
List of 5 values between 0 and 10:
>>> jnp.linspace(0, 10, 5) Array([ 0. , 2.5, 5. , 7.5, 10. ], dtype=float32)
List of 8 values between 0 and 10, excluding the endpoint:
>>> jnp.linspace(0, 10, 8, endpoint=False) Array([0. , 1.25, 2.5 , 3.75, 5. , 6.25, 7.5 , 8.75], dtype=float32)
List of values and the step size between them
>>> vals, step = jnp.linspace(0, 10, 9, retstep=True) >>> vals Array([ 0. , 1.25, 2.5 , 3.75, 5. , 6.25, 7.5 , 8.75, 10. ], dtype=float32) >>> step Array(1.25, dtype=float32)
Multi-dimensional linspace:
>>> start = jnp.array([0, 5]) >>> stop = jnp.array([5, 10]) >>> jnp.linspace(start, stop, 5) Array([[ 0. , 5. ], [ 1.25, 6.25], [ 2.5 , 7.5 ], [ 3.75, 8.75], [ 5. , 10. ]], dtype=float32)