jax.dlpack.from_dlpack#
- jax.dlpack.from_dlpack(external_array, device=None, copy=None)[source]#
Returns a
Arrayrepresentation of a DLPack tensor.The returned
Arrayshares memory withexternal_arrayif no device transfer or copy was requested.- Parameters:
external_array – An array object that has
__dlpack__and__dlpack_device__methods.device (xla_client.Device | Sharding | None) – The (optional)
Device, representing the device on which the returned array should be placed. If given, then the result is committed to the device. If unspecified, the resulting array will be unpacked onto the same device it originated from. Settingdeviceto a device different from the source ofexternal_arraywill require a copy, meaningcopymust be set to eitherTrueorNone.copy (bool | None) – An (optional) boolean, controlling whether or not a copy is performed. If
copy=Truethen a copy is always performed, even if unpacked onto the same device. Ifcopy=Falsethen the copy is never performed and will raise an error if necessary. Whencopy=Nonethen a copy may be performed if needed for a device transfer.
- Returns:
A jax.Array
Note
While JAX arrays are always immutable, dlpack buffers cannot be marked as immutable, and it is possible for processes external to JAX to mutate them in-place. If a jax Array is constructed from a dlpack buffer and the buffer is later modified in-place, it may lead to undefined behavior when using the associated JAX array.