jax.numpy.setdiff1d#
- jax.numpy.setdiff1d(ar1, ar2, assume_unique=False, *, size=None, fill_value=None)[原始碼]#
計算兩個 1D 陣列的集合差集。
JAX 實作的
numpy.setdiff1d()
。由於
setdiff1d
的輸出大小取決於資料,因此此函數通常與jit()
和其他 JAX 轉換不相容。JAX 版本新增了可選的size
引數,必須靜態指定此引數,才能在這些情況下使用jnp.setdiff1d
。- 參數:
ar1 (ArrayLike) – 要計算差集的第一個元素陣列。
ar2 (ArrayLike) – 要計算差集的第二個元素陣列。
assume_unique (bool) – 如果為 True,則假設輸入陣列包含唯一值。這可以實現更有效率的實作,但如果
assume_unique
為 True 且輸入陣列包含重複值,則行為未定義。預設值:False。size (int | None | None) – 如果指定,則僅傳回前
size
個排序元素。如果元素數量少於size
指示的數量,則傳回值將以fill_value
填補。fill_value (ArrayLike | None | None) – 當指定
size
且元素數量少於指示的數量時,以fill_value
填補剩餘的條目。fill_value
預設為最小值。
- 傳回:
即
ar1
中未包含在ar2
中的元素。- 傳回類型:
包含輸入陣列中元素集合差集的陣列
參見
jax.numpy.intersect1d()
:兩個 1D 陣列的集合交集。jax.numpy.setxor1d()
:兩個 1D 陣列的集合 XOR。jax.numpy.union1d()
:兩個 1D 陣列的集合聯集。
範例
計算兩個陣列的集合差集
>>> ar1 = jnp.array([1, 2, 3, 4]) >>> ar2 = jnp.array([3, 4, 5, 6]) >>> jnp.setdiff1d(ar1, ar2) Array([1, 2], dtype=int32)
由於輸出形狀是動態的,因此這將在
jit()
和其他轉換下失敗>>> jax.jit(jnp.setdiff1d)(ar1, ar2) Traceback (most recent call last): ... ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[4]. The error occurred while tracing the function setdiff1d at /Users/vanderplas/github/jax-ml/jax/jax/_src/numpy/setops.py:64 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1.
為了確保靜態已知的輸出形狀,您可以傳遞靜態
size
引數>>> jit_setdiff1d = jax.jit(jnp.setdiff1d, static_argnames=['size']) >>> jit_setdiff1d(ar1, ar2, size=2) Array([1, 2], dtype=int32)
如果
size
太小,則差集會被截斷>>> jit_setdiff1d(ar1, ar2, size=1) Array([1], dtype=int32)
如果
size
太大,則輸出會以fill_value
填補>>> jit_setdiff1d(ar1, ar2, size=4, fill_value=0) Array([1, 2, 0, 0], dtype=int32)