jax.numpy.polyval#

jax.numpy.polyval(p, x, *, unroll=16)[source]#

在特定值評估多項式。

numpy.polyval() 的 JAX 實作。

對於長度為 M 的 1D 多項式係數 p,此函數傳回值

\[p_0 x^{M - 1} + p_1 x^{M - 2} + ... + p_{M - 1}\]
參數:
  • p (ArrayLike) – 形狀為 (M,) 的多項式係數陣列。

  • x (ArrayLike) – 數字或數字陣列。

  • unroll (int) – 用於控制 lax.scan 中展開步驟數量的數字。 必須靜態指定。

傳回:

形狀與 x 相同的陣列。

傳回類型:

Array

注意

unroll 參數是 JAX 特有的。 它不影響正確性,但對於評估高階多項式的效能有重大影響。 該參數控制 jnp.polyval 實作內部的 lax.scan 中展開步驟的數量。 考慮將 unroll=128 (甚至更高) 以提高加速器上的執行時間效能,但會增加編譯時間。

另請參閱

範例

>>> p = jnp.array([2, 5, 1])
>>> jnp.polyval(p, 3)
Array(34., dtype=float32)

如果 x 是 2D 陣列,則 polyval 傳回與 x 形狀相同的 2D 陣列

>>> x = jnp.array([[2, 1, 5],
...                [3, 4, 7],
...                [1, 3, 5]])
>>> jnp.polyval(p, x)
Array([[ 19.,   8.,  76.],
       [ 34.,  53., 134.],
       [  8.,  34.,  76.]], dtype=float32)