jax.numpy.compress#

jax.numpy.compress(condition, a, axis=None, *, size=None, fill_value=0, out=None)[原始碼]#

使用布林條件壓縮沿著給定軸的陣列。

numpy.compress() 的 JAX 實作。

參數:
  • condition (ArrayLike) – 條件的 1 維陣列。將會轉換為布林值。

  • a (ArrayLike) – 值的 N 維陣列。

  • axis (int | None | None) – 要沿著壓縮的軸。如果為 None (預設值),則 a 將會被展平,且軸將設定為 0。

  • size (int | None | None) – 輸出的可選靜態大小。為了讓 compress 與 JAX 轉換 (如 jit()vmap()) 相容,必須指定此參數。

  • fill_value (ArrayLike) – 如果指定 size,則使用此值填滿填充的條目 (預設值:0)。

  • out (None | None) – JAX 尚未實作。

傳回:

沿著指定軸壓縮的 a.ndim 維陣列。

傳回類型:

Array

參見

筆記

此函式不要求 conditiona 之間嚴格的形狀一致性。如果 condition.size > a.shape[axis],則 condition 將被截斷;如果 a.shape[axis] > condition.size,則 a 將被截斷。

範例

沿著 2 維陣列的列壓縮

>>> a = jnp.array([[1,  2,  3,  4],
...                [5,  6,  7,  8],
...                [9,  10, 11, 12]])
>>> condition = jnp.array([True, False, True])
>>> jnp.compress(condition, a, axis=0)
Array([[ 1,  2,  3,  4],
       [ 9, 10, 11, 12]], dtype=int32)

為了方便起見,您可以等效地使用 JAX 陣列的 compress() 方法

>>> a.compress(condition, axis=0)
Array([[ 1,  2,  3,  4],
       [ 9, 10, 11, 12]], dtype=int32)

請注意,條件不需要符合指定軸的形狀;這裡我們使用長度為 3 的條件壓縮列。超出條件大小的值將被忽略

>>> jnp.compress(condition, a, axis=1)
Array([[ 1,  3],
       [ 5,  7],
       [ 9, 11]], dtype=int32)

可選的 size 引數讓您指定靜態輸出大小,以便輸出是靜態形狀的,因此此函式可以與 jit()vmap() 等轉換一起使用

>>> f = lambda c, a: jnp.extract(c, a, size=len(a), fill_value=0)
>>> mask = (a % 3 == 0)
>>> jax.vmap(f)(mask, a)
Array([[ 3,  0,  0,  0],
       [ 6,  0,  0,  0],
       [ 9, 12,  0,  0]], dtype=int32)