jax.numpy.ravel_multi_index#
- jax.numpy.ravel_multi_index(multi_index, dims, mode='raise', order='C')[原始碼]#
將多維索引轉換為扁平索引。
JAX 實作的
numpy.ravel_multi_index()
- 參數:
multi_index (Sequence[ArrayLike]) – 包含每個維度索引的整數陣列序列。
dims (Sequence[int]) – 整數大小序列;必須有
len(dims) == len(multi_index)
mode (str) –
如何處理超出邊界的索引。選項包括
"raise"
(預設):引發 ValueError。此模式與jit()
或其他 JAX 轉換不相容。"clip"
:將超出邊界的索引裁剪到有效範圍。"wrap"
:將超出邊界的索引包裝到有效範圍。
order (str) –
"C"
(預設) 或"F"
,指定是否假設 C 風格的 row-major 順序或 Fortran 風格的 column-major 順序。
- 返回:
扁平化索引的陣列
- 返回類型:
參見
jax.numpy.unravel_index()
:此函式的反函數。範例
定義一個二維陣列和一個偶數值的索引序列
>>> x = jnp.array([[2., 3., 4.], ... [5., 6., 7.]]) >>> indices = jnp.where(x % 2 == 0) >>> indices (Array([0, 0, 1], dtype=int32), Array([0, 2, 1], dtype=int32)) >>> x[indices] Array([2., 4., 6.], dtype=float32)
計算扁平化索引
>>> indices_flat = jnp.ravel_multi_index(indices, x.shape) >>> indices_flat Array([0, 2, 4], dtype=int32)
這些扁平化索引可以用於從扁平化的
x
陣列中提取相同的值>>> x_flat = x.ravel() >>> x_flat Array([2., 3., 4., 5., 6., 7.], dtype=float32) >>> x_flat[indices_flat] Array([2., 4., 6.], dtype=float32)
原始索引可以使用
unravel_index()
恢復>>> jnp.unravel_index(indices_flat, x.shape) (Array([0, 0, 1], dtype=int32), Array([0, 2, 1], dtype=int32))