jax.numpy.extract#

jax.numpy.extract(condition, arr, *, size=None, fill_value=0)[原始碼]#

傳回滿足條件的陣列元素。

numpy.extract() 的 JAX 實作。

參數:
  • condition (ArrayLike) – 條件陣列。將轉換為布林值並展平為 1D。

  • arr (ArrayLike) – 要提取的值陣列。將展平為 1D。

  • size (int | None | None) – 輸出的選用靜態大小。必須指定,extract 才能與 JAX 轉換(如 jit()vmap())相容。

  • fill_value (ArrayLike) – 如果指定 size,則用此值填補填充的條目(預設值:0)。

傳回:

提取條目的 1D 陣列。如果指定 size,則結果的形狀將為 (size,),並用 fill_value 向右填充。如果未指定 size,則輸出形狀將取決於 condition 中的 True 條目數。

傳回類型:

Array

筆記

此函數不要求 conditionarr 之間有嚴格的形狀一致性。如果 condition.size > arr.size,則 condition 將被截斷;如果 arr.size > condition.size,則 arr 將被截斷。

另請參閱

jax.numpy.compress()extract 的多維版本。

範例

從 1D 陣列中提取值

>>> x = jnp.array([1, 2, 3, 4, 5, 6])
>>> mask = (x % 2 == 0)
>>> jnp.extract(mask, x)
Array([2, 4, 6], dtype=int32)

在最簡單的情況下,這等效於布林索引

>>> x[mask]
Array([2, 4, 6], dtype=int32)

為了與 JAX 轉換一起使用,您可以傳遞 size 引數來為輸出指定靜態形狀,以及選用的 fill_value,預設為零

>>> jnp.extract(mask, x, size=len(x), fill_value=0)
Array([2, 4, 6, 0, 0, 0], dtype=int32)

請注意,與布林索引不同,extract 不要求陣列和條件的大小之間嚴格一致,並且實際上會將兩者截斷為最小大小

>>> short_mask = jnp.array([False, True])
>>> jnp.extract(short_mask, x)
Array([2], dtype=int32)
>>> long_mask = jnp.array([True, False, True, False, False, False, False, False])
>>> jnp.extract(long_mask, x)
Array([1, 3], dtype=int32)