jax.tree_util.Partial#

class jax.tree_util.Partial(func, *args, **kw)#

functools.partial 的一個版本,可在 pytree 中運作。

將其用於部分函數求值,使其與 JAX 的轉換相容,例如,Partial(func, *args, **kwargs)

(您需要明確選擇加入此行為,因為我們不希望 functools.partial 的語意與一般函數閉包不同。)

例如,以下是在類似於 functools.partial 的方式中使用 Partial 的基本用法

>>> import jax.numpy as jnp
>>> add_one = Partial(jnp.add, 1)
>>> add_one(2)
Array(3, dtype=int32, weak_type=True)

Pytree 相容性表示產生的部分函數可以作為轉換後的 JAX 函數中的引數傳遞,這對於標準 functools.partial 函數來說是不可能的

>>> from jax import jit
>>> @jit
... def call_func(f, *args):
...   return f(*args)
...
>>> call_func(add_one, 2)
Array(3, dtype=int32, weak_type=True)

將零個引數傳遞給 Partial 有效地包裝了原始函數,使其成為 JAX 轉換函數中的有效引數

>>> call_func(Partial(jnp.add), 1, 2)
Array(3, dtype=int32, weak_type=True)

如果我們直接將 jnp.add 傳遞給 call_func,則會導致 TypeError

請注意,如果在追蹤值的環境中使用 Partial 的結果,則在傳遞給部分求值函數時,會追蹤所有繫結的引數

>>> print_zero = Partial(print, 0)
>>> print_zero()
0
>>> call_func(print_zero)  
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace...>
__init__()#

方法

__init__()

屬性

args

未來部分呼叫的引數元組

func

在未來部分呼叫中使用的函數物件

keywords

未來部分呼叫的關鍵字引數字典