jax.nn.initializers.ones#
- jax.nn.initializers.ones(key, shape, dtype=<class 'numpy.float64'>, out_sharding=None)[source]#
An initializer that returns a constant array full of ones.
The
keyargument is ignored.>>> import jax, jax.numpy as jnp >>> jax.nn.initializers.ones(jax.random.key(42), (3, 2), jnp.float32) Array([[1., 1.], [1., 1.], [1., 1.]], dtype=float32)