jax.pure_callback#
- jax.pure_callback(callback, result_shape_dtypes, *args, sharding=None, vectorized=Deprecated, vmap_method=None, **kwargs)[source]#
Calls a pure Python callback. Works under
jit()/vmap()/etc.For more explanation, see External Callbacks.
pure_callbackenables calling a Python function in JIT-ed JAX functions. The inputcallbackwill be passed JAX arrays placed on a local CPU, and it should also return JAX arrays on CPU.The callback is treated as functionally pure, meaning it has no side-effects and its output value depends only on its argument values. As a consequence, it is safe to be called multiple times (e.g. when transformed by
vmap()orpmap()), or not to be called at all when e.g. the output of a jit-decorated function has no data dependence on its value. Pure callbacks may also be reordered if data-dependence allows.Warning
In the context of JAX transformations, Python exceptions should be considered side-effects: this means that intentionally raising an error within a pure_callback breaks the API contract, and the behavior of the resulting program is undefined.
When vmap-ed the behavior will depend on the value of the
vmap_method.Calling
vmap()on a callback without an explicitvmap_methodraises aNotImplementedError.vmap_method="sequential"usesmap()to loop over the batched arguments, callingcallbackonce for each batch element.vmap_method="sequential_unrolled"is likesequential, but the loop is unrolled.vmap_method="expand_dims"callscallbackwith new axes of size1added as the leading dimension unbatched inputs.vmap_method="broadcast_all"behaves likeexpand_dims, but the inputs are tiled to the expected batched shape.
If necessary, the legacy behavior provided by the removed
vectorized=Trueargument can be recovered usingvmap_method="legacy_vectorized".The current default behavior is to use
vmap_method="sequential"when not specified, but this behavior is deprecated, and in the future, the default will be to raise aNotImplementedErrorunlessvmap_methodis explicitly specified.- Parameters:
callback (Callable[..., Any]) – function to execute on the host. The callback is assumed to be a pure function (i.e. one without side-effects): if an impure function is passed, it may behave in unexpected ways, particularly under transformation. The callable will be passed PyTrees of arrays as arguments, and should return a PyTree of arrays that matches
result_shape_dtypes.result_shape_dtypes (Any) – pytree whose leaves have
shapeanddtypeattributes, whose structure matches the expected output of the callback function at runtime.jax.ShapeDtypeStructis often used to define leaf values.*args (Any) – arguments to be passed to the callback function
sharding (SingleDeviceSharding | None) – optional sharding that specifies the device from which the callback should be invoked.
vmap_method (str | None) – string specifying how the callback transforms under
vmap()as described above.**kwargs (Any) – keyword arguments to be passed to the callback function
vectorized (bool | None | DeprecatedArg)
- Returns:
- a pytree of
jax.Arrayobjects whose structure matches that of result_shape_dtypes.
- a pytree of
- Return type:
result
See also
jax.experimental.io_callback(): callback designed for impure functions.jax.debug.callback(): callback designed for general-purpose debugging.jax.debug.print(): callback designed for printing.
Examples
The behavior of
pure_callbackundervmap()is controlled by thevmap_methodargument as described above. It is useful to consider some explicit examples that demonstrate the semantics. For example, consider the following function:>>> def callback(x, y): ... print(jnp.shape(x), jnp.shape(y)) ... return x + y
>>> def fun(x, y, *, vmap_method): ... shape = jnp.broadcast_shapes(jnp.shape(x), jnp.shape(y)) ... dtype = jnp.result_type(x, y) ... out_type = jax.ShapeDtypeStruct(shape, dtype) ... return jax.pure_callback(callback, out_type, x, y, ... vmap_method=vmap_method)
Calling this with
vmap_method="expand_dims"adds a new axis of size1toy:>>> from functools import partial >>> x = jnp.arange(4) >>> y = 1.0 >>> jax.vmap(partial(fun, vmap_method="expand_dims"), in_axes=(0, None))(x, y) (4,) (1,) Array([1., 2., 3., 4.], dtype=float32)
Whereas,
vmap_method="broadcast_all"adds an axis of size4toy:>>> jax.vmap(partial(fun, vmap_method="broadcast_all"), ... in_axes=(0, None))(x, y) (4,) (4,) Array([1., 2., 3., 4.], dtype=float32)