jax.lax.with_sharding_constraint#

jax.lax.with_sharding_constraint(x, shardings)[原始碼]#

在 jitted 計算中約束 Array 分片的機制

對於 GSPMD 分割器而言,這是嚴格的約束,而不是提示。如需如何使用此函數的範例,請參閱分散式陣列和自動平行化

參數:
  • x – 將約束其分片的 jax.Arrays PyTree

  • shardings – 分片規格的 PyTree。有效值與 jax.experimental.pjit()in_shardings 引數相同。

傳回:

具有指定分片約束的 jax.Arrays PyTree。

傳回類型:

x_with_shardings