jax.numpy.polymul#

jax.numpy.polymul(a1, a2, *, trim_leading_zeros=False)[原始碼]#

傳回兩個多項式的乘積。

JAX 實作的 numpy.polymul()

參數::
  • a1 (ArrayLike) – 多項式係數的 1D 陣列。

  • a2 (ArrayLike) – 多項式係數的 1D 陣列。

  • trim_leading_zeros (bool) – 預設為 False。如果為 True,則移除傳回值中的前導零,以符合 numpy 的結果。但會阻止函式在編譯碼中使用。由於浮點算術錯誤累積的差異,值被視為零的截止值可能會導致 NumPy 和 JAX 之間,甚至不同 JAX 後端之間的不一致結果。trim_leading_zeros=True 時,結果可能會導致不一致的輸出形狀。

傳回::

兩個多項式乘積的係數陣列。輸出的 dtype 始終提升為非精確型別。

傳回型別::

Array

注意

jax.numpy.polymul() 僅接受陣列作為輸入,不像 numpy.polymul() 也接受純量輸入。

另請參閱

範例

>>> x1 = np.array([2, 1, 0])
>>> x2 = np.array([0, 5, 0, 3])
>>> np.polymul(x1, x2)
array([10,  5,  6,  3,  0])
>>> jnp.polymul(x1, x2)
Array([ 0., 10.,  5.,  6.,  3.,  0.], dtype=float32)

如果 trim_leading_zeros=True,結果會與 np.polymul 的結果一致。

>>> jnp.polymul(x1, x2, trim_leading_zeros=True)
Array([10.,  5.,  6.,  3.,  0.], dtype=float32)

對於 dtype 為 complex 的輸入陣列

>>> x3 = np.array([2., 1+2j, 1-2j])
>>> x4 = np.array([0, 5, 0, 3])
>>> np.polymul(x3, x4)
array([10. +0.j,  5.+10.j, 11.-10.j,  3. +6.j,  3. -6.j])
>>> jnp.polymul(x3, x4)
Array([ 0. +0.j, 10. +0.j,  5.+10.j, 11.-10.j,  3. +6.j,  3. -6.j],      dtype=complex64)
>>> jnp.polymul(x3, x4, trim_leading_zeros=True)
Array([10. +0.j,  5.+10.j, 11.-10.j,  3. +6.j,  3. -6.j], dtype=complex64)