jax.lax.pad#
- jax.lax.pad(operand, padding_value, padding_config)[原始碼]#
將低、高和/或內部填充應用於陣列。
包裝 XLA 的 Pad 運算子。
- 參數:
- 返回:
根據
padding_config
,在每個維度中插入填充值padding_value
的operand
陣列。- 返回類型:
範例
>>> from jax import lax >>> import jax.numpy as jnp
用零填充 1 維陣列。我們將指定前面兩個零,後面三個零
>>> x = jnp.array([1, 2, 3, 4]) >>> lax.pad(x, 0, [(2, 3, 0)]) Array([0, 0, 1, 2, 3, 4, 0, 0, 0], dtype=int32)
用內部零填充 1 維陣列;即在每個值之間插入單個零
>>> lax.pad(x, 0, [(0, 0, 1)]) Array([1, 0, 2, 0, 3, 0, 4], dtype=int32)
用值
-1
在前面和結尾填充 2 維陣列,每個維度的填充大小為 2>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> lax.pad(x, -1, [(2, 2, 0), (2, 2, 0)]) Array([[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, 1, 2, 3, -1, -1], [-1, -1, 4, 5, 6, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], dtype=int32)