jax.numpy.ones#

jax.numpy.ones(shape, dtype=None, *, device=None, out_sharding=None)[source]#

Create an array full of ones.

JAX implementation of numpy.ones().

Parameters:
  • shape (Any) – int or sequence of ints specifying the shape of the created array.

  • dtype (str | type[Any] | dtype | SupportsDType | None) – optional dtype for the created array; defaults to float32 or float64 depending on the X64 configuration (see Default dtypes and the X64 flag).

  • device (Device | Sharding | None) – (optional) Device or Sharding to which the created array will be committed. This argument exists for compatibility with the Python Array API standard.

  • out_sharding (NamedSharding | PartitionSpec | None) – (optional) PartitionSpec or NamedSharding representing the sharding of the created array (see explicit sharding for more details). This argument exists for consistency with other array creation routines across JAX. Specifying both out_sharding and device will result in an error.

Returns:

Array of the specified shape and dtype, with the given device/sharding if specified.

Return type:

Array

Examples

>>> jnp.ones(4)
Array([1., 1., 1., 1.], dtype=float32)
>>> jnp.ones((2, 3), dtype=bool)
Array([[ True,  True,  True],
       [ True,  True,  True]], dtype=bool)