jax.experimental.pallas.program_id# jax.experimental.pallas.program_id(axis)[source]# 返回沿著網格給定軸的內核執行位置。 例如,在內核執行中使用對應於網格坐標 (1, 2) 的 2D grid,program_id(axis=0) 返回 1,而 program_id(axis=1) 返回 2。 返回的值是一個形狀為 () 和 dtype 為 int32 的陣列。 參數: axis (int) – 要沿著網格計數程式的軸。 返回類型: jax.Array