jax.experimental.pallas.mosaic_gpu.WGMMAAccumulatorRef#

class jax.experimental.pallas.mosaic_gpu.WGMMAAccumulatorRef(shape: 'tuple[int, int]', dtype: 'jnp.dtype' = <class 'jax.numpy.float32'>, _init: 'Any' = <jax._src.state.types.Uninitialized object at 0x7f9033ff0160>)[source]#
參數:
  • shape (tuple[int, int])

  • dtype (jnp.dtype)

  • _init (Any)

__init__(shape, dtype=<class 'jax.numpy.float32'>, _init=<jax._src.state.types.Uninitialized object>)#
參數:
  • shape (tuple[int, int])

  • dtype (jnp.dtype)

  • _init (Any)

回傳型別:

None

方法

__init__(shape[, dtype, _init])

get_ref_aval()

init(array)

屬性

shape