jax.numpy.intersect1d#
- jax.numpy.intersect1d(ar1, ar2, assume_unique=False, return_indices=False, *, size=None, fill_value=None)[原始碼]#
計算兩個 1D 陣列的集合交集。
JAX 實作的
numpy.intersect1d()
。由於
intersect1d
的輸出大小取決於資料,因此此函式通常與jit()
和其他 JAX 轉換不相容。JAX 版本新增了可選的size
引數,必須靜態指定該引數,才能在這些情況下使用jnp.intersect1d
。- 參數:
ar1 (ArrayLike) – 要進行交集的第一個值陣列。
ar2 (ArrayLike) – 要進行交集的第二個值陣列。
assume_unique (bool) – 如果為 True,則假設輸入陣列包含唯一值。這允許更有效率的實作,但如果
assume_unique
為 True 且輸入陣列包含重複值,則行為未定義。預設值:False。return_indices (bool) – 如果為 True,則傳回索引陣列,指定交集值首次出現在輸入陣列中的位置。
size (int | None | None) – 如果指定,則僅傳回前
size
個排序元素。如果元素數量少於size
指示的數量,則傳回值將以fill_value
填補,且傳回的索引將以超出範圍的索引填補。fill_value (ArrayLike | None | None) – 當指定
size
且元素數量少於指示的數量時,以fill_value
填補剩餘的條目。fill_value
預設為交集中的最小值。
- 傳回:
陣列
intersection
,或者如果return_indices=True
,則為陣列元組(intersection, ar1_indices, ar2_indices)
。傳回值為intersection
:一個 1D 陣列,包含同時出現在ar1
和ar2
中的每個值。ar1_indices
:(如果 return_indices=True 則傳回) 形狀為intersection.shape
的陣列,包含扁平化ar1
中intersection
值的索引。對於 1D 輸入,intersection
等效於ar1[ar1_indices]
。ar2_indices
:(如果 return_indices=True 則傳回) 形狀為intersection.shape
的陣列,包含扁平化ar2
中intersection
值的索引。對於 1D 輸入,intersection
等效於ar2[ar2_indices]
。
- 傳回類型:
另請參閱
jax.numpy.union1d()
:兩個 1D 陣列的集合聯集。jax.numpy.setxor1d()
:兩個 1D 陣列的集合 XOR。jax.numpy.setdiff1d()
:兩個 1D 陣列的集合差集。
範例
>>> ar1 = jnp.array([1, 2, 3, 4]) >>> ar2 = jnp.array([3, 4, 5, 6]) >>> jnp.intersect1d(ar1, ar2) Array([3, 4], dtype=int32)
計算帶有索引的交集
>>> intersection, ar1_indices, ar2_indices = jnp.intersect1d(ar1, ar2, return_indices=True) >>> intersection Array([3, 4], dtype=int32)
ar1_indices
提供ar1
中交集值的索引>>> ar1_indices Array([2, 3], dtype=int32) >>> jnp.all(intersection == ar1[ar1_indices]) Array(True, dtype=bool)
ar2_indices
提供ar2
中交集值的索引>>> ar2_indices Array([0, 1], dtype=int32) >>> jnp.all(intersection == ar2[ar2_indices]) Array(True, dtype=bool)