jax.numpy.broadcast_to#
- jax.numpy.broadcast_to(array, shape)[原始碼]#
將陣列廣播到指定的形狀。
JAX 實作的
numpy.broadcast_to()
。 JAX 使用 NumPy 風格的廣播規則,您可以在 NumPy 廣播中閱讀更多相關資訊。- 參數:
array (ArrayLike) – 要廣播的陣列。
shape (DimSize | Shape) – 陣列將廣播到的形狀。
- 傳回:
廣播到指定形狀的陣列副本。
- 傳回類型:
參見
jax.numpy.broadcast_arrays()
:將陣列廣播到共同的形狀。jax.numpy.broadcast_shapes()
:將輸入形狀廣播到共同的形狀。
範例
>>> x = jnp.int32(1) >>> jnp.broadcast_to(x, (1, 4)) Array([[1, 1, 1, 1]], dtype=int32)
>>> x = jnp.array([1, 2, 3]) >>> jnp.broadcast_to(x, (2, 3)) Array([[1, 2, 3], [1, 2, 3]], dtype=int32)
>>> x = jnp.array([[2], [4]]) >>> jnp.broadcast_to(x, (2, 4)) Array([[2, 2, 2, 2], [4, 4, 4, 4]], dtype=int32)