jax.random.generalized_normal#

jax.random.generalized_normal(key, p, shape=(), dtype=<class 'float'>)[source]#

從廣義常態分佈中取樣。

值會根據以下機率密度函數回傳

\[f(x;p) \propto e^{-|x|^p}\]

\(-\infty < x < \infty\) 域上,其中 \(p > 0\) 是形狀參數。

參數:
  • key (ArrayLike) – 作為隨機金鑰使用的 PRNG 金鑰。

  • p (float) – 代表形狀參數的浮點數。

  • shape (Shape) – 選填,結果的批次維度。預設值為 ()。

  • dtype (DTypeLikeFloat) – 選填,回傳值的浮點數資料型別 (如果 jax_enable_x64 為 true,則預設為 float64,否則為 float32)。

回傳:

具有指定形狀和資料型別的隨機陣列。

回傳型別:

Array