jax.numpy.polydiv#

jax.numpy.polydiv(u, v, *, trim_leading_zeros=False)[原始碼]#

傳回多項式除法的商和餘數。

JAX 版本的 numpy.polydiv()

參數:
  • u (ArrayLike) – 被除數多項式係數陣列。

  • v (ArrayLike) – 除數多項式係數陣列。

  • trim_leading_zeros (bool) – 預設值為 False。如果為 True,則移除傳回值中的前導零,以符合 numpy 的結果。但會使函數無法在編譯後的程式碼中使用。由於浮點算術誤差累積的差異,將值視為零的截止值可能導致 NumPy 和 JAX 之間,甚至不同 JAX 後端之間產生不一致的結果。當 trim_leading_zeros=True 時,可能會導致不一致的輸出形狀。

傳回值:

商和餘數陣列的元組。輸出的 dtype 始終提升為非精確型別。

傳回類型:

tuple[Array, Array]

注意

jax.numpy.polydiv() 僅接受陣列作為輸入,不同於 numpy.polydiv(),後者也接受純量輸入。

另請參閱

範例

>>> x1 = jnp.array([5, 7, 9])
>>> x2 = jnp.array([4, 1])
>>> np.polydiv(x1, x2)
(array([1.25  , 1.4375]), array([7.5625]))
>>> jnp.polydiv(x1, x2)
(Array([1.25  , 1.4375], dtype=float32), Array([0.    , 0.    , 7.5625], dtype=float32))

如果 trim_leading_zeros=True,則結果與 np.polydiv 的結果相符。

>>> jnp.polydiv(x1, x2, trim_leading_zeros=True)
(Array([1.25  , 1.4375], dtype=float32), Array([7.5625], dtype=float32))