jax.nn.one_hot#

jax.nn.one_hot(x, num_classes, *, dtype=<class 'jax.numpy.float64'>, axis=-1)[原始碼]#

對給定的索引進行 one-hot 編碼。

輸入 x 中的每個索引都會編碼為長度為 num_classes 的零向量,其中索引 index 處的元素設定為一

>>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3)
Array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float32)

範圍 [0, num_classes) 之外的索引將編碼為零

>>> jax.nn.one_hot(jnp.array([-1, 3]), 3)
Array([[0., 0., 0.],
       [0., 0., 0.]], dtype=float32)
參數:
  • x (Any) – 索引的張量。

  • num_classes (int) – one-hot 維度中的類別數量。

  • dtype (Any) – 選項,傳回值的浮點數 dtype (預設為 jnp.float_)。

  • axis (int | AxisName) – 應計算函式的軸或軸。

傳回型別:

Array