jax.numpy.roots#

jax.numpy.roots(p, *, strip_zeros=True)[原始碼]#

傳回給定係數 p 的多項式根。

numpy.roots() 的 JAX 實作。

參數:
  • p (ArrayLike) – 多項式係數陣列,秩為 1。

  • strip_zeros (bool) – bool,預設值=True。如果為 True,則會去除係數中的前導零,類似於 numpy.roots()。如果設定為 False,則不會去除前導零,且未定義的根將在函式輸出中以 NaN 值表示。strip_zeros 必須設定為 False,函式才能與 jax.jit() 和其他 JAX 轉換相容。

傳回:

包含多項式根的陣列。

傳回類型:

Array

注意

與此函式的 np.roots 不同,jnp.roots 會傳回複數陣列中的根,無論根的值為何。

另請參閱

範例

>>> coeffs = jnp.array([0, 1, 2])

預設行為與 numpy 相同,並去除前導零

>>> jnp.roots(coeffs)
Array([-2.+0.j], dtype=complex64)

使用 strip_zeros=False,額外的根會設定為 NaN

>>> jnp.roots(coeffs, strip_zeros=False)
Array([-2. +0.j, nan+nanj], dtype=complex64)