jax.numpy.digitize#

jax.numpy.digitize(x, bins, right=False, *, method=None)[source]#

將陣列轉換為 bin 索引。

JAX 實作的 numpy.digitize()

參數:
  • x (ArrayLike) – 要數位化的數值陣列。

  • bins (ArrayLike) – bin 邊緣的 1D 陣列。必須單調遞增或遞減。

  • right (bool) – 若為 true,則區間包含右 bin 邊緣。若為 false (預設),則區間包含左 bin 邊緣。

  • method (str | None) – 傳遞至 searchsorted() 的選用方法引數。請參閱該函式以取得可用選項。

傳回:

x 形狀相同的整數陣列,表示數值所在的 bin 編號。

傳回型別:

陣列

另請參閱

範例

>>> x = jnp.array([1.0, 2.0, 2.5, 1.5, 3.0, 3.5])
>>> bins = jnp.array([1, 2, 3])
>>> jnp.digitize(x, bins)
Array([1, 2, 2, 1, 3, 3], dtype=int32)
>>> jnp.digitize(x, bins, right=True)
Array([0, 1, 2, 1, 2, 3], dtype=int32)

digitize 也支援反向排序的 bin

>>> bins = jnp.array([3, 2, 1])
>>> jnp.digitize(x, bins)
Array([2, 1, 1, 2, 0, 0], dtype=int32)