jax.numpy.squeeze#
- jax.numpy.squeeze(a, axis=None)[原始碼]#
從陣列中移除一個或多個長度為 1 的軸
JAX 實作的
numpy.sqeeze()
,透過jax.lax.squeeze()
實作。- 參數:
- 傳回:
a
的副本,已移除長度為 1 的軸。- 傳回型別:
注意事項
與
numpy.squeeze()
不同,jax.numpy.squeeze()
將傳回副本,而不是輸入陣列的視圖。但是,在 JIT 下,編譯器會在可能的情況下最佳化掉這些副本,因此這在實務上不會對效能產生影響。參見
jax.numpy.expand_dims()
:squeeze
的反向操作:新增長度為 1 的維度。jax.Array.squeeze()
:透過陣列方法的等效功能。jax.lax.squeeze()
:等效的 XLA API。jax.numpy.ravel()
:將陣列展平為 1D 形狀。jax.numpy.reshape()
:一般陣列重塑形狀。
範例
>>> x = jnp.array([[[0]], [[1]], [[2]]]) >>> x.shape (3, 1, 1)
擠壓所有長度為 1 的維度
>>> jnp.squeeze(x) Array([0, 1, 2], dtype=int32) >>> _.shape (3,)
等效於明確指定軸
>>> jnp.squeeze(x, axis=(1, 2)) Array([0, 1, 2], dtype=int32)
嘗試擠壓非單位軸會導致錯誤
>>> jnp.squeeze(x, axis=0) Traceback (most recent call last): ... ValueError: cannot select an axis to squeeze out which has size not equal to one, got shape=(3, 1, 1) and dimensions=(0,)
為了方便起見,此功能也可透過
jax.Array.squeeze()
方法使用>>> x.squeeze() Array([0, 1, 2], dtype=int32)