jax.lax.conv#
- jax.lax.conv(lhs, rhs, window_strides, padding, precision=None, preferred_element_type=None)[原始碼]#
圍繞 conv_general_dilated 的便利包裝器。
- 參數:
lhs (Array) – 一個秩為 n+2 維度的輸入陣列。
rhs (Array) – 一個秩為 n+2 維度的核心權重陣列。
window_strides (Sequence[int]) – 一個 n 個整數的序列,表示視窗間的步幅。
padding (str) – 字串 ‘SAME’ 或字串 ‘VALID’。
precision (lax.PrecisionLike | None) – 選擇性。可以是
None
,表示後端的預設精度;Precision
列舉值 (Precision.DEFAULT
、Precision.HIGH
或Precision.HIGHEST
);或是兩個Precision
列舉值的元組,表示lhs`
和rhs
的精度。preferred_element_type (DTypeLike | None | None) – 選擇性。可以是
None
,表示輸入類型的預設累積類型;或是資料類型,表示將結果累積至該資料類型並傳回該資料類型的結果。
- 傳回值:
包含卷積結果的陣列。
- 傳回類型: