jax.nn.one_hot#
- jax.nn.one_hot(x, num_classes, *, dtype=<class 'jax.numpy.float64'>, axis=-1)[原始碼]#
對給定的索引進行 one-hot 編碼。
輸入
x
中的每個索引都會編碼為長度為num_classes
的零向量,其中索引index
處的元素設定為一>>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3) Array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32)
範圍 [0, num_classes) 之外的索引將編碼為零
>>> jax.nn.one_hot(jnp.array([-1, 3]), 3) Array([[0., 0., 0.], [0., 0., 0.]], dtype=float32)