jax.numpy.resize#
- jax.numpy.resize(a, new_shape)[原始碼]#
傳回具有指定形狀的新陣列。
numpy.resize()
的 JAX 實作。- 參數:
a (ArrayLike) – 輸入陣列或純量。
new_shape (Shape) – 整數或整數元組。指定調整大小後陣列的形狀。
- 傳回:
具有指定形狀的調整大小後陣列。如果調整大小後的陣列大於原始陣列,則會在調整大小後的陣列中重複
a
的元素。- 傳回類型:
參見
jax.numpy.reshape()
:傳回陣列的重新塑形副本。jax.numpy.repeat()
:從重複元素建構陣列。
範例
>>> x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9]) >>> jnp.resize(x, (3, 3)) Array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=int32) >>> jnp.resize(x, (3, 4)) Array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 1, 2, 3]], dtype=int32) >>> jnp.resize(4, (3, 2)) Array([[4, 4], [4, 4], [4, 4]], dtype=int32, weak_type=True)