jax.numpy.polyval#
- jax.numpy.polyval(p, x, *, unroll=16)[source]#
在特定值評估多項式。
numpy.polyval()
的 JAX 實作。對於長度為
M
的 1D 多項式係數p
,此函數傳回值\[p_0 x^{M - 1} + p_1 x^{M - 2} + ... + p_{M - 1}\]- 參數:
p (ArrayLike) – 形狀為
(M,)
的多項式係數陣列。x (ArrayLike) – 數字或數字陣列。
unroll (int) – 用於控制
lax.scan
中展開步驟數量的數字。 必須靜態指定。
- 傳回:
形狀與
x
相同的陣列。- 傳回類型:
注意
unroll
參數是 JAX 特有的。 它不影響正確性,但對於評估高階多項式的效能有重大影響。 該參數控制jnp.polyval
實作內部的lax.scan
中展開步驟的數量。 考慮將unroll=128
(甚至更高) 以提高加速器上的執行時間效能,但會增加編譯時間。另請參閱
jax.numpy.polyfit()
:最小平方多項式擬合。jax.numpy.poly()
:尋找具有給定根的多項式係數。jax.numpy.roots()
:計算給定係數的多項式根。
範例
>>> p = jnp.array([2, 5, 1]) >>> jnp.polyval(p, 3) Array(34., dtype=float32)
如果
x
是 2D 陣列,則polyval
傳回與x
形狀相同的 2D 陣列>>> x = jnp.array([[2, 1, 5], ... [3, 4, 7], ... [1, 3, 5]]) >>> jnp.polyval(p, x) Array([[ 19., 8., 76.], [ 34., 53., 134.], [ 8., 34., 76.]], dtype=float32)