jax.lax.axis_index#
- jax.lax.axis_index(axis_name)[原始碼]#
傳回沿著映射軸
axis_name
的索引。- 參數:
axis_name – 用於命名映射軸的可雜湊 Python 物件。
- 傳回:
代表索引的整數。
例如,在有 8 個 XLA 裝置可用的情況下
>>> from functools import partial >>> @partial(jax.pmap, axis_name='i') ... def f(_): ... return lax.axis_index('i') ... >>> f(np.zeros(4)) Array([0, 1, 2, 3], dtype=int32) >>> f(np.zeros(8)) Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32) >>> @partial(jax.pmap, axis_name='i') ... @partial(jax.pmap, axis_name='j') ... def f(_): ... return lax.axis_index('i'), lax.axis_index('j') ... >>> x, y = f(np.zeros((4, 2))) >>> print(x) [[0 0] [1 1] [2 2] [3 3]] >>> print(y) [[0 1] [0 1] [0 1] [0 1]]