jax.numpy.ediff1d#

jax.numpy.ediff1d(ary, to_end=None, to_begin=None)[原始碼]#

計算扁平化陣列中元素的差異。

JAX 實作的 numpy.ediff1d()

參數:
  • ary (ArrayLike) – 輸入陣列或純量。

  • to_end (ArrayLike | None) – 純量或陣列,選用,預設值=None。指定要附加到結果陣列的數字。

  • to_begin (ArrayLike | None) – 純量或陣列,選用,預設值=None。指定要前置到結果陣列的數字。

回傳:

一個包含輸入陣列元素之間差異的陣列。

回傳型別:

Array

注意

與 NumPy 的 ediff1d 實作不同,如果將 to_endto_begin 轉換為 ary 的型別會損失精度,jax.numpy.ediff1d() 不會發出錯誤。

參見

範例

>>> a = jnp.array([2, 3, 5, 9, 1, 4])
>>> jnp.ediff1d(a)
Array([ 1,  2,  4, -8,  3], dtype=int32)
>>> jnp.ediff1d(a, to_begin=-10)
Array([-10,   1,   2,   4,  -8,   3], dtype=int32)
>>> jnp.ediff1d(a, to_end=jnp.array([20, 30]))
Array([ 1,  2,  4, -8,  3, 20, 30], dtype=int32)
>>> jnp.ediff1d(a, to_begin=-10, to_end=jnp.array([20, 30]))
Array([-10,   1,   2,   4,  -8,   3,  20,  30], dtype=int32)

對於 ndim > 1 的陣列,差異在扁平化輸入陣列後計算。

>>> a1 = jnp.array([[2, -1, 4, 7],
...                 [3, 5, -6, 9]])
>>> jnp.ediff1d(a1)
Array([ -3,   5,   3,  -4,   2, -11,  15], dtype=int32)
>>> a2 = jnp.array([2, -1, 4, 7, 3, 5, -6, 9])
>>> jnp.ediff1d(a2)
Array([ -3,   5,   3,  -4,   2, -11,  15], dtype=int32)