jax.scipy.special.log_ndtr#

jax.scipy.special.log_ndtr = <jax._src.custom_derivatives.custom_jvp object>[原始碼]#

對數常態分佈函數。

JAX 版本的 scipy.special.log_ndtr

有關常態分佈函數的詳細資訊,請參閱 ndtr

此函數計算 \(\log(\mathrm{ndtr}(x))\),方法為直接呼叫 \(\log(\mathrm{ndtr}(x))\) 或使用漸近級數。具體來說:

  • 對於 x > upper_segment,使用基於 \(\log(1-x) \approx -x, x \ll 1\) 的近似值 -ndtr(-x)

  • 對於 lower_segment < x <= upper_segment,使用現有的 ndtr 技術並取對數。

  • 對於 x <= lower_segment,我們使用 erf 的級數近似值直接計算對數 CDF。

lower_segment 的設定基於輸入的精度

\[\begin{split}\begin{align} \mathit{lower\_segment} =& \ \begin{cases} -20 & x.\mathrm{dtype}=\mathit{float64} \\ -10 & x.\mathrm{dtype}=\mathit{float32} \\ \end{cases} \\ \mathit{upper\_segment} =& \ \begin{cases} 8& x.\mathrm{dtype}=\mathit{float64} \\ 5& x.\mathrm{dtype}=\mathit{float32} \\ \end{cases} \end{align}\end{split}\]

x < lower_segment 時,ndtr 漸近級數近似值為

\[\begin{split}\begin{align} \mathrm{ndtr}(x) =&\ \mathit{scale} * (1 + \mathit{sum}) + R_N \\ \mathit{scale} =&\ \frac{e^{-0.5 x^2}}{-x \sqrt{2 \pi}} \\ \mathit{sum} =&\ \sum_{n=1}^N {-1}^n (2n-1)!! / (x^2)^n \\ R_N =&\ O(e^{-0.5 x^2} (2N+1)!! / |x|^{2N+3}) \end{align}\end{split}\]

其中 \((2n-1)!! = (2n-1) (2n-3) (2n-5) ... (3) (1)\)雙階乘運算子。

參數:
  • x (ArrayLike) – float32float64 型別的陣列。

  • series_order (int) – 正 Python 整數。評估漸近展開式的最大深度。這是上述的 N

傳回:

具有 dtype=x.dtype 的陣列。

引發:
  • TypeError – 如果不處理 x.dtype

  • TypeError – 如果 series_order 不是 Python integer

  • ValueError – 如果 series_order 不在 [0, 30] 範圍內。

傳回型別:

Array