jax.numpy.resize#

jax.numpy.resize(a, new_shape)[原始碼]#

傳回具有指定形狀的新陣列。

numpy.resize() 的 JAX 實作。

參數:
  • a (ArrayLike) – 輸入陣列或純量。

  • new_shape (Shape) – 整數或整數元組。指定調整大小後陣列的形狀。

傳回:

具有指定形狀的調整大小後陣列。如果調整大小後的陣列大於原始陣列,則會在調整大小後的陣列中重複 a 的元素。

傳回類型:

陣列

參見

範例

>>> x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
>>> jnp.resize(x, (3, 3))
Array([[1, 2, 3],
       [4, 5, 6],
       [7, 8, 9]], dtype=int32)
>>> jnp.resize(x, (3, 4))
Array([[1, 2, 3, 4],
       [5, 6, 7, 8],
       [9, 1, 2, 3]], dtype=int32)
>>> jnp.resize(4, (3, 2))
Array([[4, 4],
       [4, 4],
       [4, 4]], dtype=int32, weak_type=True)