jax.numpy.ravel_multi_index#

jax.numpy.ravel_multi_index(multi_index, dims, mode='raise', order='C')[原始碼]#

將多維索引轉換為扁平索引。

JAX 實作的 numpy.ravel_multi_index()

參數:
  • multi_index (Sequence[ArrayLike]) – 包含每個維度索引的整數陣列序列。

  • dims (Sequence[int]) – 整數大小序列;必須有 len(dims) == len(multi_index)

  • mode (str) –

    如何處理超出邊界的索引。選項包括

    • "raise" (預設):引發 ValueError。此模式與 jit() 或其他 JAX 轉換不相容。

    • "clip":將超出邊界的索引裁剪到有效範圍。

    • "wrap":將超出邊界的索引包裝到有效範圍。

  • order (str) – "C" (預設) 或 "F",指定是否假設 C 風格的 row-major 順序或 Fortran 風格的 column-major 順序。

返回:

扁平化索引的陣列

返回類型:

Array

參見

jax.numpy.unravel_index():此函式的反函數。

範例

定義一個二維陣列和一個偶數值的索引序列

>>> x = jnp.array([[2., 3., 4.],
...                [5., 6., 7.]])
>>> indices = jnp.where(x % 2 == 0)
>>> indices
(Array([0, 0, 1], dtype=int32), Array([0, 2, 1], dtype=int32))
>>> x[indices]
Array([2., 4., 6.], dtype=float32)

計算扁平化索引

>>> indices_flat = jnp.ravel_multi_index(indices, x.shape)
>>> indices_flat
Array([0, 2, 4], dtype=int32)

這些扁平化索引可以用於從扁平化的 x 陣列中提取相同的值

>>> x_flat = x.ravel()
>>> x_flat
Array([2., 3., 4., 5., 6., 7.], dtype=float32)
>>> x_flat[indices_flat]
Array([2., 4., 6.], dtype=float32)

原始索引可以使用 unravel_index() 恢復

>>> jnp.unravel_index(indices_flat, x.shape)
(Array([0, 0, 1], dtype=int32), Array([0, 2, 1], dtype=int32))