jax.numpy.ediff1d#
- jax.numpy.ediff1d(ary, to_end=None, to_begin=None)[原始碼]#
計算扁平化陣列中元素的差異。
JAX 實作的
numpy.ediff1d()
。- 參數:
ary (ArrayLike) – 輸入陣列或純量。
to_end (ArrayLike | None) – 純量或陣列,選用,預設值=None。指定要附加到結果陣列的數字。
to_begin (ArrayLike | None) – 純量或陣列,選用,預設值=None。指定要前置到結果陣列的數字。
- 回傳:
一個包含輸入陣列元素之間差異的陣列。
- 回傳型別:
注意
與 NumPy 的 ediff1d 實作不同,如果將
to_end
或to_begin
轉換為ary
的型別會損失精度,jax.numpy.ediff1d()
不會發出錯誤。參見
jax.numpy.diff()
:計算沿給定軸的陣列元素之間的 n 階差分。jax.numpy.cumsum()
:計算沿給定軸的陣列元素的累積總和。jax.numpy.gradient()
:計算 N 維陣列的梯度。
範例
>>> a = jnp.array([2, 3, 5, 9, 1, 4]) >>> jnp.ediff1d(a) Array([ 1, 2, 4, -8, 3], dtype=int32) >>> jnp.ediff1d(a, to_begin=-10) Array([-10, 1, 2, 4, -8, 3], dtype=int32) >>> jnp.ediff1d(a, to_end=jnp.array([20, 30])) Array([ 1, 2, 4, -8, 3, 20, 30], dtype=int32) >>> jnp.ediff1d(a, to_begin=-10, to_end=jnp.array([20, 30])) Array([-10, 1, 2, 4, -8, 3, 20, 30], dtype=int32)
對於
ndim
> 1 的陣列,差異在扁平化輸入陣列後計算。>>> a1 = jnp.array([[2, -1, 4, 7], ... [3, 5, -6, 9]]) >>> jnp.ediff1d(a1) Array([ -3, 5, 3, -4, 2, -11, 15], dtype=int32) >>> a2 = jnp.array([2, -1, 4, 7, 3, 5, -6, 9]) >>> jnp.ediff1d(a2) Array([ -3, 5, 3, -4, 2, -11, 15], dtype=int32)