jax.numpy.union1d#

jax.numpy.union1d(ar1, ar2, *, size=None, fill_value=None)[原始碼]#

計算兩個 1D 陣列的聯集。

JAX 版本的 numpy.union1d() 實作。

由於 union1d 的輸出大小取決於資料,因此此函式通常與 jit() 和其他 JAX 轉換不相容。JAX 版本新增了選用的 size 引數,必須靜態指定此引數,才能在這些情況下使用 jnp.union1d

參數:
  • ar1 (ArrayLike) – 要聯集的第一個元素陣列。

  • ar2 (ArrayLike) – 要聯集的第二個元素陣列

  • size (int | None | None) – 如果指定,則僅傳回前 size 個排序元素。如果元素數量少於 size 指示的數量,則傳回值將以 fill_value 填補。

  • fill_value (ArrayLike | None | None) – 當指定 size 且元素數量少於指示的數量時,以 fill_value 填補剩餘的項目。fill_value 預設為最小值。

傳回:

包含輸入陣列中元素聯集的陣列。

傳回類型:

Array

另請參閱

範例

計算兩個陣列的聯集

>>> ar1 = jnp.array([1, 2, 3, 4])
>>> ar2 = jnp.array([3, 4, 5, 6])
>>> jnp.union1d(ar1, ar2)
Array([1, 2, 3, 4, 5, 6], dtype=int32)

由於輸出形狀是動態的,因此這會在 jit() 和其他轉換下失敗

>>> jax.jit(jnp.union1d)(ar1, ar2)  
Traceback (most recent call last):
   ...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[4].
The error occurred while tracing the function union1d at /Users/vanderplas/github/jax-ml/jax/jax/_src/numpy/setops.py:101 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1.

為了確保靜態已知的輸出形狀,您可以傳遞靜態 size 引數

>>> jit_union1d = jax.jit(jnp.union1d, static_argnames=['size'])
>>> jit_union1d(ar1, ar2, size=6)
Array([1, 2, 3, 4, 5, 6], dtype=int32)

如果 size 太小,則聯集會被截斷

>>> jit_union1d(ar1, ar2, size=4)
Array([1, 2, 3, 4], dtype=int32)

如果 size 太大,則輸出會以 fill_value 填補

>>> jit_union1d(ar1, ar2, size=8, fill_value=0)
Array([1, 2, 3, 4, 5, 6, 0, 0], dtype=int32)