jax.lax.top_k#
- jax.lax.top_k(operand, k)[原始碼]#
傳回
operand
最後一個軸的前k
個值及其索引。- 參數:
operand (ArrayLike) – 非複數類型的 N 維陣列。
k (int) – 整數,指定頂部條目的數量。
- 傳回:
一個元組
(values, indices)
,其中values
是一個陣列,包含沿最後一個軸的前 k 個值。indices
是一個陣列,包含對應於值的索引。
- 傳回類型:
範例
在陣列中找到最大的三個值及其索引
>>> x = jnp.array([9., 3., 6., 4., 10.]) >>> values, indices = jax.lax.top_k(x, 3) >>> values Array([10., 9., 6.], dtype=float32) >>> indices Array([4, 0, 2], dtype=int32)