jax.numpy.broadcast_shapes#
- jax.numpy.broadcast_shapes(*shapes)[原始碼]#
將輸入形狀廣播到共同的輸出形狀。
JAX 實作的
numpy.broadcast_shapes()
。JAX 使用 NumPy 風格的廣播規則,您可以在 NumPy 廣播 閱讀更多相關資訊。- 參數:
shapes – 0 個或多個形狀,指定為整數序列
- 返回:
廣播後的形狀,以整數元組表示。
另請參閱
jax.numpy.broadcast_arrays()
:將陣列廣播到共同的形狀。jax.numpy.broadcast_to()
:將陣列廣播到指定的形狀。
範例
一些相容的形狀
>>> jnp.broadcast_shapes((1,), (4,)) (4,) >>> jnp.broadcast_shapes((3, 1), (4,)) (3, 4) >>> jnp.broadcast_shapes((3, 1), (1, 4), (5, 1, 1)) (5, 3, 4)
不相容的形狀
>>> jnp.broadcast_shapes((3, 1), (4, 1)) Traceback (most recent call last): ValueError: Incompatible shapes for broadcasting: shapes=[(3, 1), (4, 1)]