jax.numpy.fromfunction#
- jax.numpy.fromfunction(function, shape, *, dtype=<class 'float'>, **kwargs)[原始碼]#
從應用於索引的函數建立陣列。
JAX 實作的
numpy.fromfunction()
。JAX 實作的不同之處在於它透過jax.vmap()
進行分派,因此與 NumPy 不同,此函數在邏輯上對純量輸入進行操作,並且不需要顯式處理廣播輸入(請參閱下方的範例)。- 參數:
function (Callable[..., Array]) – 一個接受 N 個動態純量並輸出一個純量的函數。
shape (Any) – 一個長度為 N 的整數元組,指定輸出形狀。
dtype (DTypeLike) – 選擇性指定輸入的 dtype。預設為浮點數。
kwargs – 額外的關鍵字引數會靜態傳遞至
function
。
- 回傳值:
如果
function
回傳純量,則為形狀為shape
的陣列;否則,通常為具有前導維度shape
的陣列 pytree,由function
的輸出決定。- 回傳類型:
另請參閱
jax.vmap()
:fromfunction()
API 建構於其上的核心轉換。
範例
產生給定形狀的乘法表
>>> jnp.fromfunction(jnp.multiply, shape=(3, 6), dtype=int) Array([[ 0, 0, 0, 0, 0, 0], [ 0, 1, 2, 3, 4, 5], [ 0, 2, 4, 6, 8, 10]], dtype=int32)
當
function
回傳非純量時,輸出將具有shape
的前導維度>>> def f(x): ... return (x + 1) * jnp.arange(3) >>> jnp.fromfunction(f, shape=(2,)) Array([[0., 1., 2.], [0., 2., 4.]], dtype=float32)
function
可能會回傳多個結果,在這種情況下,每個結果都會獨立映射>>> def f(x, y): ... return x + y, x * y >>> x_plus_y, x_times_y = jnp.fromfunction(f, shape=(3, 5)) >>> print(x_plus_y) [[0. 1. 2. 3. 4.] [1. 2. 3. 4. 5.] [2. 3. 4. 5. 6.]] >>> print(x_times_y) [[0. 0. 0. 0. 0.] [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.]]
JAX 實作與 NumPy 的實作略有不同。在
numpy.fromfunction()
中,預期函數會顯式地對輸入值的完整網格執行元素級操作>>> def f(x, y): ... print(f"{x.shape = }\n{y.shape = }") ... return x + y ... >>> np.fromfunction(f, (2, 3)) x.shape = (2, 3) y.shape = (2, 3) array([[0., 1., 2.], [1., 2., 3.]])
在
jax.numpy.fromfunction()
中,函數會透過jax.vmap()
進行向量化,因此預期會對純量值進行操作>>> jnp.fromfunction(f, (2, 3)) x.shape = () y.shape = () Array([[0., 1., 2.], [1., 2., 3.]], dtype=float32)