jax.numpy.atleast_3d#
- jax.numpy.atleast_3d(*arys)[source]#
將輸入轉換為至少 3 維的陣列。
JAX 實作的
numpy.atleast_3d()
。- 參數:
arguments. (零個 或 多個類陣列)
arys (類陣列)
- 回傳:
一個陣列或陣列列表,對應到輸入值。形狀為
()
的陣列會被轉換為形狀(1, 1, 1)
,形狀為(N,)
的一維陣列會被轉換為形狀(1, N, 1)
,形狀為(M, N)
的二維陣列會被轉換為形狀(M, N, 1)
,而所有其他形狀的陣列則保持不變。- 回傳類型:
範例
純量引數會被轉換為 3D、大小為 1 的陣列
>>> x = jnp.float32(1.0) >>> jnp.atleast_3d(x) Array([[[1.]]], dtype=float32)
一維陣列會在前面和後面加上一個單位維度
>>> y = jnp.arange(4) >>> jnp.atleast_3d(y).shape (1, 4, 1)
二維陣列會在後面加上一個單位維度
>>> z = jnp.ones((2, 3)) >>> jnp.atleast_3d(z).shape (2, 3, 1)
可以一次將多個引數傳遞給函式,在這種情況下,會回傳一個結果列表
>>> x3, y3 = jnp.atleast_3d(x, y) >>> print(x3) [[[1.]]] >>> print(y3) [[[0] [1] [2] [3]]]