jax.numpy.pad#

jax.numpy.pad(array, pad_width, mode='constant', **kwargs)[原始碼]#

為陣列添加填充。

JAX 實作的 numpy.pad()

參數:
  • array (ArrayLike) – 要填充的陣列。

  • pad_width (PadValueLike[int | Array | np.ndarray]) –

    指定陣列每個維度的填充寬度。可以分別為陣列之前之後指定填充寬度。選項如下

    • int(int,):在每個陣列維度之前之後填充相同數量的數值。

    • (before, after):在每個陣列之前填充 before 個元素,之後填充 after 個元素

    • ((before_1, after_1), (before_2, after_2), ... (before_N, after_N)):為每個陣列維度指定不同的 beforeafter 值。

  • mode (str | Callable[..., Any]) –

    字串或可調用物件。支援的填充模式為

    • 'constant' (預設):以常數值填充,預設為零。

    • 'empty':以空值填充 (即零)

    • 'edge':以陣列的邊緣值填充。

    • 'wrap':通過包裝陣列進行填充。

    • 'linear_ramp':以線性斜坡填充到指定的 end_values

    • 'maximum':以最大值填充。

    • 'mean':以平均值填充。

    • 'median':以中位數值填充。

    • 'minimum':以最小值填充。

    • 'reflect':通過反射填充。

    • 'symmetric':通過對稱反射填充。

    • <callable>:可調用函數。請參閱下面的註釋。

  • constant_values – 針對 mode = 'constant' 參考。指定要填充的常數值。

  • stat_length – 針對 mode in ['maximum', 'mean', 'median', 'minimum'] 參考。整數或元組,指定計算統計量時要使用的邊緣值的數量。

  • end_values – 針對 mode = 'linear_ramp' 參考。指定斜坡填充值要達到的結束值。

  • reflect_type – 針對 mode in ['reflect', 'symmetric'] 參考。指定是否使用偶數或奇數反射。

返回:

array 的填充副本。

返回類型:

Array

注意事項

mode 是可調用物件時,它應具有以下簽名

def pad_func(row: Array, pad_width: tuple[int, int],
             iaxis: int, kwargs: dict) -> Array:
  ...

此處 row 是沿軸 iaxis 填充陣列的 1D 切片,填充值為零。pad_width 是一個元組,指定 (before, after) 填充大小,而 kwargs 是傳遞給 jax.numpy.pad() 函數的任何其他關鍵字引數。

請注意,雖然在 NumPy 中,函數應就地修改 row,但在 JAX 中,函數應返回修改後的 row。在 JAX 中,自訂填充函數將使用 jax.vmap() 轉換映射到填充軸上。

參見

範例

用零填充 1 維陣列

>>> x = jnp.array([10, 20, 30, 40])
>>> jnp.pad(x, 2)
Array([ 0,  0, 10, 20, 30, 40,  0,  0], dtype=int32)
>>> jnp.pad(x, (2, 4))
Array([ 0,  0, 10, 20, 30, 40,  0,  0,  0,  0], dtype=int32)

用指定值填充 1 維陣列

>>> jnp.pad(x, 2, constant_values=99)
Array([99, 99, 10, 20, 30, 40, 99, 99], dtype=int32)

用平均陣列值填充 1 維陣列

>>> jnp.pad(x, 2, mode='mean')
Array([25, 25, 10, 20, 30, 40, 25, 25], dtype=int32)

用反射值填充 1 維陣列

>>> jnp.pad(x, 2, mode='reflect')
Array([30, 20, 10, 20, 30, 40, 30, 20], dtype=int32)

在每個維度中使用不同填充來填充 2 維陣列

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> jnp.pad(x, ((1, 2), (3, 0)))
Array([[0, 0, 0, 0, 0, 0],
       [0, 0, 0, 1, 2, 3],
       [0, 0, 0, 4, 5, 6],
       [0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0]], dtype=int32)

使用自訂填充函數填充 1 維陣列

>>> def custom_pad(row, pad_width, iaxis, kwargs):
...   # row represents a 1D slice of the zero-padded array.
...   before, after = pad_width
...   before_value = kwargs.get('before_value', 0)
...   after_value = kwargs.get('after_value', 0)
...   row = row.at[:before].set(before_value)
...   return row.at[len(row) - after:].set(after_value)
>>> x = jnp.array([2, 3, 4])
>>> jnp.pad(x, 2, custom_pad, before_value=-10, after_value=10)
Array([-10, -10,   2,   3,   4,  10,  10], dtype=int32)