jax.numpy.atleast_3d#

jax.numpy.atleast_3d(*arys)[source]#

將輸入轉換為至少 3 維的陣列。

JAX 實作的 numpy.atleast_3d()

參數:
  • arguments. (零個 多個類陣列)

  • arys (類陣列)

回傳:

一個陣列或陣列列表,對應到輸入值。形狀為 () 的陣列會被轉換為形狀 (1, 1, 1),形狀為 (N,) 的一維陣列會被轉換為形狀 (1, N, 1),形狀為 (M, N) 的二維陣列會被轉換為形狀 (M, N, 1),而所有其他形狀的陣列則保持不變。

回傳類型:

Array | list[Array]

範例

純量引數會被轉換為 3D、大小為 1 的陣列

>>> x = jnp.float32(1.0)
>>> jnp.atleast_3d(x)
Array([[[1.]]], dtype=float32)

一維陣列會在前面和後面加上一個單位維度

>>> y = jnp.arange(4)
>>> jnp.atleast_3d(y).shape
(1, 4, 1)

二維陣列會在後面加上一個單位維度

>>> z = jnp.ones((2, 3))
>>> jnp.atleast_3d(z).shape
(2, 3, 1)

可以一次將多個引數傳遞給函式,在這種情況下,會回傳一個結果列表

>>> x3, y3 = jnp.atleast_3d(x, y)
>>> print(x3)
[[[1.]]]
>>> print(y3)
[[[0]
  [1]
  [2]
  [3]]]