jax.numpy.vectorize#

jax.numpy.vectorize(pyfunc, *, excluded=frozenset({}), signature=None)[原始碼]#

定義具有廣播的向量化函式。

vectorize() 是一個方便的包裝函式,用於定義具有廣播的向量化函式,風格類似於 NumPy 的 generalized universal functions。它允許定義自動在任何前導維度上重複的函式,而無需函式的實作擔心如何處理更高維度的輸入。

jax.numpy.vectorize() 具有與 numpy.vectorize 相同的介面,但它是自動批次處理轉換 (vmap()) 的語法糖,而不是 Python 迴圈。這應該效率更高,但實作必須以作用於 JAX 陣列的函式來撰寫。

參數:
  • pyfunc – 要向量化的函式。

  • excluded – 代表位置引數的可選整數集合,函式將不會針對這些引數進行向量化。這些引數將直接且未經修改地傳遞給 pyfunc

  • signature – 可選的廣義通用函式簽名,例如,(m,n),(n)->(m) 用於向量化矩陣-向量乘法。如果提供,則將使用(並預期傳回)形狀由相應核心維度大小給定的陣列來呼叫 pyfunc。預設情況下,pyfunc 假定將純量陣列作為輸入和輸出。

傳回:

給定函式的向量化版本。

範例

以下是一些關於如何使用 vectorize() 撰寫向量化線性代數常式之範例

>>> from functools import partial
>>> @partial(jnp.vectorize, signature='(k),(k)->(k)')
... def cross_product(a, b):
...   assert a.shape == b.shape and a.ndim == b.ndim == 1
...   return jnp.array([a[1] * b[2] - a[2] * b[1],
...                     a[2] * b[0] - a[0] * b[2],
...                     a[0] * b[1] - a[1] * b[0]])
>>> @partial(jnp.vectorize, signature='(n,m),(m)->(n)')
... def matrix_vector_product(matrix, vector):
...   assert matrix.ndim == 2 and matrix.shape[1:] == vector.shape
...   return matrix @ vector

這些函式僅撰寫為處理 1D 或 2D 陣列(assert 陳述式永遠不會被違反),但透過 vectorize,它們支援具有 NumPy 風格廣播的任意維度輸入,例如:

>>> cross_product(jnp.ones(3), jnp.ones(3)).shape
(3,)
>>> cross_product(jnp.ones((2, 3)), jnp.ones(3)).shape
(2, 3)
>>> cross_product(jnp.ones((1, 2, 3)), jnp.ones((2, 1, 3))).shape
(2, 2, 3)
>>> matrix_vector_product(jnp.ones(3), jnp.ones(3))  
Traceback (most recent call last):
ValueError: input with shape (3,) does not have enough dimensions for all
core dimensions ('n', 'k') on vectorized function with excluded=frozenset()
and signature='(n,k),(k)->(k)'
>>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones(3)).shape
(2,)
>>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones((4, 3))).shape
(4, 2)

請注意,這與 jnp.matmul 具有不同的語意

>>> jnp.matmul(jnp.ones((2, 3)), jnp.ones((4, 3)))  
Traceback (most recent call last):
TypeError: dot_general requires contracting dimensions to have the same shape, got [3] and [4].