jax.devices#
- jax.devices(backend=None)[source]#
Returns a list of all devices for a given backend.
Each device is represented by a subclass of
Device(e.g.CpuDevice,GpuDevice). The length of the returned list is equal todevice_count(backend). Local devices can be identified by comparingDevice.process_indexto the value returned byjax.process_index().If
backendisNone, returns all the devices from the default backend. The default backend is generally'gpu'or'tpu'if available, otherwise'cpu'.