jax.numpy.broadcast_shapes#

jax.numpy.broadcast_shapes(*shapes)[原始碼]#

將輸入形狀廣播到共同的輸出形狀。

JAX 實作的 numpy.broadcast_shapes()。JAX 使用 NumPy 風格的廣播規則,您可以在 NumPy 廣播 閱讀更多相關資訊。

參數:

shapes – 0 個或多個形狀,指定為整數序列

返回:

廣播後的形狀,以整數元組表示。

另請參閱

範例

一些相容的形狀

>>> jnp.broadcast_shapes((1,), (4,))
(4,)
>>> jnp.broadcast_shapes((3, 1), (4,))
(3, 4)
>>> jnp.broadcast_shapes((3, 1), (1, 4), (5, 1, 1))
(5, 3, 4)

不相容的形狀

>>> jnp.broadcast_shapes((3, 1), (4, 1))  
Traceback (most recent call last):
ValueError: Incompatible shapes for broadcasting: shapes=[(3, 1), (4, 1)]