jax.lax.precv#
- jax.lax.precv(token, out_shape, axis_name, perm)[source]#
Perform a collective recv according to the permutation
perm.This function is an analog of the Recv HLO.
- Parameters:
token – a compiler token, either generated by a matching psend or lax.create_token(). This is used to enforce control dependencies between collectives.
out_shape – ShapeDtypeStruct(s) containing the dtype and shape of the result.
axis_name – hashable Python object used to name a pmapped axis (see the
jax.pmap()documentation for more details).perm – list of pairs of ints, representing
(source_index, destination_index)pairs that encode how the mapped axis namedaxis_nameshould be shuffled. The integer values are treated as indices into the mapped axisaxis_name. Any two pairs should not have the same source index or the same destination index. For each index of the axisaxis_namethat does not correspond to a destination index inperm, the corresponding values in the result are filled with zeros of the appropriate type. The semantics here are platform-specific, and for GPU they correspond to NCCL recv.
- Returns:
Array(s) with the same shape as
out_shape.