jax.numpy.sinc#
- jax.numpy.sinc(x, /)[原始碼]#
計算標準化的 sinc 函數。
numpy.sinc()
的 JAX 實作。標準化的 sinc 函數由下式給出
\[\mathrm{sinc}(x) = \frac{\sin({\pi x})}{\pi x}\]其中
sinc(0)
返回極限值1
。sinc 函數是平滑且無限可微的。- 參數::
x (ArrayLike) – 輸入陣列;將提升為非精確類型。
- 返回::
一個與
x
形狀相同的陣列,包含結果。- 返回類型::
範例
>>> x = jnp.array([-1, -0.5, 0, 0.5, 1]) >>> with jnp.printoptions(precision=3, suppress=True): ... jnp.sinc(x) Array([-0. , 0.637, 1. , 0.637, -0. ], dtype=float32)
將此與計算函數的樸素方法進行比較,該方法在零點未定義
>>> with jnp.printoptions(precision=3, suppress=True): ... jnp.sin(jnp.pi * x) / (jnp.pi * x) Array([-0. , 0.637, nan, 0.637, -0. ], dtype=float32)
JAX 為 sinc 定義了自訂梯度規則,即使對於更高階導數,也允許在零點準確評估梯度
>>> f = jnp.sinc >>> for i in range(1, 6): ... f = jax.grad(f) ... print(f"(d/dx)^{i} f(0.0) = {f(0.0):.2f}") ... (d/dx)^1 f(0.0) = 0.00 (d/dx)^2 f(0.0) = -3.29 (d/dx)^3 f(0.0) = 0.00 (d/dx)^4 f(0.0) = 19.48 (d/dx)^5 f(0.0) = 0.00