jax.numpy.diff#

jax.numpy.diff(a, n=1, axis=-1, prepend=None, append=None)[原始碼]#

計算沿給定軸的陣列元素之間的 n 階差分。

JAX 實作的 numpy.diff()

一階差分的計算方式為 a[i+1] - a[i],而 n 階差分則遞迴計算 n 次。

參數:
  • a (ArrayLike) – 輸入陣列。必須具有 a.ndim >= 1

  • n (int) – int,選用,預設值=1。差分的階數。指定計算差分的次數。如果 n=0,則不計算差分,並將輸入原樣傳回。

  • axis (int) – int,選用,預設值=-1。指定計算差分的軸。預設情況下,差分沿 axis -1 計算。

  • prepend (ArrayLike | None) – 純量或陣列,選用,預設值=None。指定在計算差分之前要沿 axis 預先加入的值。

  • append (ArrayLike | None) – 純量或陣列,選用,預設值=None。指定在計算差分之前要沿 axis 附加的值。

傳回:

一個陣列,包含 a 元素之間的 n 階差分。

傳回型別:

Array

另請參閱

範例

jnp.diff 計算沿 axis 的一階差分,預設情況下。

>>> a = jnp.array([[1, 5, 2, 9],
...                [3, 8, 7, 4]])
>>> jnp.diff(a)
Array([[ 4, -3,  7],
       [ 5, -1, -3]], dtype=int32)

n = 2 時,計算沿 axis 的二階差分。

>>> jnp.diff(a, n=2)
Array([[-7, 10],
       [-6, -2]], dtype=int32)

prepend = 2 時,在計算差分之前,會將其預先加入到沿 axisa

>>> jnp.diff(a, prepend=2)
Array([[-1,  4, -3,  7],
       [ 1,  5, -1, -3]], dtype=int32)

append = jnp.array([[3],[1]]) 時,在計算差分之前,會將其附加到沿 axisa

>>> jnp.diff(a, append=jnp.array([[3],[1]]))
Array([[ 4, -3,  7, -6],
       [ 5, -1, -3, -3]], dtype=int32)