jax.random.bernoulli#
- jax.random.bernoulli(key, p=np.float32(0.5), shape=None)[原始碼]#
取樣具有給定形狀和平均值的 Bernoulli 隨機值。
這些值根據以下機率質量函數分佈
\[f(k; p) = p^k(1 - p)^{1 - k}\]其中 \(k \in \{0, 1\}\) 且 \(0 \le p \le 1\)。
- 參數:
key (ArrayLike) – 作為隨機金鑰使用的 PRNG 金鑰。
p (RealArray) – 選填,隨機變數平均值的浮點數或浮點數陣列。必須與
shape
進行廣播相容。預設值為 0.5。shape (Shape | None | None) – 選填,代表結果形狀的非負整數元組。必須與
p.shape
進行廣播相容。預設值 (None) 產生的結果形狀與p.shape
相同。
- 返回:
如果
shape
不是 None,則為具有布林 dtype 和由shape
給定形狀的隨機陣列,否則為p.shape
。- 返回類型: