jax.lax.pad#

jax.lax.pad(operand, padding_value, padding_config)[原始碼]#

將低、高和/或內部填充應用於陣列。

包裝 XLA 的 Pad 運算子。

參數:
  • operand (ArrayLike) – 要填充的陣列。

  • padding_value (ArrayLike) – 要作為填充插入的值。必須與 operand 具有相同的 dtype。

  • padding_config (Sequence[tuple[int, int, int]]) – 一系列 (low, high, interior) 整數元組,指定要在每個維度中插入的低、高和內部(擴張)填充量。

返回:

根據 padding_config,在每個維度中插入填充值 padding_valueoperand 陣列。

返回類型:

Array

範例

>>> from jax import lax
>>> import jax.numpy as jnp

用零填充 1 維陣列。我們將指定前面兩個零,後面三個零

>>> x = jnp.array([1, 2, 3, 4])
>>> lax.pad(x, 0, [(2, 3, 0)])
Array([0, 0, 1, 2, 3, 4, 0, 0, 0], dtype=int32)

內部零填充 1 維陣列;即在每個值之間插入單個零

>>> lax.pad(x, 0, [(0, 0, 1)])
Array([1, 0, 2, 0, 3, 0, 4], dtype=int32)

用值 -1 在前面和結尾填充 2 維陣列,每個維度的填充大小為 2

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> lax.pad(x, -1, [(2, 2, 0), (2, 2, 0)])
Array([[-1, -1, -1, -1, -1, -1, -1],
       [-1, -1, -1, -1, -1, -1, -1],
       [-1, -1,  1,  2,  3, -1, -1],
       [-1, -1,  4,  5,  6, -1, -1],
       [-1, -1, -1, -1, -1, -1, -1],
       [-1, -1, -1, -1, -1, -1, -1]], dtype=int32)