jax.lax.argmax#

jax.lax.argmax(operand, axis, index_dtype)[原始碼]#

計算沿著 `axis` 的最大元素的索引。

參數:
  • operand (ArrayLike)

  • axis (int)

  • index_dtype (DTypeLike)

回傳型別:

Array