jax.numpy.delete#
- jax.numpy.delete(arr, obj, axis=None, *, assume_unique_indices=False)[來源]#
從陣列中刪除條目或多個條目。
numpy.delete()
的 JAX 實作。- 參數:
- 回傳:
arr
的副本,其中已刪除指定的索引。- 回傳型別:
注意
delete()
通常要求索引規格為靜態。如果索引是保證包含唯一條目的整數陣列,您可以指定assume_unique_indices=True
,以執行不需要靜態索引的操作。另請參閱
jax.numpy.insert()
:將條目插入陣列。
範例
從一維陣列中刪除條目
>>> a = jnp.array([4, 5, 6, 7, 8, 9]) >>> jnp.delete(a, 2) Array([4, 5, 7, 8, 9], dtype=int32) >>> jnp.delete(a, slice(1, 4)) # delete a[1:4] Array([4, 8, 9], dtype=int32) >>> jnp.delete(a, slice(None, None, 2)) # delete a[::2] Array([5, 7, 9], dtype=int32)
從二維陣列中沿指定軸刪除條目
>>> a2 = jnp.array([[4, 5, 6], ... [7, 8, 9]]) >>> jnp.delete(a2, 1, axis=1) Array([[4, 6], [7, 9]], dtype=int32)
透過索引序列刪除多個條目
>>> indices = jnp.array([0, 1, 3]) >>> jnp.delete(a, indices) Array([6, 8, 9], dtype=int32)
這在
jit()
和其他轉換下會失敗,因為在可能出現重複索引的情況下,無法得知輸出形狀>>> jax.jit(jnp.delete)(a, indices) Traceback (most recent call last): ... ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[3].
如果您可以確保索引是唯一的,請傳遞
assume_unique_indices
以允許在 JIT 下執行此操作>>> jit_delete = jax.jit(jnp.delete, static_argnames=['assume_unique_indices']) >>> jit_delete(a, indices, assume_unique_indices=True) Array([6, 8, 9], dtype=int32)