jax.lax.expand_dims#

jax.lax.expand_dims(array, dimensions)[原始碼]#

將任意數量的尺寸為 1 的維度插入陣列中。

參數:
  • array (ArrayLike)

  • dimensions (Sequence[int])

回傳類型:

Array