jax.numpy.fromfunction#
- jax.numpy.fromfunction(function, shape, *, dtype=<class 'float'>, **kwargs)[source]#
Create an array from a function applied over indices.
JAX implementation of
numpy.fromfunction(). The JAX implementation differs in that it dispatches viajax.vmap(), and so unlike in NumPy the function logically operates on scalar inputs, and need not explicitly handle broadcasted inputs (See Examples below).- Parameters:
function (Callable[..., Array]) – a function that takes N dynamic scalars and outputs a scalar.
shape (Any) – a length-N tuple of integers specifying the output shape.
dtype (DTypeLike) – optionally specify the dtype of the inputs. Defaults to floating-point.
kwargs – additional keyword arguments are passed statically to
function.
- Returns:
An array of shape
shapeiffunctionreturns a scalar, or in general a pytree of arrays with leading dimensionsshape, as determined by the output offunction.- Return type:
See also
jax.vmap(): the core transformation that thefromfunction()API is built on.
Examples
Generate a multiplication table of a given shape:
>>> jnp.fromfunction(jnp.multiply, shape=(3, 6), dtype=int) Array([[ 0, 0, 0, 0, 0, 0], [ 0, 1, 2, 3, 4, 5], [ 0, 2, 4, 6, 8, 10]], dtype=int32)
When
functionreturns a non-scalar the output will have leading dimension ofshape:>>> def f(x): ... return (x + 1) * jnp.arange(3) >>> jnp.fromfunction(f, shape=(2,)) Array([[0., 1., 2.], [0., 2., 4.]], dtype=float32)
functionmay return multiple results, in which case each is mapped independently:>>> def f(x, y): ... return x + y, x * y >>> x_plus_y, x_times_y = jnp.fromfunction(f, shape=(3, 5)) >>> print(x_plus_y) [[0. 1. 2. 3. 4.] [1. 2. 3. 4. 5.] [2. 3. 4. 5. 6.]] >>> print(x_times_y) [[0. 0. 0. 0. 0.] [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.]]
The JAX implementation differs slightly from NumPy’s implementation. In
numpy.fromfunction(), the function is expected to explicitly operate element-wise on the full grid of input values:>>> def f(x, y): ... print(f"{x.shape = }\n{y.shape = }") ... return x + y ... >>> np.fromfunction(f, (2, 3)) x.shape = (2, 3) y.shape = (2, 3) array([[0., 1., 2.], [1., 2., 3.]])
In
jax.numpy.fromfunction(), the function is vectorized viajax.vmap(), and so is expected to operate on scalar values:>>> jnp.fromfunction(f, (2, 3)) x.shape = () y.shape = () Array([[0., 1., 2.], [1., 2., 3.]], dtype=float32)