jax.numpy.unique_inverse#

jax.numpy.unique_inverse(x, /, *, size=None, fill_value=None)[source]#

從 x 返回唯一值,以及索引、反向索引和計數。

JAX 版本的 numpy.unique_inverse();這等同於呼叫 jax.numpy.unique(),並將 return_inverseequal_nan 設定為 True。

由於 unique_inverse 的輸出大小取決於資料,因此該函數通常與 jit() 和其他 JAX 轉換不相容。JAX 版本添加了可選的 size 參數,必須靜態指定該參數才能在這種情況下使用 jnp.unique

參數:
  • x (ArrayLike) – 將從中提取唯一值的 N 維陣列。

  • size (int | None | None) – 如果指定,則僅返回前 size 個排序後的唯一元素。如果唯一元素的數量少於 size 指示的數量,則返回值將使用 fill_value 填充。

  • fill_value (ArrayLike | None | None) – 當指定 size 並且元素數量少於指示的數量時,使用 fill_value 填充剩餘條目。預設為最小唯一值。

返回:

  • values:

    形狀為 (n_unique,) 的陣列,包含來自 x 的唯一值。

  • inverse_indices:

    形狀為 x.shape 的陣列。包含 valuesx 中每個值的索引。對於 1D 輸入,values[inverse_indices] 等效於 x

返回類型:

一個元組 (values, indices, inverse_indices, counts),具有以下屬性

另請參閱

範例

在這裡,我們計算 1D 陣列中的唯一值

>>> x = jnp.array([3, 4, 1, 3, 1])
>>> result = jnp.unique_inverse(x)

結果是一個具有兩個命名屬性的 NamedTuplevalues 屬性包含來自陣列的唯一值

>>> result.values
Array([1, 3, 4], dtype=int32)

indices 屬性包含輸入陣列中唯一 values 的索引

inverse_indices 屬性包含 values 中輸入的索引

>>> result.inverse_indices
Array([1, 2, 0, 1, 0], dtype=int32)
>>> jnp.all(x == result.values[result.inverse_indices])
Array(True, dtype=bool)

有關 sizefill_value 參數的範例,請參閱 jax.numpy.unique()