jax.numpy.delete#

jax.numpy.delete(arr, obj, axis=None, *, assume_unique_indices=False)[來源]#

從陣列中刪除條目或多個條目。

numpy.delete() 的 JAX 實作。

參數:
  • arr (ArrayLike) – 將從中刪除條目的陣列。

  • obj (ArrayLike | slice) – 要刪除的索引、多個索引或切片。

  • axis (int | None | None) – 將沿其刪除條目的軸。

  • assume_unique_indices (bool) – 在使用類似陣列的整數 (而非布林值) 索引的情況下,假設索引是唯一的,並以與 JIT 和其他 JAX 轉換相容的方式執行刪除。

回傳:

arr 的副本,其中已刪除指定的索引。

回傳型別:

Array

注意

delete() 通常要求索引規格為靜態。如果索引是保證包含唯一條目的整數陣列,您可以指定 assume_unique_indices=True,以執行不需要靜態索引的操作。

另請參閱

範例

從一維陣列中刪除條目

>>> a = jnp.array([4, 5, 6, 7, 8, 9])
>>> jnp.delete(a, 2)
Array([4, 5, 7, 8, 9], dtype=int32)
>>> jnp.delete(a, slice(1, 4))  # delete a[1:4]
Array([4, 8, 9], dtype=int32)
>>> jnp.delete(a, slice(None, None, 2))  # delete a[::2]
Array([5, 7, 9], dtype=int32)

從二維陣列中沿指定軸刪除條目

>>> a2 = jnp.array([[4, 5, 6],
...                 [7, 8, 9]])
>>> jnp.delete(a2, 1, axis=1)
Array([[4, 6],
       [7, 9]], dtype=int32)

透過索引序列刪除多個條目

>>> indices = jnp.array([0, 1, 3])
>>> jnp.delete(a, indices)
Array([6, 8, 9], dtype=int32)

這在 jit() 和其他轉換下會失敗,因為在可能出現重複索引的情況下,無法得知輸出形狀

>>> jax.jit(jnp.delete)(a, indices)  
Traceback (most recent call last):
  ...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[3].

如果您可以確保索引是唯一的,請傳遞 assume_unique_indices 以允許在 JIT 下執行此操作

>>> jit_delete = jax.jit(jnp.delete, static_argnames=['assume_unique_indices'])
>>> jit_delete(a, indices, assume_unique_indices=True)
Array([6, 8, 9], dtype=int32)