jax.numpy.nanargmax#
- jax.numpy.nanargmax(a, axis=None, out=None, keepdims=None)[原始碼]#
傳回陣列最大值的索引,忽略 NaN 值。
numpy.nanargmax()
的 JAX 實作。- 參數:
- 傳回值:
一個陣列,包含沿指定軸的最大值的索引。
- 傳回類型:
注意
在軸包含全 NaN 值的情況下,傳回的索引將為 -1。這與
numpy.nanargmax()
的行為不同,後者會引發錯誤。另請參閱
jax.numpy.argmax()
:傳回最大值的索引。jax.numpy.nanargmin()
:計算argmin
,同時忽略 NaN 值。
範例
>>> x = jnp.array([1, 3, 5, 4, jnp.nan])
使用標準
argmax()
可能會導致意外的結果>>> jnp.argmax(x) Array(4, dtype=int32)
使用
nanargmax
傳回最大非 NaN 值的索引。>>> jnp.nanargmax(x) Array(2, dtype=int32)
>>> x = jnp.array([[1, 3, jnp.nan], ... [5, 4, jnp.nan]]) >>> jnp.nanargmax(x, axis=1) Array([1, 0], dtype=int32)
>>> jnp.nanargmax(x, axis=1, keepdims=True) Array([[1], [0]], dtype=int32)