jax.numpy.compress#
- jax.numpy.compress(condition, a, axis=None, *, size=None, fill_value=0, out=None)[原始碼]#
使用布林條件壓縮沿著給定軸的陣列。
numpy.compress()
的 JAX 實作。- 參數:
condition (ArrayLike) – 條件的 1 維陣列。將會轉換為布林值。
a (ArrayLike) – 值的 N 維陣列。
axis (int | None | None) – 要沿著壓縮的軸。如果為 None (預設值),則
a
將會被展平,且軸將設定為 0。size (int | None | None) – 輸出的可選靜態大小。為了讓
compress
與 JAX 轉換 (如jit()
或vmap()
) 相容,必須指定此參數。fill_value (ArrayLike) – 如果指定
size
,則使用此值填滿填充的條目 (預設值:0)。out (None | None) – JAX 尚未實作。
- 傳回:
沿著指定軸壓縮的
a.ndim
維陣列。- 傳回類型:
參見
jax.numpy.extract()
:compress
的 1 維版本。jax.Array.compress()
:作為陣列方法的等效功能。
筆記
此函式不要求
condition
和a
之間嚴格的形狀一致性。如果condition.size > a.shape[axis]
,則condition
將被截斷;如果a.shape[axis] > condition.size
,則a
將被截斷。範例
沿著 2 維陣列的列壓縮
>>> a = jnp.array([[1, 2, 3, 4], ... [5, 6, 7, 8], ... [9, 10, 11, 12]]) >>> condition = jnp.array([True, False, True]) >>> jnp.compress(condition, a, axis=0) Array([[ 1, 2, 3, 4], [ 9, 10, 11, 12]], dtype=int32)
為了方便起見,您可以等效地使用 JAX 陣列的
compress()
方法>>> a.compress(condition, axis=0) Array([[ 1, 2, 3, 4], [ 9, 10, 11, 12]], dtype=int32)
請注意,條件不需要符合指定軸的形狀;這裡我們使用長度為 3 的條件壓縮列。超出條件大小的值將被忽略
>>> jnp.compress(condition, a, axis=1) Array([[ 1, 3], [ 5, 7], [ 9, 11]], dtype=int32)
可選的
size
引數讓您指定靜態輸出大小,以便輸出是靜態形狀的,因此此函式可以與jit()
和vmap()
等轉換一起使用>>> f = lambda c, a: jnp.extract(c, a, size=len(a), fill_value=0) >>> mask = (a % 3 == 0) >>> jax.vmap(f)(mask, a) Array([[ 3, 0, 0, 0], [ 6, 0, 0, 0], [ 9, 12, 0, 0]], dtype=int32)