jax.numpy.array_equal#

jax.numpy.array_equal(a1, a2, equal_nan=False)[原始碼]#

檢查兩個陣列是否逐元素相等。

JAX 實作的 numpy.array_equal()

參數:
  • a1 (ArrayLike) – 第一個要比較的輸入陣列。

  • a2 (ArrayLike) – 第二個要比較的輸入陣列。

  • equal_nan (bool) – 布林值。如果 True,則 a1 中的 NaN 將被視為等於 a2 中的 NaN。預設值為 False

傳回:

布林純量陣列,指示輸入陣列是否逐元素相等。

傳回類型:

Array

範例

>>> jnp.array_equal(jnp.array([1, 2, 3]), jnp.array([1, 2, 3]))
Array(True, dtype=bool)
>>> jnp.array_equal(jnp.array([1, 2, 3]), jnp.array([1, 2]))
Array(False, dtype=bool)
>>> jnp.array_equal(jnp.array([1, 2, 3]), jnp.array([1, 2, 4]))
Array(False, dtype=bool)
>>> jnp.array_equal(jnp.array([1, 2, float('nan')]),
...                 jnp.array([1, 2, float('nan')]))
Array(False, dtype=bool)
>>> jnp.array_equal(jnp.array([1, 2, float('nan')]),
...                 jnp.array([1, 2, float('nan')]), equal_nan=True)
Array(True, dtype=bool)