jax.numpy.unique#

jax.numpy.unique(ar, return_index=False, return_inverse=False, return_counts=False, axis=None, *, equal_nan=True, size=None, fill_value=None)[原始碼]#

從陣列中傳回唯一值。

JAX 實作的 numpy.unique()

由於 unique 的輸出大小取決於資料,因此此函數通常與 jit() 和其他 JAX 轉換不相容。JAX 版本新增了選用的 size 引數,必須靜態指定此引數,才能在這些情況下使用 jnp.unique

參數:
  • ar (ArrayLike) – 將從中提取唯一值的 N 維陣列。

  • return_index (bool) – 若為 True,則同時傳回 ar 中每個值出現位置的索引

  • return_inverse (bool) – 若為 True,則同時傳回可用於從唯一值重建 ar 的索引。

  • return_counts (bool) – 若為 True,則同時傳回每個唯一值的出現次數。

  • axis (int | None | None) – 若有指定,則沿著指定的軸計算唯一值。若為 None (預設值),則在計算唯一值之前展平 ar

  • equal_nan (bool) – 若為 True,則在判斷唯一性時將 NaN 值視為相等。

  • size (int | None | None) – 若有指定,則僅傳回前 size 個排序後的唯一元素。如果唯一元素的數量少於 size 指示的數量,則傳回值將以 fill_value 填補。

  • fill_value (ArrayLike | None | None) – 當指定 size 且元素數量少於指示的數量時,以 fill_value 填滿剩餘的項目。fill_value 預設為最小的唯一值。

傳回值:

一個陣列或陣列元組,取決於 return_indexreturn_inversereturn_counts 的值。傳回值為

  • unique_values:

    axis 為 None,則為長度為 n_unique 的 1D 陣列。若有指定 axis,則形狀為 (*ar.shape[:axis], n_unique, *ar.shape[axis + 1:])

  • unique_index:

    (僅在 return_index 為 True 時傳回) 形狀為 (n_unique,) 的陣列。包含 ar 中每個唯一值首次出現位置的索引。對於 1D 輸入,ar[unique_index] 相當於 unique_values

  • unique_inverse:

    (僅在 return_inverse 為 True 時傳回)axis 為 None,則為形狀為 (ar.size,) 的陣列,若有指定 axis,則為形狀為 (ar.shape[axis],) 的陣列。包含 ar 中每個值的 unique_values 內的索引。對於 1D 輸入,unique_values[unique_inverse] 相當於 ar

  • unique_counts:

    (僅在 return_counts 為 True 時傳回) 形狀為 (n_unique,) 的陣列。包含 ar 中每個唯一值的出現次數。

另請參閱

範例

>>> x = jnp.array([3, 4, 1, 3, 1])
>>> jnp.unique(x)
Array([1, 3, 4], dtype=int32)

JIT 編譯 & size 引數

如果您在 jit() 或其他轉換下嘗試此操作,您會收到錯誤,因為輸出形狀是動態的

>>> jax.jit(jnp.unique)(x)  
Traceback (most recent call last):
   ...
jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[5].
The error arose for the first argument of jnp.unique(). To make jnp.unique() compatible with JIT and other transforms, you can specify a concrete value for the size argument, which will determine the output size.

問題在於轉換後的函數輸出必須具有靜態形狀。為了使其運作,您可以傳遞靜態 size 參數

>>> jit_unique = jax.jit(jnp.unique, static_argnames=['size'])
>>> jit_unique(x, size=3)
Array([1, 3, 4], dtype=int32)

如果您的靜態大小小於唯一值的真實數量,則會截斷這些值。

>>> jit_unique(x, size=2)
Array([1, 3], dtype=int32)

如果靜態大小大於唯一值的真實數量,則會以 fill_value 填補這些值,預設值為最小的唯一值

>>> jit_unique(x, size=5)
Array([1, 3, 4, 1, 1], dtype=int32)
>>> jit_unique(x, size=5, fill_value=0)
Array([1, 3, 4, 0, 0], dtype=int32)

多維唯一值

如果您將多維陣列傳遞至 unique,則預設會將其展平

>>> M = jnp.array([[1, 2],
...                [2, 3],
...                [1, 2]])
>>> jnp.unique(M)
Array([1, 2, 3], dtype=int32)

如果您傳遞 axis 關鍵字,您可以沿該軸找到陣列的唯一切片

>>> jnp.unique(M, axis=0)
Array([[1, 2],
       [2, 3]], dtype=int32)

傳回索引

如果您設定 return_index=True,則 unique 會傳回每個唯一值首次出現位置的索引

>>> x = jnp.array([3, 4, 1, 3, 1])
>>> values, indices = jnp.unique(x, return_index=True)
>>> print(values)
[1 3 4]
>>> print(indices)
[2 0 1]
>>> jnp.all(values == x[indices])
Array(True, dtype=bool)

在多個維度中,可以使用沿指定軸評估的 jax.numpy.take() 提取唯一值

>>> values, indices = jnp.unique(M, axis=0, return_index=True)
>>> jnp.all(values == jnp.take(M, indices, axis=0))
Array(True, dtype=bool)

傳回反向索引

如果您設定 return_inverse=True,則 unique 會傳回輸入陣列中每個項目的唯一值內的索引

>>> x = jnp.array([3, 4, 1, 3, 1])
>>> values, inverse = jnp.unique(x, return_inverse=True)
>>> print(values)
[1 3 4]
>>> print(inverse)
[1 2 0 1 0]
>>> jnp.all(values[inverse] == x)
Array(True, dtype=bool)

在多個維度中,可以使用 jax.numpy.take() 重建輸入

>>> values, inverse = jnp.unique(M, axis=0, return_inverse=True)
>>> jnp.all(jnp.take(values, inverse, axis=0) == M)
Array(True, dtype=bool)

傳回計數

如果您設定 return_counts=True,則 unique 會傳回輸入中每個唯一值的出現次數

>>> x = jnp.array([3, 4, 1, 3, 1])
>>> values, counts = jnp.unique(x, return_counts=True)
>>> print(values)
[1 3 4]
>>> print(counts)
[2 2 1]

對於多維陣列,這也會傳回 1D 計數陣列,指示沿指定軸的出現次數

>>> values, counts = jnp.unique(M, axis=0, return_counts=True)
>>> print(values)
[[1 2]
 [2 3]]
>>> print(counts)
[2 1]