jax.random.categorical#
- jax.random.categorical(key, logits, axis=-1, shape=None)[原始碼]#
從類別分佈中採樣隨機值。
- 參數:
key (ArrayLike) – 用作隨機金鑰的 PRNG 金鑰。
logits (RealArray) – 要从中採樣的類別分佈的未正規化對數機率,因此 softmax(logits, axis) 給出相應的機率。
axis (int) – logits 屬於相同類別分佈的軸。
shape (Shape | None | None) – 選用,表示結果形狀的非負整數元組。必須與
np.delete(logits.shape, axis)
廣播相容。預設值 (None) 產生的結果形狀等於np.delete(logits.shape, axis)
。
- 返回:
如果
shape
不是 None,則為具有 int dtype 和由shape
給定的形狀的隨機陣列,否則為np.delete(logits.shape, axis)
。- 返回型別: