jax.numpy.argmin#

jax.numpy.argmin(a, axis=None, out=None, keepdims=None)[原始碼]#

傳回陣列中最小值的索引。

JAX 版本的 numpy.argmin() 實作。

參數:
  • a (ArrayLike) – 輸入陣列

  • axis (int | None | None) – 可選的整數,指定要沿其尋找最小值的軸。如果未指定 axis,則會將 a 展平。

  • out (None | None) – JAX 未使用

  • keepdims (bool | None | None) – 如果為 True,則傳回與 a 維度數量相同的陣列。

傳回值:

一個陣列,包含沿指定軸的最小值索引。

傳回類型:

Array

另請參閱

範例

>>> x = jnp.array([1, 3, 5, 4, 2])
>>> jnp.argmin(x)
Array(0, dtype=int32)
>>> x = jnp.array([[1, 3, 2],
...                [5, 4, 1]])
>>> jnp.argmin(x, axis=1)
Array([0, 2], dtype=int32)
>>> jnp.argmin(x, axis=1, keepdims=True)
Array([[0],
       [2]], dtype=int32)