jax.lax.precv#

jax.lax.precv(token, out_shape, axis_name, perm)[source]#

Perform a collective recv according to the permutation perm.

This function is an analog of the Recv HLO.

Parameters:
  • token – a compiler token, either generated by a matching psend or lax.create_token(). This is used to enforce control dependencies between collectives.

  • out_shape – ShapeDtypeStruct(s) containing the dtype and shape of the result.

  • axis_name – hashable Python object used to name a pmapped axis (see the jax.pmap() documentation for more details).

  • perm – list of pairs of ints, representing (source_index, destination_index) pairs that encode how the mapped axis named axis_name should be shuffled. The integer values are treated as indices into the mapped axis axis_name. Any two pairs should not have the same source index or the same destination index. For each index of the axis axis_name that does not correspond to a destination index in perm, the corresponding values in the result are filled with zeros of the appropriate type. The semantics here are platform-specific, and for GPU they correspond to NCCL recv.

Returns:

Array(s) with the same shape as out_shape.