jax.debug.inspect_array_sharding#
- jax.debug.inspect_array_sharding(value, *, callback)[原始碼]#
啟用在 JIT 編譯函式內檢查陣列分片。
此函式在提供陣列的 Pytree 時,會使用它們的每個分片進行回呼,並在
pjit
編譯的計算中運作,從而能夠檢查所選的中間分片。當分片資訊可用時,呼叫
callback
的策略是盡可能提早。這表示如果呼叫inspect_array_callback
時沒有任何轉換,則回呼會立即發生,因為我們已準備好陣列及其分片。在jax.jit
內部,回呼將在降低時間發生,這表示您可以使用 AOT API (jit(f).lower(...)
) 觸發回呼。在pjit
內部時,回呼會在編譯時間發生,因為分片是由 XLA 決定的。您可以使用 JAX 的 AOT API (pjit(f).lower(...).compile()
) 觸發回呼。在所有情況下,都將透過執行函式來觸發回呼,因為執行函式需要先降低和編譯它。但是,一旦函式被編譯並快取,回呼將不再發生。此函式為實驗性質,其行為在未來可能會變更。
- 參數:
value – JAX 陣列的 Pytree。
callback (Callable[[Sharding], None]) – 一個可呼叫物件,它接受
Sharding
並不傳回值。
在以下範例中,我們印出
pjit
編譯的計算中中間值的分片>>> import jax >>> import jax.numpy as jnp >>> from jax.experimental.pjit import pjit >>> from jax.sharding import Mesh, PartitionSpec >>> >>> x = jnp.arange(8, dtype=jnp.float32) >>> def f_(x): ... x = jnp.sin(x) ... jax.debug.inspect_array_sharding(x, callback=print) ... return jnp.square(x) >>> f = pjit(f_, in_shardings=PartitionSpec('dev'), ... out_shardings=PartitionSpec('dev')) >>> with Mesh(jax.devices(), ('dev',)): ... f.lower(x).compile() ... NamedSharding(mesh={'dev': 8}, partition_spec=PartitionSpec(('dev',),))