jax.numpy.outer#

jax.numpy.outer(a, b, out=None)[原始碼]#

計算兩個陣列的外積。

numpy.outer() 的 JAX 實作。

參數:
  • a (ArrayLike) – 第一個輸入陣列,若非 1D 陣列將會被展平。

  • b (ArrayLike) – 第二個輸入陣列,若非 1D 陣列將會被展平。

  • out (None) – JAX 不支援。

回傳值:

輸入 ab 的外積。回傳陣列的形狀為 (a.size, b.size)

回傳型別:

陣列 (Array)

另請參閱

範例

>>> a = jnp.array([1, 2, 3])
>>> b = jnp.array([4, 5, 6])
>>> jnp.outer(a, b)
Array([[ 4,  5,  6],
       [ 8, 10, 12],
       [12, 15, 18]], dtype=int32)