jax.numpy.polymul#
- jax.numpy.polymul(a1, a2, *, trim_leading_zeros=False)[原始碼]#
傳回兩個多項式的乘積。
JAX 實作的
numpy.polymul()
。- 參數::
a1 (ArrayLike) – 多項式係數的 1D 陣列。
a2 (ArrayLike) – 多項式係數的 1D 陣列。
trim_leading_zeros (bool) – 預設為
False
。如果為True
,則移除傳回值中的前導零,以符合 numpy 的結果。但會阻止函式在編譯碼中使用。由於浮點算術錯誤累積的差異,值被視為零的截止值可能會導致 NumPy 和 JAX 之間,甚至不同 JAX 後端之間的不一致結果。trim_leading_zeros=True
時,結果可能會導致不一致的輸出形狀。
- 傳回::
兩個多項式乘積的係數陣列。輸出的 dtype 始終提升為非精確型別。
- 傳回型別::
注意
jax.numpy.polymul()
僅接受陣列作為輸入,不像numpy.polymul()
也接受純量輸入。另請參閱
jax.numpy.polyadd()
:計算兩個多項式的和。jax.numpy.polysub()
:計算兩個多項式的差。jax.numpy.polydiv()
:計算多項式除法的商和餘數。
範例
>>> x1 = np.array([2, 1, 0]) >>> x2 = np.array([0, 5, 0, 3]) >>> np.polymul(x1, x2) array([10, 5, 6, 3, 0]) >>> jnp.polymul(x1, x2) Array([ 0., 10., 5., 6., 3., 0.], dtype=float32)
如果
trim_leading_zeros=True
,結果會與np.polymul
的結果一致。>>> jnp.polymul(x1, x2, trim_leading_zeros=True) Array([10., 5., 6., 3., 0.], dtype=float32)
對於 dtype 為
complex
的輸入陣列>>> x3 = np.array([2., 1+2j, 1-2j]) >>> x4 = np.array([0, 5, 0, 3]) >>> np.polymul(x3, x4) array([10. +0.j, 5.+10.j, 11.-10.j, 3. +6.j, 3. -6.j]) >>> jnp.polymul(x3, x4) Array([ 0. +0.j, 10. +0.j, 5.+10.j, 11.-10.j, 3. +6.j, 3. -6.j], dtype=complex64) >>> jnp.polymul(x3, x4, trim_leading_zeros=True) Array([10. +0.j, 5.+10.j, 11.-10.j, 3. +6.j, 3. -6.j], dtype=complex64)