jax.numpy.argmax#
- jax.numpy.argmax(a, axis=None, out=None, keepdims=None)[原始碼]#
傳回陣列中最大值的索引。
JAX 實作的
numpy.argmax()
。- 參數:
- 傳回:
一個陣列,其中包含沿指定軸的最大值的索引。
- 傳回類型:
另請參閱
jax.numpy.argmin()
:傳回最小值的索引。jax.numpy.nanargmax()
:計算argmax
,同時忽略 NaN 值。
範例
>>> x = jnp.array([1, 3, 5, 4, 2]) >>> jnp.argmax(x) Array(2, dtype=int32)
>>> x = jnp.array([[1, 3, 2], ... [5, 4, 1]]) >>> jnp.argmax(x, axis=1) Array([1, 0], dtype=int32)
>>> jnp.argmax(x, axis=1, keepdims=True) Array([[1], [0]], dtype=int32)