jax.lax.conv_general_dilated#
- jax.lax.conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation=None, rhs_dilation=None, dimension_numbers=None, feature_group_count=1, batch_group_count=1, precision=None, preferred_element_type=None)[原始碼]#
通用的 n 維度卷積運算子,具有可選的擴張。
包裝 XLA 的 Conv 運算子。
- 參數:
lhs (Array) – 秩為 n+2 維度的輸入陣列。
rhs (Array) – 秩為 n+2 維度的核心權重陣列。
window_strides (Sequence[int]) – n 個整數的序列,表示視窗間的步幅。
padding (str | Sequence[tuple[int, int]]) – 字串 ‘SAME’、‘SAME_LOWER’ 或 ‘VALID’,或 n 個 (low, high) 整數對的序列,給出在每個空間維度之前和之後應用的填充。‘SAME’ 和 ‘SAME_LOWER’ 新增填充以產生與輸入相同大小的輸出。填充在兩側之間平均或幾乎平均分配。如果填充是奇數,則額外的填充會新增到 ‘SAME’ 的末尾和 ‘SAME_LOWER’ 的開頭。
lhs_dilation (Sequence[int] | None | None) – None,或 n 個整數的序列,給出在 lhs 的每個空間維度中應用的擴張因子。 LHS 擴張也稱為轉置卷積。
rhs_dilation (Sequence[int] | None | None) – None,或 n 個整數的序列,給出在 rhs 的每個空間維度中應用的擴張因子。 RHS 擴張也稱為 atrous 卷積。
dimension_numbers (ConvGeneralDilatedDimensionNumbers | None) – None、
ConvDimensionNumbers
物件或 3 元組(lhs_spec, rhs_spec, out_spec)
,其中每個元素都是長度為 n+2 的字串。feature_group_count (int) – 整數,預設值為 1。請參閱 XLA HLO 文件。
batch_group_count (int) – 整數,預設值為 1。請參閱 XLA HLO 文件。
precision (lax.PrecisionLike | None) – 可選。可以是
None
,表示後端的預設精度,Precision
列舉值 (Precision.DEFAULT
、Precision.HIGH
或Precision.HIGHEST
),字串 (例如 ‘highest’ 或 ‘fastest’,請參閱jax.default_matmul_precision
內容管理器),或是兩個Precision
列舉或字串的元組,表示lhs
和rhs
的精度。preferred_element_type (DTypeLike | None | None) – 可選。可以是
None
,表示輸入類型的預設累積類型,或資料類型,表示將結果累積到該資料類型並傳回該資料類型的結果。
- 傳回值:
包含卷積結果的陣列。
- 傳回類型:
在
dimension_numbers
的字串情況下,每個字元依位置識別批次維度在
lhs
、rhs
和輸出中使用字元 ‘N’,特徵維度在 lhs 和輸出中使用字元 ‘C’,
輸入和輸出特徵維度在 rhs 中分別使用字元 ‘I’ 和 ‘O’,以及
使用任何不同的字元,在 lhs、rhs 和輸出之間建立空間維度對應關係。以下範例使用 ‘W’ 和 ‘H’。
例如,若要指示與具有兩個空間維度的
conv
函數一致的維度編號,可以使用('NCHW', 'OIHW', 'NCHW')
。 作為另一個範例,若要指示與 TensorFlow Conv2D 運算一致的維度編號,可以使用('NHWC', 'HWIO', 'NHWC')
。 當使用後一種形式的卷積維度規範時,視窗步幅會根據標籤在rhs_spec
字串中出現的順序與空間維度字元標籤相關聯,因此window_strides[0]
會與對應於 rhs_spec 中第一個出現的字元的維度相符,該字元不是'I'
或'O'
。如果
dimension_numbers
為None
,則預設值為('NCHW', 'OIHW', 'NCHW')
(適用於 2D 卷積)。