jax.numpy.squeeze#

jax.numpy.squeeze(a, axis=None)[原始碼]#

從陣列中移除一個或多個長度為 1 的軸

JAX 實作的 numpy.sqeeze(),透過 jax.lax.squeeze() 實作。

參數:
  • a (ArrayLike) – 輸入陣列

  • axis (int | Sequence[int] | None | None) – 指定要移除軸的整數或整數序列。如果任何指定的軸長度不是 1,則會引發錯誤。如果未指定,則擠壓 a 中所有長度為 1 的軸。

傳回:

a 的副本,已移除長度為 1 的軸。

傳回型別:

Array

注意事項

numpy.squeeze() 不同,jax.numpy.squeeze() 將傳回副本,而不是輸入陣列的視圖。但是,在 JIT 下,編譯器會在可能的情況下最佳化掉這些副本,因此這在實務上不會對效能產生影響。

參見

範例

>>> x = jnp.array([[[0]], [[1]], [[2]]])
>>> x.shape
(3, 1, 1)

擠壓所有長度為 1 的維度

>>> jnp.squeeze(x)
Array([0, 1, 2], dtype=int32)
>>> _.shape
(3,)

等效於明確指定軸

>>> jnp.squeeze(x, axis=(1, 2))
Array([0, 1, 2], dtype=int32)

嘗試擠壓非單位軸會導致錯誤

>>> jnp.squeeze(x, axis=0)  
Traceback (most recent call last):
  ...
ValueError: cannot select an axis to squeeze out which has size not equal to one, got shape=(3, 1, 1) and dimensions=(0,)

為了方便起見,此功能也可透過 jax.Array.squeeze() 方法使用

>>> x.squeeze()
Array([0, 1, 2], dtype=int32)