jax.lax.conv_general_dilated#

jax.lax.conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation=None, rhs_dilation=None, dimension_numbers=None, feature_group_count=1, batch_group_count=1, precision=None, preferred_element_type=None)[原始碼]#

通用的 n 維度卷積運算子,具有可選的擴張。

包裝 XLA 的 Conv 運算子。

參數:
  • lhs (Array) – 秩為 n+2 維度的輸入陣列。

  • rhs (Array) – 秩為 n+2 維度的核心權重陣列。

  • window_strides (Sequence[int]) – n 個整數的序列,表示視窗間的步幅。

  • padding (str | Sequence[tuple[int, int]]) – 字串 ‘SAME’‘SAME_LOWER’‘VALID’,或 n(low, high) 整數對的序列,給出在每個空間維度之前和之後應用的填充。‘SAME’‘SAME_LOWER’ 新增填充以產生與輸入相同大小的輸出。填充在兩側之間平均或幾乎平均分配。如果填充是奇數,則額外的填充會新增到 ‘SAME’ 的末尾和 ‘SAME_LOWER’ 的開頭。

  • lhs_dilation (Sequence[int] | None | None) – None,或 n 個整數的序列,給出在 lhs 的每個空間維度中應用的擴張因子。 LHS 擴張也稱為轉置卷積。

  • rhs_dilation (Sequence[int] | None | None) – None,或 n 個整數的序列,給出在 rhs 的每個空間維度中應用的擴張因子。 RHS 擴張也稱為 atrous 卷積。

  • dimension_numbers (ConvGeneralDilatedDimensionNumbers | None) – NoneConvDimensionNumbers 物件或 3 元組 (lhs_spec, rhs_spec, out_spec),其中每個元素都是長度為 n+2 的字串。

  • feature_group_count (int) – 整數,預設值為 1。請參閱 XLA HLO 文件。

  • batch_group_count (int) – 整數,預設值為 1。請參閱 XLA HLO 文件。

  • precision (lax.PrecisionLike | None) – 可選。可以是 None,表示後端的預設精度,Precision 列舉值 (Precision.DEFAULTPrecision.HIGHPrecision.HIGHEST),字串 (例如 ‘highest’ 或 ‘fastest’,請參閱 jax.default_matmul_precision 內容管理器),或是兩個 Precision 列舉或字串的元組,表示 lhsrhs 的精度。

  • preferred_element_type (DTypeLike | None | None) – 可選。可以是 None,表示輸入類型的預設累積類型,或資料類型,表示將結果累積到該資料類型並傳回該資料類型的結果。

傳回值:

包含卷積結果的陣列。

傳回類型:

Array

dimension_numbers 的字串情況下,每個字元依位置識別

  • 批次維度在 lhsrhs 和輸出中使用字元 ‘N’,

  • 特徵維度在 lhs 和輸出中使用字元 ‘C’,

  • 輸入和輸出特徵維度在 rhs 中分別使用字元 ‘I’ 和 ‘O’,以及

  • 使用任何不同的字元,在 lhs、rhs 和輸出之間建立空間維度對應關係。以下範例使用 ‘W’ 和 ‘H’。

例如,若要指示與具有兩個空間維度的 conv 函數一致的維度編號,可以使用 ('NCHW', 'OIHW', 'NCHW')。 作為另一個範例,若要指示與 TensorFlow Conv2D 運算一致的維度編號,可以使用 ('NHWC', 'HWIO', 'NHWC')。 當使用後一種形式的卷積維度規範時,視窗步幅會根據標籤在 rhs_spec 字串中出現的順序與空間維度字元標籤相關聯,因此 window_strides[0] 會與對應於 rhs_spec 中第一個出現的字元的維度相符,該字元不是 'I''O'

如果 dimension_numbersNone,則預設值為 ('NCHW', 'OIHW', 'NCHW') (適用於 2D 卷積)。