jax.numpy.cross#
- jax.numpy.cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None)[原始碼]#
計算兩個陣列的(批次)外積。
JAX 實作的
numpy.cross()
。這會計算 2 維或 3 維的外積,
\[c = a \times b\]在 3 維中,
c
是一個長度為 3 的陣列。在 2 維中,c
是一個純量。- 參數:
- 回傳值:
陣列
c
,包含a
和b
沿指定軸的(批次)外積。
參見
jax.numpy.linalg.cross()
:一個陣列 API 相容函式,用於計算 3 維向量的外積。
範例
2 維外積回傳一個純量
>>> a = jnp.array([1, 2]) >>> b = jnp.array([3, 4]) >>> jnp.cross(a, b) Array(-2, dtype=int32)
3 維外積回傳一個長度為 3 的向量
>>> a = jnp.array([1, 2, 3]) >>> b = jnp.array([4, 5, 6]) >>> jnp.cross(a, b) Array([-3, 6, -3], dtype=int32)
對於多維輸入,外積預設沿著最後一個軸計算。以下是一個批次的 3 維外積,運算於輸入的列
>>> a = jnp.array([[1, 2, 3], ... [3, 4, 3]]) >>> b = jnp.array([[2, 3, 2], ... [4, 5, 6]]) >>> jnp.cross(a, b) Array([[-5, 4, -1], [ 9, -6, -1]], dtype=int32)
指定 axis=0 使其成為批次的 2 維外積,運算於輸入的行
>>> jnp.cross(a, b, axis=0) Array([-2, -2, 12], dtype=int32)
等效地,我們可以獨立指定輸入
a
和b
以及輸出c
的軸>>> jnp.cross(a, b, axisa=0, axisb=0, axisc=0) Array([-2, -2, 12], dtype=int32)