jax.numpy.isin#

jax.numpy.isin(element, test_elements, assume_unique=False, invert=False, *, method='auto')[原始碼]#

判斷 element 中的元素是否出現在 test_elements 中。

JAX 實作的 numpy.isin()

參數:
  • element (類陣列) – 將檢查成員資格的元素輸入陣列。

  • test_elements (類陣列) – N 維測試值陣列,用於檢查每個元素是否存在。

  • invert (布林值) – 如果為 True,則返回 ~isin(element, test_elements)。預設值為 False。

  • assume_unique (布林值) – 如果為 true,則假設輸入陣列是唯一的,這可以提高計算效率。如果輸入陣列不是唯一的,並且 assume_unique 設定為 True,則結果是未定義的。

  • method – 字串,指定用於計算結果的方法。支援的選項包括 'compare_all'、'binary_search'、'sort' 和 'auto' (預設)。

返回:

形狀為 element.shape 的布林陣列,指定每個元素是否出現在 test_elements 中。

返回類型:

陣列

範例

>>> elements = jnp.array([1, 2, 3, 4])
>>> test_elements = jnp.array([[1, 5, 6, 3, 7, 1]])
>>> jnp.isin(elements, test_elements)
Array([ True, False,  True, False], dtype=bool)