jax.debug.inspect_array_sharding#

jax.debug.inspect_array_sharding(value, *, callback)[原始碼]#

啟用在 JIT 編譯函式內檢查陣列分片。

此函式在提供陣列的 Pytree 時,會使用它們的每個分片進行回呼,並在 pjit 編譯的計算中運作,從而能夠檢查所選的中間分片。

當分片資訊可用時,呼叫 callback 的策略是盡可能提早。這表示如果呼叫 inspect_array_callback 時沒有任何轉換,則回呼會立即發生,因為我們已準備好陣列及其分片。在 jax.jit 內部,回呼將在降低時間發生,這表示您可以使用 AOT API (jit(f).lower(...)) 觸發回呼。在 pjit 內部時,回呼會在編譯時間發生,因為分片是由 XLA 決定的。您可以使用 JAX 的 AOT API (pjit(f).lower(...).compile()) 觸發回呼。在所有情況下,都將透過執行函式來觸發回呼,因為執行函式需要先降低和編譯它。但是,一旦函式被編譯並快取,回呼將不再發生。

此函式為實驗性質,其行為在未來可能會變更。

參數:
  • value – JAX 陣列的 Pytree。

  • callback (Callable[[Sharding], None]) – 一個可呼叫物件,它接受 Sharding 並不傳回值。

在以下範例中,我們印出 pjit 編譯的計算中中間值的分片

>>> import jax
>>> import jax.numpy as jnp
>>> from jax.experimental.pjit import pjit
>>> from jax.sharding import Mesh, PartitionSpec
>>>
>>> x = jnp.arange(8, dtype=jnp.float32)
>>> def f_(x):
...   x = jnp.sin(x)
...   jax.debug.inspect_array_sharding(x, callback=print)
...   return jnp.square(x)
>>> f = pjit(f_, in_shardings=PartitionSpec('dev'),
...          out_shardings=PartitionSpec('dev'))
>>> with Mesh(jax.devices(), ('dev',)):
...   f.lower(x).compile()  
...
NamedSharding(mesh={'dev': 8}, partition_spec=PartitionSpec(('dev',),))