jax.numpy.extract#
- jax.numpy.extract(condition, arr, *, size=None, fill_value=0)[原始碼]#
傳回滿足條件的陣列元素。
numpy.extract()
的 JAX 實作。- 參數:
- 傳回:
提取條目的 1D 陣列。如果指定
size
,則結果的形狀將為(size,)
,並用fill_value
向右填充。如果未指定size
,則輸出形狀將取決於condition
中的 True 條目數。- 傳回類型:
筆記
此函數不要求
condition
和arr
之間有嚴格的形狀一致性。如果condition.size > arr.size
,則condition
將被截斷;如果arr.size > condition.size
,則arr
將被截斷。另請參閱
jax.numpy.compress()
:extract
的多維版本。範例
從 1D 陣列中提取值
>>> x = jnp.array([1, 2, 3, 4, 5, 6]) >>> mask = (x % 2 == 0) >>> jnp.extract(mask, x) Array([2, 4, 6], dtype=int32)
在最簡單的情況下,這等效於布林索引
>>> x[mask] Array([2, 4, 6], dtype=int32)
為了與 JAX 轉換一起使用,您可以傳遞
size
引數來為輸出指定靜態形狀,以及選用的fill_value
,預設為零>>> jnp.extract(mask, x, size=len(x), fill_value=0) Array([2, 4, 6, 0, 0, 0], dtype=int32)
請注意,與布林索引不同,
extract
不要求陣列和條件的大小之間嚴格一致,並且實際上會將兩者截斷為最小大小>>> short_mask = jnp.array([False, True]) >>> jnp.extract(short_mask, x) Array([2], dtype=int32) >>> long_mask = jnp.array([True, False, True, False, False, False, False, False]) >>> jnp.extract(long_mask, x) Array([1, 3], dtype=int32)