jax.numpy.where#

jax.numpy.where(condition, x=None, y=None, /, *, size=None, fill_value=None)[原始碼]#

根據條件從兩個陣列中選取元素。

numpy.where() 的 JAX 實作。

注意

當僅提供 condition 時,jnp.where(condition) 等同於 jnp.nonzero(condition)。對於這種情況,請參閱 jax.numpy.nonzero() 的文件。以下文件字串重點在於指定 xy 的情況。

jnp.where 的三項版本會降低為 jax.lax.select()

參數:
  • condition – 布林陣列。當指定 xy 時,必須與它們廣播相容。

  • x – 類陣列。應與 conditiony 廣播相容,且與 y 類型轉換相容。

  • y – 類陣列。應與 conditionx 廣播相容,且與 x 類型轉換相容。

  • size – 整數,僅在 xyNone 時參考。有關詳細資訊,請參閱 jax.numpy.nonzero()

  • fill_value – 僅在 xyNone 時參考。有關詳細資訊,請參閱 jax.numpy.nonzero()

回傳值:

dtype 為 jnp.result_type(x, y) 的陣列,其值從 condition 為 True 時的 x 以及 condition 為 False 時的 y 中提取。如果 xyNone,則此函數的行為會有所不同;請參閱 jax.numpy.nonzero() 以取得回傳類型說明。

注意事項

jax.numpy.where()xy 輸入可能具有 NaN 值時,需要特別注意。具體而言,當使用 jax.grad() (反向模式微分) 取得梯度時,xy 中的 NaN 將會傳播到梯度中,而與 condition 的值無關。有關此行為和解決方法的更多資訊,請參閱 JAX FAQ

範例

當未提供 xy 時,where 的行為等同於 jax.numpy.nonzero()

>>> x = jnp.arange(10)
>>> jnp.where(x > 4)
(Array([5, 6, 7, 8, 9], dtype=int32),)
>>> jnp.nonzero(x > 4)
(Array([5, 6, 7, 8, 9], dtype=int32),)

當提供 xy 時,where 會根據指定的條件在它們之間進行選擇

>>> jnp.where(x > 4, x, 0)
Array([0, 0, 0, 0, 0, 5, 6, 7, 8, 9], dtype=int32)