jax.numpy.isscalar#
- jax.numpy.isscalar(element)[source]#
Return True if the input is a scalar.
JAX implementation of
numpy.isscalar(). JAX’s implementation differs from NumPy’s in that it considers zero-dimensional arrays to be scalars; see the Note below for more details.- Parameters:
element (Any) – input object to check; any type is valid input.
- Returns:
True if
elementis a scalar value or an array-like object with zero dimensions, False otherwise.- Return type:
Note
JAX and NumPy differ in their representation of scalar values. NumPy has special scalar objects (e.g.
np.int32(0)) which are distinct from zero-dimensional arrays (e.g.np.array(0)), andnumpy.isscalar()returnsTruefor the former andFalsefor the latter.JAX does not define special scalar objects, but rather represents scalars as zero-dimensional arrays. As such,
jax.numpy.isscalar()returnsTruefor both scalar objects (e.g.0.0ornp.float32(0.0)) and array-like objects with zero dimensions (e.g.jnp.array(0.0),np.array(0.0)).One reason for the different conventions in
isscalaris to maintain JIT-invariance: i.e. the property that the result of a function should not change when it is JIT-compiled. Because scalar inputs are cast to zero-dimensional JAX arrays at JIT boundaries, the semantics ofnumpy.isscalar()are such that the result changes under JIT:>>> np.isscalar(1.0) True >>> jax.jit(np.isscalar)(1.0) Array(False, dtype=bool)
By treating zero-dimensional arrays as scalars,
jax.numpy.isscalar()avoids this issue:>>> jnp.isscalar(1.0) True >>> jax.jit(jnp.isscalar)(1.0) Array(True, dtype=bool)
Examples
In JAX, both scalars and zero-dimensional array-like objects are considered scalars:
>>> jnp.isscalar(1.0) True >>> jnp.isscalar(1 + 1j) True >>> jnp.isscalar(jnp.array(1)) # zero-dimensional JAX array True >>> jnp.isscalar(jnp.int32(1)) # JAX scalar constructor True >>> jnp.isscalar(np.array(1.0)) # zero-dimensional NumPy array True >>> jnp.isscalar(np.int32(1)) # NumPy scalar type True
Arrays with one or more dimension are not considered scalars:
>>> jnp.isscalar(jnp.array([1])) False >>> jnp.isscalar(np.array([1])) False
Compare this to
numpy.isscalar(), which returnsTruefor scalar-typed objects, andFalsefor all arrays, even those with zero dimensions:>>> np.isscalar(np.int32(1)) # scalar object True >>> np.isscalar(np.array(1)) # zero-dimensional array False
In JAX, as in NumPy, objects which are not array-like are not considered scalars:
>>> jnp.isscalar(None) False >>> jnp.isscalar([1]) False >>> jnp.isscalar(tuple()) False >>> jnp.isscalar(slice(10)) False