jax.numpy.polydiv#
- jax.numpy.polydiv(u, v, *, trim_leading_zeros=False)[原始碼]#
傳回多項式除法的商和餘數。
JAX 版本的
numpy.polydiv()
。- 參數:
u (ArrayLike) – 被除數多項式係數陣列。
v (ArrayLike) – 除數多項式係數陣列。
trim_leading_zeros (bool) – 預設值為
False
。如果為True
,則移除傳回值中的前導零,以符合 numpy 的結果。但會使函數無法在編譯後的程式碼中使用。由於浮點算術誤差累積的差異,將值視為零的截止值可能導致 NumPy 和 JAX 之間,甚至不同 JAX 後端之間產生不一致的結果。當trim_leading_zeros=True
時,可能會導致不一致的輸出形狀。
- 傳回值:
商和餘數陣列的元組。輸出的 dtype 始終提升為非精確型別。
- 傳回類型:
注意
jax.numpy.polydiv()
僅接受陣列作為輸入,不同於numpy.polydiv()
,後者也接受純量輸入。另請參閱
jax.numpy.polyadd()
:計算兩個多項式的和。jax.numpy.polysub()
:計算兩個多項式的差。jax.numpy.polymul()
:計算兩個多項式的乘積。
範例
>>> x1 = jnp.array([5, 7, 9]) >>> x2 = jnp.array([4, 1]) >>> np.polydiv(x1, x2) (array([1.25 , 1.4375]), array([7.5625])) >>> jnp.polydiv(x1, x2) (Array([1.25 , 1.4375], dtype=float32), Array([0. , 0. , 7.5625], dtype=float32))
如果
trim_leading_zeros=True
,則結果與np.polydiv
的結果相符。>>> jnp.polydiv(x1, x2, trim_leading_zeros=True) (Array([1.25 , 1.4375], dtype=float32), Array([7.5625], dtype=float32))