jax.experimental.pallas.mosaic_gpu.TransposeTransform#

class jax.experimental.pallas.mosaic_gpu.TransposeTransform(permutation)[原始碼]#

轉置平鋪的 memref。

參數:

permutation (tuple[int, ...])

__init__(permutation)#
參數:

permutation (tuple[int, ...])

回傳類型:

None

方法

__init__(permutation)

batch(leading_rank)

傳回一個轉換,該轉換接受具有額外 leading_rank 維度的 ref。

to_gpu_transform()

undo(ref)

屬性

permutation