jax.lax.conv_with_general_padding#
- jax.lax.conv_with_general_padding(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, precision=None, preferred_element_type=None)[原始碼]#
圍繞 conv_general_dilated 的便利包裝函式。
- 參數:
lhs (Array) – 一個秩為 n+2 維度的輸入陣列。
rhs (Array) – 一個秩為 n+2 維度的核心權重陣列。
window_strides (Sequence[int]) – 一個包含 n 個整數的序列,表示視窗間的步幅。
padding (str | Sequence[tuple[int, int]]) – 字串 ‘SAME’、字串 ‘VALID’,或一個包含 n 個 (low, high) 整數對的序列,指定在每個空間維度之前和之後套用的填充。
lhs_dilation (Sequence[int] | None) – None,或一個包含 n 個整數的序列,指定要在 lhs 的每個空間維度中套用的擴張因子。 LHS 擴張也稱為轉置卷積。
rhs_dilation (Sequence[int] | None) – None,或一個包含 n 個整數的序列,指定要在 rhs 的每個空間維度中套用的擴張因子。 RHS 擴張也稱為 atrous 卷積。
precision (lax.PrecisionLike | None) – 選擇性。可以是
None
,表示後端的預設精度,Precision
列舉值 (Precision.DEFAULT
、Precision.HIGH
或Precision.HIGHEST
),或一組兩個Precision
列舉,表示lhs`
和rhs
的精度。preferred_element_type (DTypeLike | None | None) – 選擇性。可以是
None
,表示輸入類型的預設累積類型,或一個資料類型,表示將結果累積到該資料類型並傳回該資料類型的結果。
- 傳回值:
包含卷積結果的陣列。
- 傳回類型: