jax.lax.conv_with_general_padding#

jax.lax.conv_with_general_padding(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, precision=None, preferred_element_type=None)[原始碼]#

圍繞 conv_general_dilated 的便利包裝函式。

參數:
  • lhs (Array) – 一個秩為 n+2 維度的輸入陣列。

  • rhs (Array) – 一個秩為 n+2 維度的核心權重陣列。

  • window_strides (Sequence[int]) – 一個包含 n 個整數的序列,表示視窗間的步幅。

  • padding (str | Sequence[tuple[int, int]]) – 字串 ‘SAME’、字串 ‘VALID’,或一個包含 n(low, high) 整數對的序列,指定在每個空間維度之前和之後套用的填充。

  • lhs_dilation (Sequence[int] | None) – None,或一個包含 n 個整數的序列,指定要在 lhs 的每個空間維度中套用的擴張因子。 LHS 擴張也稱為轉置卷積。

  • rhs_dilation (Sequence[int] | None) – None,或一個包含 n 個整數的序列,指定要在 rhs 的每個空間維度中套用的擴張因子。 RHS 擴張也稱為 atrous 卷積。

  • precision (lax.PrecisionLike | None) – 選擇性。可以是 None,表示後端的預設精度,Precision 列舉值 (Precision.DEFAULTPrecision.HIGHPrecision.HIGHEST),或一組兩個 Precision 列舉,表示 lhs`rhs 的精度。

  • preferred_element_type (DTypeLike | None | None) – 選擇性。可以是 None,表示輸入類型的預設累積類型,或一個資料類型,表示將結果累積到該資料類型並傳回該資料類型的結果。

傳回值:

包含卷積結果的陣列。

傳回類型:

Array