jax.numpy.kron#

jax.numpy.kron(a, b)[原始碼]#

計算兩個輸入陣列的 Kronecker 乘積。

JAX 版本的 numpy.kron()

Kronecker 乘積是對兩個任意大小的矩陣進行運算,產生一個區塊矩陣。第一個矩陣 a 的每個元素都乘以第二個矩陣 b 的整個矩陣。如果 a 的形狀為 (m, n) 且 b 的形狀為 (p, q),則結果矩陣的形狀將為 (m * p, n * q)。

參數:
  • a (ArrayLike) – 第一個輸入陣列,具有任何形狀。

  • b (ArrayLike) – 第二個輸入陣列,具有任何形狀。

返回:

一個新的陣列,表示輸入 ab 的 Kronecker 乘積。輸出的形狀是輸入形狀的元素乘積。

返回類型:

陣列 (Array)

參見

範例

>>> a = jnp.array([[1, 2],
...                [3, 4]])
>>> b = jnp.array([[5, 6],
...                [7, 8]])
>>> jnp.kron(a, b)
Array([[ 5,  6, 10, 12],
       [ 7,  8, 14, 16],
       [15, 18, 20, 24],
       [21, 24, 28, 32]], dtype=int32)