jax.numpy.broadcast_arrays#
- jax.numpy.broadcast_arrays(*args)[原始碼]#
將陣列廣播到共同形狀。
JAX 實作的
numpy.broadcast_arrays()
。 JAX 使用 NumPy 風格的廣播規則,您可以在 NumPy 廣播 閱讀更多相關資訊。參見
jax.numpy.broadcast_shapes()
:將輸入形狀廣播到共同形狀。jax.numpy.broadcast_to()
:將陣列廣播到指定的形狀。
範例
>>> x = jnp.arange(3) >>> y = jnp.int32(1) >>> jnp.broadcast_arrays(x, y) [Array([0, 1, 2], dtype=int32), Array([1, 1, 1], dtype=int32)]
>>> x = jnp.array([[1, 2, 3]]) >>> y = jnp.array([[10], ... [20]]) >>> x2, y2 = jnp.broadcast_arrays(x, y) >>> x2 Array([[1, 2, 3], [1, 2, 3]], dtype=int32) >>> y2 Array([[10, 10, 10], [20, 20, 20]], dtype=int32)