jax.numpy.expand_dims#

jax.numpy.expand_dims(a, axis)[原始碼]#

在陣列中插入長度為 1 的維度

numpy.expand_dims() 的 JAX 實作,透過 jax.lax.expand_dims() 實作。

參數:
  • a (ArrayLike) – 輸入陣列

  • axis (int | Sequence[int]) – 指定要新增軸位置的整數或整數序列。

傳回:

帶有新增維度的 a 副本。

傳回型別:

Array

筆記

numpy.expand_dims() 不同,jax.numpy.expand_dims() 將傳回輸入陣列的副本而不是視圖。然而,在 JIT 下,編譯器會盡可能最佳化掉這些副本,因此實際上不會對效能造成影響。

另請參閱

範例

>>> x = jnp.array([1, 2, 3])
>>> x.shape
(3,)

展開前導維度

>>> jnp.expand_dims(x, 0)
Array([[1, 2, 3]], dtype=int32)
>>> _.shape
(1, 3)

展開尾隨維度

>>> jnp.expand_dims(x, 1)
Array([[1],
       [2],
       [3]], dtype=int32)
>>> _.shape
(3, 1)

展開多個維度

>>> jnp.expand_dims(x, (0, 1, 3))
Array([[[[1],
         [2],
         [3]]]], dtype=int32)
>>> _.shape
(1, 1, 3, 1)

維度也可以透過使用 None 索引更簡潔地展開

>>> x[None]  # equivalent to jnp.expand_dims(x, 0)
Array([[1, 2, 3]], dtype=int32)
>>> x[:, None]  # equivalent to jnp.expand_dims(x, 1)
Array([[1],
       [2],
       [3]], dtype=int32)
>>> x[None, None, :, None]  # equivalent to jnp.expand_dims(x, (0, 1, 3))
Array([[[[1],
         [2],
         [3]]]], dtype=int32)