jax.lax.conv#

jax.lax.conv(lhs, rhs, window_strides, padding, precision=None, preferred_element_type=None)[原始碼]#

圍繞 conv_general_dilated 的便利包裝器。

參數:
  • lhs (Array) – 一個秩為 n+2 維度的輸入陣列。

  • rhs (Array) – 一個秩為 n+2 維度的核心權重陣列。

  • window_strides (Sequence[int]) – 一個 n 個整數的序列,表示視窗間的步幅。

  • padding (str) – 字串 ‘SAME’ 或字串 ‘VALID’

  • precision (lax.PrecisionLike | None) – 選擇性。可以是 None,表示後端的預設精度;Precision 列舉值 (Precision.DEFAULTPrecision.HIGHPrecision.HIGHEST);或是兩個 Precision 列舉值的元組,表示 lhs`rhs 的精度。

  • preferred_element_type (DTypeLike | None | None) – 選擇性。可以是 None,表示輸入類型的預設累積類型;或是資料類型,表示將結果累積至該資料類型並傳回該資料類型的結果。

傳回值:

包含卷積結果的陣列。

傳回類型:

Array