jax.numpy.where#
- jax.numpy.where(condition, x=None, y=None, /, *, size=None, fill_value=None)[原始碼]#
根據條件從兩個陣列中選取元素。
numpy.where()
的 JAX 實作。注意
當僅提供
condition
時,jnp.where(condition)
等同於jnp.nonzero(condition)
。對於這種情況,請參閱jax.numpy.nonzero()
的文件。以下文件字串重點在於指定x
和y
的情況。jnp.where
的三項版本會降低為jax.lax.select()
。- 參數:
condition – 布林陣列。當指定
x
和y
時,必須與它們廣播相容。x – 類陣列。應與
condition
和y
廣播相容,且與y
類型轉換相容。y – 類陣列。應與
condition
和x
廣播相容,且與x
類型轉換相容。size – 整數,僅在
x
和y
為None
時參考。有關詳細資訊,請參閱jax.numpy.nonzero()
。fill_value – 僅在
x
和y
為None
時參考。有關詳細資訊,請參閱jax.numpy.nonzero()
。
- 回傳值:
dtype 為
jnp.result_type(x, y)
的陣列,其值從condition
為 True 時的x
以及 condition 為False
時的y
中提取。如果x
和y
為None
,則此函數的行為會有所不同;請參閱jax.numpy.nonzero()
以取得回傳類型說明。
注意事項
當
jax.numpy.where()
的x
或y
輸入可能具有 NaN 值時,需要特別注意。具體而言,當使用jax.grad()
(反向模式微分) 取得梯度時,x
或y
中的 NaN 將會傳播到梯度中,而與condition
的值無關。有關此行為和解決方法的更多資訊,請參閱 JAX FAQ。範例
當未提供
x
和y
時,where
的行為等同於jax.numpy.nonzero()
>>> x = jnp.arange(10) >>> jnp.where(x > 4) (Array([5, 6, 7, 8, 9], dtype=int32),) >>> jnp.nonzero(x > 4) (Array([5, 6, 7, 8, 9], dtype=int32),)
當提供
x
和y
時,where
會根據指定的條件在它們之間進行選擇>>> jnp.where(x > 4, x, 0) Array([0, 0, 0, 0, 0, 5, 6, 7, 8, 9], dtype=int32)