jax.numpy.unravel_index#

jax.numpy.unravel_index(indices, shape)[原始碼]#

將扁平索引轉換為多維索引。

JAX 實作的 numpy.unravel_index()。JAX 版本在處理超出邊界索引方面有所不同:與 NumPy 不同,JAX 支援負索引,且超出邊界索引會被裁剪為最接近的有效值。

參數:
  • indices (ArrayLike) – 扁平索引的整數陣列

  • shape (Shape) – 要索引的多維陣列的形狀

傳回:

解開索引的元組

傳回型別:

tuple[Array, …]

另請參閱

jax.numpy.ravel_multi_index():此函數的反函數。

範例

從一維陣列值和索引開始

>>> x = jnp.array([2., 3., 4., 5., 6., 7.])
>>> indices = jnp.array([1, 3, 5])
>>> print(x[indices])
[3. 5. 7.]

現在,如果 x 被重塑,則可以使用 unravel_indices 將扁平索引轉換為存取相同條目的索引元組

>>> shape = (2, 3)
>>> x_2D = x.reshape(shape)
>>> indices_2D = jnp.unravel_index(indices, shape)
>>> indices_2D
(Array([0, 1, 1], dtype=int32), Array([1, 0, 2], dtype=int32))
>>> print(x_2D[indices_2D])
[3. 5. 7.]

反函數 ravel_multi_index 可用於取得原始索引

>>> jnp.ravel_multi_index(indices_2D, shape)
Array([1, 3, 5], dtype=int32)