jax.numpy.expand_dims#
- jax.numpy.expand_dims(a, axis)[原始碼]#
在陣列中插入長度為 1 的維度
numpy.expand_dims()
的 JAX 實作,透過jax.lax.expand_dims()
實作。- 參數:
- 傳回:
帶有新增維度的
a
副本。- 傳回型別:
筆記
與
numpy.expand_dims()
不同,jax.numpy.expand_dims()
將傳回輸入陣列的副本而不是視圖。然而,在 JIT 下,編譯器會盡可能最佳化掉這些副本,因此實際上不會對效能造成影響。另請參閱
jax.numpy.squeeze()
:此操作的反向,即移除長度為 1 的維度。jax.lax.expand_dims()
:此功能的 XLA 版本。
範例
>>> x = jnp.array([1, 2, 3]) >>> x.shape (3,)
展開前導維度
>>> jnp.expand_dims(x, 0) Array([[1, 2, 3]], dtype=int32) >>> _.shape (1, 3)
展開尾隨維度
>>> jnp.expand_dims(x, 1) Array([[1], [2], [3]], dtype=int32) >>> _.shape (3, 1)
展開多個維度
>>> jnp.expand_dims(x, (0, 1, 3)) Array([[[[1], [2], [3]]]], dtype=int32) >>> _.shape (1, 1, 3, 1)
維度也可以透過使用
None
索引更簡潔地展開>>> x[None] # equivalent to jnp.expand_dims(x, 0) Array([[1, 2, 3]], dtype=int32) >>> x[:, None] # equivalent to jnp.expand_dims(x, 1) Array([[1], [2], [3]], dtype=int32) >>> x[None, None, :, None] # equivalent to jnp.expand_dims(x, (0, 1, 3)) Array([[[[1], [2], [3]]]], dtype=int32)