jax.numpy.intersect1d#

jax.numpy.intersect1d(ar1, ar2, assume_unique=False, return_indices=False, *, size=None, fill_value=None)[原始碼]#

計算兩個 1D 陣列的集合交集。

JAX 實作的 numpy.intersect1d()

由於 intersect1d 的輸出大小取決於資料,因此此函式通常與 jit() 和其他 JAX 轉換不相容。JAX 版本新增了可選的 size 引數,必須靜態指定該引數,才能在這些情況下使用 jnp.intersect1d

參數:
  • ar1 (ArrayLike) – 要進行交集的第一個值陣列。

  • ar2 (ArrayLike) – 要進行交集的第二個值陣列。

  • assume_unique (bool) – 如果為 True,則假設輸入陣列包含唯一值。這允許更有效率的實作,但如果 assume_unique 為 True 且輸入陣列包含重複值,則行為未定義。預設值:False。

  • return_indices (bool) – 如果為 True,則傳回索引陣列,指定交集值首次出現在輸入陣列中的位置。

  • size (int | None | None) – 如果指定,則僅傳回前 size 個排序元素。如果元素數量少於 size 指示的數量,則傳回值將以 fill_value 填補,且傳回的索引將以超出範圍的索引填補。

  • fill_value (ArrayLike | None | None) – 當指定 size 且元素數量少於指示的數量時,以 fill_value 填補剩餘的條目。fill_value 預設為交集中的最小值。

傳回:

陣列 intersection,或者如果 return_indices=True,則為陣列元組 (intersection, ar1_indices, ar2_indices)。傳回值為

  • intersection:一個 1D 陣列,包含同時出現在 ar1ar2 中的每個值。

  • ar1_indices(如果 return_indices=True 則傳回) 形狀為 intersection.shape 的陣列,包含扁平化 ar1intersection 值的索引。對於 1D 輸入,intersection 等效於 ar1[ar1_indices]

  • ar2_indices(如果 return_indices=True 則傳回) 形狀為 intersection.shape 的陣列,包含扁平化 ar2intersection 值的索引。對於 1D 輸入,intersection 等效於 ar2[ar2_indices]

傳回類型:

Array | tuple[Array, Array, Array]

另請參閱

範例

>>> ar1 = jnp.array([1, 2, 3, 4])
>>> ar2 = jnp.array([3, 4, 5, 6])
>>> jnp.intersect1d(ar1, ar2)
Array([3, 4], dtype=int32)

計算帶有索引的交集

>>> intersection, ar1_indices, ar2_indices = jnp.intersect1d(ar1, ar2, return_indices=True)
>>> intersection
Array([3, 4], dtype=int32)

ar1_indices 提供 ar1 中交集值的索引

>>> ar1_indices
Array([2, 3], dtype=int32)
>>> jnp.all(intersection == ar1[ar1_indices])
Array(True, dtype=bool)

ar2_indices 提供 ar2 中交集值的索引

>>> ar2_indices
Array([0, 1], dtype=int32)
>>> jnp.all(intersection == ar2[ar2_indices])
Array(True, dtype=bool)