jax.random.categorical#

jax.random.categorical(key, logits, axis=-1, shape=None)[原始碼]#

從類別分佈中採樣隨機值。

參數:
  • key (ArrayLike) – 用作隨機金鑰的 PRNG 金鑰。

  • logits (RealArray) – 要从中採樣的類別分佈的未正規化對數機率,因此 softmax(logits, axis) 給出相應的機率。

  • axis (int) – logits 屬於相同類別分佈的軸。

  • shape (Shape | None | None) – 選用,表示結果形狀的非負整數元組。必須與 np.delete(logits.shape, axis) 廣播相容。預設值 (None) 產生的結果形狀等於 np.delete(logits.shape, axis)

返回:

如果 shape 不是 None,則為具有 int dtype 和由 shape 給定的形狀的隨機陣列,否則為 np.delete(logits.shape, axis)

返回型別:

Array