jax.numpy.kron#
- jax.numpy.kron(a, b)[原始碼]#
計算兩個輸入陣列的 Kronecker 乘積。
JAX 版本的
numpy.kron()
。Kronecker 乘積是對兩個任意大小的矩陣進行運算,產生一個區塊矩陣。第一個矩陣
a
的每個元素都乘以第二個矩陣b
的整個矩陣。如果a
的形狀為 (m, n) 且b
的形狀為 (p, q),則結果矩陣的形狀將為 (m * p, n * q)。- 參數:
a (ArrayLike) – 第一個輸入陣列,具有任何形狀。
b (ArrayLike) – 第二個輸入陣列,具有任何形狀。
- 返回:
一個新的陣列,表示輸入
a
和b
的 Kronecker 乘積。輸出的形狀是輸入形狀的元素乘積。- 返回類型:
參見
jax.numpy.outer()
:計算兩個陣列的外積。
範例
>>> a = jnp.array([[1, 2], ... [3, 4]]) >>> b = jnp.array([[5, 6], ... [7, 8]]) >>> jnp.kron(a, b) Array([[ 5, 6, 10, 12], [ 7, 8, 14, 16], [15, 18, 20, 24], [21, 24, 28, 32]], dtype=int32)