jax.numpy.unique#
- jax.numpy.unique(ar, return_index=False, return_inverse=False, return_counts=False, axis=None, *, equal_nan=True, size=None, fill_value=None)[原始碼]#
從陣列中傳回唯一值。
JAX 實作的
numpy.unique()
。由於
unique
的輸出大小取決於資料,因此此函數通常與jit()
和其他 JAX 轉換不相容。JAX 版本新增了選用的size
引數,必須靜態指定此引數,才能在這些情況下使用jnp.unique
。- 參數:
ar (ArrayLike) – 將從中提取唯一值的 N 維陣列。
return_index (bool) – 若為 True,則同時傳回
ar
中每個值出現位置的索引return_inverse (bool) – 若為 True,則同時傳回可用於從唯一值重建
ar
的索引。return_counts (bool) – 若為 True,則同時傳回每個唯一值的出現次數。
axis (int | None | None) – 若有指定,則沿著指定的軸計算唯一值。若為 None (預設值),則在計算唯一值之前展平
ar
。equal_nan (bool) – 若為 True,則在判斷唯一性時將 NaN 值視為相等。
size (int | None | None) – 若有指定,則僅傳回前
size
個排序後的唯一元素。如果唯一元素的數量少於size
指示的數量,則傳回值將以fill_value
填補。fill_value (ArrayLike | None | None) – 當指定
size
且元素數量少於指示的數量時,以fill_value
填滿剩餘的項目。fill_value
預設為最小的唯一值。
- 傳回值:
一個陣列或陣列元組,取決於
return_index
、return_inverse
和return_counts
的值。傳回值為unique_values
:若
axis
為 None,則為長度為n_unique
的 1D 陣列。若有指定axis
,則形狀為(*ar.shape[:axis], n_unique, *ar.shape[axis + 1:])
。
unique_index
:(僅在 return_index 為 True 時傳回) 形狀為
(n_unique,)
的陣列。包含ar
中每個唯一值首次出現位置的索引。對於 1D 輸入,ar[unique_index]
相當於unique_values
。
unique_inverse
:(僅在 return_inverse 為 True 時傳回) 若
axis
為 None,則為形狀為(ar.size,)
的陣列,若有指定axis
,則為形狀為(ar.shape[axis],)
的陣列。包含ar
中每個值的unique_values
內的索引。對於 1D 輸入,unique_values[unique_inverse]
相當於ar
。
unique_counts
:(僅在 return_counts 為 True 時傳回) 形狀為
(n_unique,)
的陣列。包含ar
中每個唯一值的出現次數。
另請參閱
jax.numpy.unique_counts()
:unique(arr, return_counts=True)
的捷徑。jax.numpy.unique_inverse()
:unique(arr, return_inverse=True)
的捷徑。jax.numpy.unique_all()
:包含所有傳回值的unique
捷徑。jax.numpy.unique_values()
:類似unique
,但不含選用的傳回值。
範例
>>> x = jnp.array([3, 4, 1, 3, 1]) >>> jnp.unique(x) Array([1, 3, 4], dtype=int32)
JIT 編譯 & size 引數
如果您在
jit()
或其他轉換下嘗試此操作,您會收到錯誤,因為輸出形狀是動態的>>> jax.jit(jnp.unique)(x) Traceback (most recent call last): ... jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[5]. The error arose for the first argument of jnp.unique(). To make jnp.unique() compatible with JIT and other transforms, you can specify a concrete value for the size argument, which will determine the output size.
問題在於轉換後的函數輸出必須具有靜態形狀。為了使其運作,您可以傳遞靜態
size
參數>>> jit_unique = jax.jit(jnp.unique, static_argnames=['size']) >>> jit_unique(x, size=3) Array([1, 3, 4], dtype=int32)
如果您的靜態大小小於唯一值的真實數量,則會截斷這些值。
>>> jit_unique(x, size=2) Array([1, 3], dtype=int32)
如果靜態大小大於唯一值的真實數量,則會以
fill_value
填補這些值,預設值為最小的唯一值>>> jit_unique(x, size=5) Array([1, 3, 4, 1, 1], dtype=int32) >>> jit_unique(x, size=5, fill_value=0) Array([1, 3, 4, 0, 0], dtype=int32)
多維唯一值
如果您將多維陣列傳遞至
unique
,則預設會將其展平>>> M = jnp.array([[1, 2], ... [2, 3], ... [1, 2]]) >>> jnp.unique(M) Array([1, 2, 3], dtype=int32)
如果您傳遞
axis
關鍵字,您可以沿該軸找到陣列的唯一切片>>> jnp.unique(M, axis=0) Array([[1, 2], [2, 3]], dtype=int32)
傳回索引
如果您設定
return_index=True
,則unique
會傳回每個唯一值首次出現位置的索引>>> x = jnp.array([3, 4, 1, 3, 1]) >>> values, indices = jnp.unique(x, return_index=True) >>> print(values) [1 3 4] >>> print(indices) [2 0 1] >>> jnp.all(values == x[indices]) Array(True, dtype=bool)
在多個維度中,可以使用沿指定軸評估的
jax.numpy.take()
提取唯一值>>> values, indices = jnp.unique(M, axis=0, return_index=True) >>> jnp.all(values == jnp.take(M, indices, axis=0)) Array(True, dtype=bool)
傳回反向索引
如果您設定
return_inverse=True
,則unique
會傳回輸入陣列中每個項目的唯一值內的索引>>> x = jnp.array([3, 4, 1, 3, 1]) >>> values, inverse = jnp.unique(x, return_inverse=True) >>> print(values) [1 3 4] >>> print(inverse) [1 2 0 1 0] >>> jnp.all(values[inverse] == x) Array(True, dtype=bool)
在多個維度中,可以使用
jax.numpy.take()
重建輸入>>> values, inverse = jnp.unique(M, axis=0, return_inverse=True) >>> jnp.all(jnp.take(values, inverse, axis=0) == M) Array(True, dtype=bool)
傳回計數
如果您設定
return_counts=True
,則unique
會傳回輸入中每個唯一值的出現次數>>> x = jnp.array([3, 4, 1, 3, 1]) >>> values, counts = jnp.unique(x, return_counts=True) >>> print(values) [1 3 4] >>> print(counts) [2 2 1]
對於多維陣列,這也會傳回 1D 計數陣列,指示沿指定軸的出現次數
>>> values, counts = jnp.unique(M, axis=0, return_counts=True) >>> print(values) [[1 2] [2 3]] >>> print(counts) [2 1]