jax.experimental.pallas.MemoryRef#

class jax.experimental.pallas.MemoryRef(shape, dtype, memory_space)[原始碼]#

類似於 jax.ShapeDtypeStruct,但具有記憶體空間。

參數:
  • shape (tuple[int, ...])

  • dtype (jnp.dtype)

  • memory_space (Any)

__init__(shape, dtype, memory_space)#
參數:
  • shape (tuple[int, ...])

  • dtype (jnp.dtype)

  • memory_space (Any)

回傳類型:

None

方法

__init__(shape, dtype, memory_space)

get_array_aval()

get_ref_aval()

屬性

shape

dtype

memory_space