jax.tree_util.Partial#
- class jax.tree_util.Partial(func, *args, **kw)#
A version of functools.partial that works in pytrees.
Use it for partial function evaluation in a way that is compatible with JAX’s transformations, e.g.,
Partial(func, *args, **kwargs).(You need to explicitly opt-in to this behavior because we didn’t want to give functools.partial different semantics than normal function closures.)
For example, here is a basic usage of
Partialin a manner similar tofunctools.partial:>>> import jax.numpy as jnp >>> add_one = Partial(jnp.add, 1) >>> add_one(2) Array(3, dtype=int32, weak_type=True)
Pytree compatibility means that the resulting partial function can be passed as an argument within transformed JAX functions, which is not possible with a standard
functools.partialfunction:>>> from jax import jit >>> @jit ... def call_func(f, *args): ... return f(*args) ... >>> call_func(add_one, 2) Array(3, dtype=int32, weak_type=True)
Passing zero arguments to
Partialeffectively wraps the original function, making it a valid argument in JAX transformed functions:>>> call_func(Partial(jnp.add), 1, 2) Array(3, dtype=int32, weak_type=True)
Had we passed
jnp.addtocall_funcdirectly, it would have resulted in aTypeError.Note that if the result of
Partialis used in the context where the value is traced, it results in all bound arguments being traced when passed to the partially-evaluated function:>>> print_zero = Partial(print, 0) >>> print_zero() 0 >>> call_func(print_zero) JitTracer<~int32[]>
- __init__()#
Methods
__init__()Attributes
argstuple of arguments to future partial calls
funcfunction object to use in future partial calls
keywordsdictionary of keyword arguments to future partial calls