jax.random.bernoulli#

jax.random.bernoulli(key, p=np.float32(0.5), shape=None)[原始碼]#

取樣具有給定形狀和平均值的 Bernoulli 隨機值。

這些值根據以下機率質量函數分佈

\[f(k; p) = p^k(1 - p)^{1 - k}\]

其中 \(k \in \{0, 1\}\)\(0 \le p \le 1\)

參數:
  • key (ArrayLike) – 作為隨機金鑰使用的 PRNG 金鑰。

  • p (RealArray) – 選填,隨機變數平均值的浮點數或浮點數陣列。必須與 shape 進行廣播相容。預設值為 0.5。

  • shape (Shape | None | None) – 選填,代表結果形狀的非負整數元組。必須與 p.shape 進行廣播相容。預設值 (None) 產生的結果形狀與 p.shape 相同。

返回:

如果 shape 不是 None,則為具有布林 dtype 和由 shape 給定形狀的隨機陣列,否則為 p.shape

返回類型:

Array