jax.numpy.pad#
- jax.numpy.pad(array, pad_width, mode='constant', **kwargs)[原始碼]#
為陣列添加填充。
JAX 實作的
numpy.pad()
。- 參數:
array (ArrayLike) – 要填充的陣列。
pad_width (PadValueLike[int | Array | np.ndarray]) –
指定陣列每個維度的填充寬度。可以分別為陣列之前和之後指定填充寬度。選項如下
int
或(int,)
:在每個陣列維度之前和之後填充相同數量的數值。(before, after)
:在每個陣列之前填充before
個元素,之後填充after
個元素((before_1, after_1), (before_2, after_2), ... (before_N, after_N))
:為每個陣列維度指定不同的before
和after
值。
mode (str | Callable[..., Any]) –
字串或可調用物件。支援的填充模式為
'constant'
(預設):以常數值填充,預設為零。'empty'
:以空值填充 (即零)'edge'
:以陣列的邊緣值填充。'wrap'
:通過包裝陣列進行填充。'linear_ramp'
:以線性斜坡填充到指定的end_values
。'maximum'
:以最大值填充。'mean'
:以平均值填充。'median'
:以中位數值填充。'minimum'
:以最小值填充。'reflect'
:通過反射填充。'symmetric'
:通過對稱反射填充。<callable>
:可調用函數。請參閱下面的註釋。
constant_values – 針對
mode = 'constant'
參考。指定要填充的常數值。stat_length – 針對
mode in ['maximum', 'mean', 'median', 'minimum']
參考。整數或元組,指定計算統計量時要使用的邊緣值的數量。end_values – 針對
mode = 'linear_ramp'
參考。指定斜坡填充值要達到的結束值。reflect_type – 針對
mode in ['reflect', 'symmetric']
參考。指定是否使用偶數或奇數反射。
- 返回:
array
的填充副本。- 返回類型:
注意事項
當
mode
是可調用物件時,它應具有以下簽名def pad_func(row: Array, pad_width: tuple[int, int], iaxis: int, kwargs: dict) -> Array: ...
此處
row
是沿軸iaxis
填充陣列的 1D 切片,填充值為零。pad_width
是一個元組,指定(before, after)
填充大小,而kwargs
是傳遞給jax.numpy.pad()
函數的任何其他關鍵字引數。請注意,雖然在 NumPy 中,函數應就地修改
row
,但在 JAX 中,函數應返回修改後的row
。在 JAX 中,自訂填充函數將使用jax.vmap()
轉換映射到填充軸上。參見
jax.numpy.resize()
:調整陣列大小jax.numpy.tile()
:通過平鋪較小的陣列來建立更大的陣列。jax.numpy.repeat()
:通過重複較小陣列的值來建立更大的陣列。
範例
用零填充 1 維陣列
>>> x = jnp.array([10, 20, 30, 40]) >>> jnp.pad(x, 2) Array([ 0, 0, 10, 20, 30, 40, 0, 0], dtype=int32) >>> jnp.pad(x, (2, 4)) Array([ 0, 0, 10, 20, 30, 40, 0, 0, 0, 0], dtype=int32)
用指定值填充 1 維陣列
>>> jnp.pad(x, 2, constant_values=99) Array([99, 99, 10, 20, 30, 40, 99, 99], dtype=int32)
用平均陣列值填充 1 維陣列
>>> jnp.pad(x, 2, mode='mean') Array([25, 25, 10, 20, 30, 40, 25, 25], dtype=int32)
用反射值填充 1 維陣列
>>> jnp.pad(x, 2, mode='reflect') Array([30, 20, 10, 20, 30, 40, 30, 20], dtype=int32)
在每個維度中使用不同填充來填充 2 維陣列
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.pad(x, ((1, 2), (3, 0))) Array([[0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 2, 3], [0, 0, 0, 4, 5, 6], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]], dtype=int32)
使用自訂填充函數填充 1 維陣列
>>> def custom_pad(row, pad_width, iaxis, kwargs): ... # row represents a 1D slice of the zero-padded array. ... before, after = pad_width ... before_value = kwargs.get('before_value', 0) ... after_value = kwargs.get('after_value', 0) ... row = row.at[:before].set(before_value) ... return row.at[len(row) - after:].set(after_value) >>> x = jnp.array([2, 3, 4]) >>> jnp.pad(x, 2, custom_pad, before_value=-10, after_value=10) Array([-10, -10, 2, 3, 4, 10, 10], dtype=int32)