jax.lax.split#

jax.lax.split(operand, sizes, axis=0)[source]#

沿著 axis 分割陣列。

參數:
  • operand (ArrayLike) – 要分割的陣列

  • sizes (Sequence[int]) – 分割陣列的大小。大小總和必須等於 operandaxis 維度的大小。

  • axis (int) – 沿著此軸分割陣列。

回傳:

一個 len(sizes) 陣列的序列。如果 sizes[s1, s2, ...],此函數會沿著 axis 回傳大小為 s1s2 的區塊。

回傳類型:

Sequence[Array]