jax.numpy.atleast_2d#

jax.numpy.atleast_2d(*arys)[原始碼]#

將輸入轉換為至少 2 維的陣列。

JAX 實作的 numpy.atleast_2d()

參數:
  • arguments. (零個多個類陣列)

  • arys (ArrayLike)

返回:

對應於輸入值的陣列或陣列列表。形狀為 () 的陣列會轉換為形狀 (1, 1),形狀為 (N,) 的 1D 陣列會轉換為形狀 (1, N),而所有其他形狀的陣列則保持不變。

返回類型:

Array | list[Array]

範例

純量引數會轉換為 2D、大小為 1 的陣列

>>> x = jnp.float32(1.0)
>>> jnp.atleast_2d(x)
Array([[1.]], dtype=float32)

一維引數會在形狀前面加上一個單位維度

>>> y = jnp.arange(4)
>>> jnp.atleast_2d(y)
Array([[0, 1, 2, 3]], dtype=int32)

更高維度的輸入會保持不變傳回

>>> z = jnp.ones((2, 3))
>>> jnp.atleast_2d(z)
Array([[1., 1., 1.],
       [1., 1., 1.]], dtype=float32)

可以一次將多個引數傳遞給函式,在這種情況下,會傳回結果列表

>>> jnp.atleast_2d(x, y)
[Array([[1.]], dtype=float32), Array([[0, 1, 2, 3]], dtype=int32)]