jax.numpy.nanargmax#

jax.numpy.nanargmax(a, axis=None, out=None, keepdims=None)[原始碼]#

傳回陣列最大值的索引,忽略 NaN 值。

numpy.nanargmax() 的 JAX 實作。

參數:
  • a (ArrayLike) – 輸入陣列

  • axis (int | None | None) – 可選整數,指定要沿其尋找最大值的軸。如果未指定 axis,則會將 a 平坦化。

  • out (None | None) – JAX 未使用

  • keepdims (bool | None | None) – 如果為 True,則傳回與 a 具有相同維度的陣列。

傳回值:

一個陣列,包含沿指定軸的最大值的索引。

傳回類型:

Array

注意

在軸包含全 NaN 值的情況下,傳回的索引將為 -1。這與 numpy.nanargmax() 的行為不同,後者會引發錯誤。

另請參閱

範例

>>> x = jnp.array([1, 3, 5, 4, jnp.nan])

使用標準 argmax() 可能會導致意外的結果

>>> jnp.argmax(x)
Array(4, dtype=int32)

使用 nanargmax 傳回最大非 NaN 值的索引。

>>> jnp.nanargmax(x)
Array(2, dtype=int32)
>>> x = jnp.array([[1, 3, jnp.nan],
...                [5, 4, jnp.nan]])
>>> jnp.nanargmax(x, axis=1)
Array([1, 0], dtype=int32)
>>> jnp.nanargmax(x, axis=1, keepdims=True)
Array([[1],
       [0]], dtype=int32)