jax.lax.sort_key_val# jax.lax.sort_key_val(keys, values, dimension=-1, is_stable=True)[原始碼]# 沿 dimension 排序 keys 並將相同排列應用於 values。 參數: keys (Array) values (ArrayLike) dimension (int) is_stable (bool) 返回類型: tuple[Array, Array]