jax.tree_util.Partial#
- class jax.tree_util.Partial(func, *args, **kw)#
functools.partial 的一個版本,可在 pytree 中運作。
將其用於部分函數求值,使其與 JAX 的轉換相容,例如,
Partial(func, *args, **kwargs)
。(您需要明確選擇加入此行為,因為我們不希望 functools.partial 的語意與一般函數閉包不同。)
例如,以下是在類似於
functools.partial
的方式中使用Partial
的基本用法>>> import jax.numpy as jnp >>> add_one = Partial(jnp.add, 1) >>> add_one(2) Array(3, dtype=int32, weak_type=True)
Pytree 相容性表示產生的部分函數可以作為轉換後的 JAX 函數中的引數傳遞,這對於標準
functools.partial
函數來說是不可能的>>> from jax import jit >>> @jit ... def call_func(f, *args): ... return f(*args) ... >>> call_func(add_one, 2) Array(3, dtype=int32, weak_type=True)
將零個引數傳遞給
Partial
有效地包裝了原始函數,使其成為 JAX 轉換函數中的有效引數>>> call_func(Partial(jnp.add), 1, 2) Array(3, dtype=int32, weak_type=True)
如果我們直接將
jnp.add
傳遞給call_func
,則會導致TypeError
。請注意,如果在追蹤值的環境中使用
Partial
的結果,則在傳遞給部分求值函數時,會追蹤所有繫結的引數>>> print_zero = Partial(print, 0) >>> print_zero() 0 >>> call_func(print_zero) Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace...>
- __init__()#
方法
__init__
()屬性
args
未來部分呼叫的引數元組
func
在未來部分呼叫中使用的函數物件
keywords
未來部分呼叫的關鍵字引數字典