jax.numpy.broadcast_arrays#

jax.numpy.broadcast_arrays(*args)[原始碼]#

將陣列廣播到共同形狀。

JAX 實作的 numpy.broadcast_arrays()。 JAX 使用 NumPy 風格的廣播規則,您可以在 NumPy 廣播 閱讀更多相關資訊。

參數:

args (ArrayLike) – 零或多個要廣播的類陣列物件。

返回:

包含輸入廣播副本的陣列列表。

返回類型:

list[Array]

參見

範例

>>> x = jnp.arange(3)
>>> y = jnp.int32(1)
>>> jnp.broadcast_arrays(x, y)
[Array([0, 1, 2], dtype=int32), Array([1, 1, 1], dtype=int32)]
>>> x = jnp.array([[1, 2, 3]])
>>> y = jnp.array([[10],
...                [20]])
>>> x2, y2 = jnp.broadcast_arrays(x, y)
>>> x2
Array([[1, 2, 3],
       [1, 2, 3]], dtype=int32)
>>> y2
Array([[10, 10, 10],
       [20, 20, 20]], dtype=int32)