jax.experimental.pallas.num_programs# jax.experimental.pallas.num_programs(axis)[原始碼]# 返回沿給定軸的網格大小。 參數: axis (int) 返回型別: int | jax.Array