jax.lax.with_sharding_constraint#
- jax.lax.with_sharding_constraint(x, shardings)[原始碼]#
在 jitted 計算中約束 Array 分片的機制
對於 GSPMD 分割器而言,這是嚴格的約束,而不是提示。如需如何使用此函數的範例,請參閱分散式陣列和自動平行化。
- 參數:
x – 將約束其分片的 jax.Arrays PyTree
shardings – 分片規格的 PyTree。有效值與
jax.experimental.pjit()
的in_shardings
引數相同。
- 傳回:
具有指定分片約束的 jax.Arrays PyTree。
- 傳回類型:
x_with_shardings