jax.numpy.fromfunction#

jax.numpy.fromfunction(function, shape, *, dtype=<class 'float'>, **kwargs)[原始碼]#

從應用於索引的函數建立陣列。

JAX 實作的 numpy.fromfunction()。JAX 實作的不同之處在於它透過 jax.vmap() 進行分派,因此與 NumPy 不同,此函數在邏輯上對純量輸入進行操作,並且不需要顯式處理廣播輸入(請參閱下方的範例)。

參數:
  • function (Callable[..., Array]) – 一個接受 N 個動態純量並輸出一個純量的函數。

  • shape (Any) – 一個長度為 N 的整數元組,指定輸出形狀。

  • dtype (DTypeLike) – 選擇性指定輸入的 dtype。預設為浮點數。

  • kwargs – 額外的關鍵字引數會靜態傳遞至 function

回傳值:

如果 function 回傳純量,則為形狀為 shape 的陣列;否則,通常為具有前導維度 shape 的陣列 pytree,由 function 的輸出決定。

回傳類型:

Array

另請參閱

範例

產生給定形狀的乘法表

>>> jnp.fromfunction(jnp.multiply, shape=(3, 6), dtype=int)
Array([[ 0,  0,  0,  0,  0,  0],
       [ 0,  1,  2,  3,  4,  5],
       [ 0,  2,  4,  6,  8, 10]], dtype=int32)

function 回傳非純量時,輸出將具有 shape 的前導維度

>>> def f(x):
...   return (x + 1) * jnp.arange(3)
>>> jnp.fromfunction(f, shape=(2,))
Array([[0., 1., 2.],
       [0., 2., 4.]], dtype=float32)

function 可能會回傳多個結果,在這種情況下,每個結果都會獨立映射

>>> def f(x, y):
...   return x + y, x * y
>>> x_plus_y, x_times_y = jnp.fromfunction(f, shape=(3, 5))
>>> print(x_plus_y)
[[0. 1. 2. 3. 4.]
 [1. 2. 3. 4. 5.]
 [2. 3. 4. 5. 6.]]
>>> print(x_times_y)
[[0. 0. 0. 0. 0.]
 [0. 1. 2. 3. 4.]
 [0. 2. 4. 6. 8.]]

JAX 實作與 NumPy 的實作略有不同。在 numpy.fromfunction() 中,預期函數會顯式地對輸入值的完整網格執行元素級操作

>>> def f(x, y):
...   print(f"{x.shape = }\n{y.shape = }")
...   return x + y
...
>>> np.fromfunction(f, (2, 3))
x.shape = (2, 3)
y.shape = (2, 3)
array([[0., 1., 2.],
       [1., 2., 3.]])

jax.numpy.fromfunction() 中,函數會透過 jax.vmap() 進行向量化,因此預期會對純量值進行操作

>>> jnp.fromfunction(f, (2, 3))
x.shape = ()
y.shape = ()
Array([[0., 1., 2.],
       [1., 2., 3.]], dtype=float32)