jax.numpy.diff#
- jax.numpy.diff(a, n=1, axis=-1, prepend=None, append=None)[原始碼]#
計算沿給定軸的陣列元素之間的 n 階差分。
JAX 實作的
numpy.diff()
。一階差分的計算方式為
a[i+1] - a[i]
,而 n 階差分則遞迴計算n
次。- 參數:
a (ArrayLike) – 輸入陣列。必須具有
a.ndim >= 1
。n (int) – int,選用,預設值=1。差分的階數。指定計算差分的次數。如果 n=0,則不計算差分,並將輸入原樣傳回。
axis (int) – int,選用,預設值=-1。指定計算差分的軸。預設情況下,差分沿
axis -1
計算。prepend (ArrayLike | None) – 純量或陣列,選用,預設值=None。指定在計算差分之前要沿
axis
預先加入的值。append (ArrayLike | None) – 純量或陣列,選用,預設值=None。指定在計算差分之前要沿
axis
附加的值。
- 傳回:
一個陣列,包含
a
元素之間的 n 階差分。- 傳回型別:
另請參閱
jax.numpy.ediff1d()
:計算陣列中連續元素之間的差值。jax.numpy.cumsum()
:計算陣列元素沿給定軸的累積總和。jax.numpy.gradient()
:計算 N 維陣列的梯度。
範例
jnp.diff
計算沿axis
的一階差分,預設情況下。>>> a = jnp.array([[1, 5, 2, 9], ... [3, 8, 7, 4]]) >>> jnp.diff(a) Array([[ 4, -3, 7], [ 5, -1, -3]], dtype=int32)
當
n = 2
時,計算沿axis
的二階差分。>>> jnp.diff(a, n=2) Array([[-7, 10], [-6, -2]], dtype=int32)
當
prepend = 2
時,在計算差分之前,會將其預先加入到沿axis
的a
。>>> jnp.diff(a, prepend=2) Array([[-1, 4, -3, 7], [ 1, 5, -1, -3]], dtype=int32)
當
append = jnp.array([[3],[1]])
時,在計算差分之前,會將其附加到沿axis
的a
。>>> jnp.diff(a, append=jnp.array([[3],[1]])) Array([[ 4, -3, 7, -6], [ 5, -1, -3, -3]], dtype=int32)