jax.lax.split# jax.lax.split(operand, sizes, axis=0)[source]# 沿著 axis 分割陣列。 參數: operand (ArrayLike) – 要分割的陣列 sizes (Sequence[int]) – 分割陣列的大小。大小總和必須等於 operand 的 axis 維度的大小。 axis (int) – 沿著此軸分割陣列。 回傳: 一個 len(sizes) 陣列的序列。如果 sizes 是 [s1, s2, ...],此函數會沿著 axis 回傳大小為 s1、s2 的區塊。 回傳類型: Sequence[Array]