jax.numpy.atleast_2d#
- jax.numpy.atleast_2d(*arys)[原始碼]#
將輸入轉換為至少 2 維的陣列。
JAX 實作的
numpy.atleast_2d()
。- 參數:
arguments. (零個或多個類陣列)
arys (ArrayLike)
- 返回:
對應於輸入值的陣列或陣列列表。形狀為
()
的陣列會轉換為形狀(1, 1)
,形狀為(N,)
的 1D 陣列會轉換為形狀(1, N)
,而所有其他形狀的陣列則保持不變。- 返回類型:
範例
純量引數會轉換為 2D、大小為 1 的陣列
>>> x = jnp.float32(1.0) >>> jnp.atleast_2d(x) Array([[1.]], dtype=float32)
一維引數會在形狀前面加上一個單位維度
>>> y = jnp.arange(4) >>> jnp.atleast_2d(y) Array([[0, 1, 2, 3]], dtype=int32)
更高維度的輸入會保持不變傳回
>>> z = jnp.ones((2, 3)) >>> jnp.atleast_2d(z) Array([[1., 1., 1.], [1., 1., 1.]], dtype=float32)
可以一次將多個引數傳遞給函式,在這種情況下,會傳回結果列表
>>> jnp.atleast_2d(x, y) [Array([[1.]], dtype=float32), Array([[0, 1, 2, 3]], dtype=int32)]