jax.numpy.load#
- jax.numpy.load(file, *args, **kwargs)[source]#
Load JAX arrays from npy files.
JAX wrapper of
numpy.load().This function is a simple wrapper of
numpy.load(), but in the case of.npyfiles created withnumpy.save()orjax.numpy.save(), the output will be returned as ajax.Array, andbfloat16data types will be restored. For.npzfiles, results will be returned as normal NumPy arrays.This function requires concrete array inputs, and is not compatible with transformations like
jax.jit()orjax.vmap().- Parameters:
file (IO[bytes] | str | os.PathLike[Any]) – string, bytes, or path-like object containing the array data.
args (Any) – for additional arguments, see
numpy.load()kwargs (Any) – for additional arguments, see
numpy.load()
- Returns:
the array stored in the file.
- Return type:
See also
jax.numpy.save(): save an array to a file.
Examples
>>> import io >>> f = io.BytesIO() # use an in-memory file-like object. >>> x = jnp.array([2, 4, 6, 8], dtype='bfloat16') >>> jnp.save(f, x) >>> f.seek(0) 0 >>> jnp.load(f) Array([2, 4, 6, 8], dtype=bfloat16)