jax.device_put#
- jax.device_put(x, device=None, *, src=None, donate=False, may_alias=None)[source]#
Transfers
x
todevice
.- Parameters:
x – An array, scalar, or (nested) standard Python container thereof.
device (None | xc.Device | Sharding | P | Format | Any | TransferToMemoryKind) – The (optional)
Device
,Sharding
, or a (nested)Sharding
in standard Python container (must be a tree prefix ofx
), representing the device(s) to whichx
should be transferred. If given, then the result is committed to the device(s).src (None | xc.Device | Sharding | P | Format | Any | TransferToMemoryKind) – The (optional)
Device
,Sharding
, or a (nested)Sharding
in standard Python container (must be a tree prefix ofx
), representing the device(s) on whichx
belongs.donate (bool | Any) – bool or a (nested) bool in standard Python container (must be a tree prefix of
x
). If True,x
can be overwritten and marked deleted in the caller. This is best effort. JAX will donate if possible, otherwise it won’t. The input buffer (in the future) will always be deleted if donated.may_alias (bool | None | Any) – bool or None or a (nested) bool in standard Python container (must be a tree prefix of
x
). If False, x will be copied. If true, x may be aliased depending on the runtime’s implementation.
- Returns:
A copy of
x
that resides ondevice
.
If the
device
parameter isNone
, then this operation behaves like the identity function if the operand is on any device already, otherwise it transfers the data to the default device, uncommitted.For more details on data placement see the FAQ on data placement.
This function is always asynchronous, i.e. returns immediately without blocking the calling Python thread until any transfers are completed.