jax.lax.sort_key_val#

jax.lax.sort_key_val(keys, values, dimension=-1, is_stable=True)[原始碼]#

沿 dimension 排序 keys 並將相同排列應用於 values

參數:
  • keys (Array)

  • values (ArrayLike)

  • dimension (int)

  • is_stable (bool)

返回類型:

tuple[Array, Array]