jax.numpy.cross#

jax.numpy.cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None)[原始碼]#

計算兩個陣列的(批次)外積。

JAX 實作的 numpy.cross()

這會計算 2 維或 3 維的外積,

\[c = a \times b\]

在 3 維中,c 是一個長度為 3 的陣列。在 2 維中,c 是一個純量。

參數:
  • a – N 維陣列。a.shape[axisa] 指出外積的維度,且必須為 2 或 3。

  • b – N 維陣列。必須具有 b.shape[axisb] == a.shape[axisb],且 ab 的其他維度必須是可廣播相容的。

  • axisa (int) – 指定 a 的軸,沿該軸計算外積。

  • axisb (int) – 指定 b 的軸,沿該軸計算外積。

  • axisc (int) – 指定 c 的軸,外積結果將儲存在該軸上。

  • axis (int | None) – 如果指定,此參數會使用單一值覆寫 axisaaxisbaxisc

回傳值:

陣列 c,包含 ab 沿指定軸的(批次)外積。

參見

範例

2 維外積回傳一個純量

>>> a = jnp.array([1, 2])
>>> b = jnp.array([3, 4])
>>> jnp.cross(a, b)
Array(-2, dtype=int32)

3 維外積回傳一個長度為 3 的向量

>>> a = jnp.array([1, 2, 3])
>>> b = jnp.array([4, 5, 6])
>>> jnp.cross(a, b)
Array([-3,  6, -3], dtype=int32)

對於多維輸入,外積預設沿著最後一個軸計算。以下是一個批次的 3 維外積,運算於輸入的列

>>> a = jnp.array([[1, 2, 3],
...                [3, 4, 3]])
>>> b = jnp.array([[2, 3, 2],
...                [4, 5, 6]])
>>> jnp.cross(a, b)
Array([[-5,  4, -1],
       [ 9, -6, -1]], dtype=int32)

指定 axis=0 使其成為批次的 2 維外積,運算於輸入的行

>>> jnp.cross(a, b, axis=0)
Array([-2, -2, 12], dtype=int32)

等效地,我們可以獨立指定輸入 ab 以及輸出 c 的軸

>>> jnp.cross(a, b, axisa=0, axisb=0, axisc=0)
Array([-2, -2, 12], dtype=int32)