jax.lax.top_k#

jax.lax.top_k(operand, k)[原始碼]#

傳回 operand 最後一個軸的前 k 個值及其索引。

參數:
  • operand (ArrayLike) – 非複數類型的 N 維陣列。

  • k (int) – 整數,指定頂部條目的數量。

傳回:

一個元組 (values, indices),其中

  • values 是一個陣列,包含沿最後一個軸的前 k 個值。

  • indices 是一個陣列,包含對應於值的索引。

傳回類型:

tuple[Array, Array]

範例

在陣列中找到最大的三個值及其索引

>>> x = jnp.array([9., 3., 6., 4., 10.])
>>> values, indices = jax.lax.top_k(x, 3)
>>> values
Array([10.,  9.,  6.], dtype=float32)
>>> indices
Array([4, 0, 2], dtype=int32)