jax.numpy.piecewise#
- jax.numpy.piecewise(x, condlist, funclist, *args, **kw)[原始碼]#
在定義域上分段評估函數。
JAX 版本的
numpy.piecewise()
,以jax.lax.switch()
實作。注意
不同於
numpy.piecewise()
,jax.numpy.piecewise()
要求funclist
中的函數可被 JAX 追蹤,因為它是透過jax.lax.switch()
實作的。- 參數:
x (ArrayLike) – 輸入值的陣列。
condlist (Array | Sequence[ArrayLike]) – 布林陣列或布林陣列序列,對應於
funclist
中的函數。如果是一個陣列序列,則每個陣列的長度必須與x
的長度相符funclist (list[ArrayLike | Callable[..., Array]]) – 陣列或函數的列表;長度必須與
condlist
相同,或長度為len(condlist) + 1
,在後者的情況下,最後一個條目是當沒有條件為 True 時套用的預設值。或者,funclist
的條目可以是數值,在這種情況下,它們表示常數函數。args – 傳遞給
funclist
中每個函數的其他引數。kwargs – 傳遞給
funclist
中每個函數的其他引數。
- 傳回值:
一個陣列,它是根據指定條件在
x
上評估函數的結果。- 傳回型別:
另請參閱
jax.lax.switch()
:根據索引在 N 個函數之間選擇。jax.lax.cond()
:根據布林條件在兩個函數之間選擇。jax.numpy.where()
:根據布林遮罩在兩個結果之間選擇。jax.lax.select()
:根據布林遮罩在兩個結果之間選擇。jax.lax.select_n()
:根據布林遮罩在 N 個結果之間選擇。
範例
以下範例是一個函數,對於負值為零,對於正值為線性
>>> x = jnp.array([-4, -3, -2, -1, 0, 1, 2, 3, 4])
>>> condlist = [x < 0, x >= 0] >>> funclist = [lambda x: 0 * x, lambda x: x] >>> jnp.piecewise(x, condlist, funclist) Array([0, 0, 0, 0, 0, 1, 2, 3, 4], dtype=int32)
funclist
也可以包含簡單的純量值作為常數函數>>> condlist = [x < 0, x >= 0] >>> funclist = [0, lambda x: x] >>> jnp.piecewise(x, condlist, funclist) Array([0, 0, 0, 0, 0, 1, 2, 3, 4], dtype=int32)
您可以透過在
funclist
中附加額外條件來指定預設值>>> condlist = [x < -1, x > 1] >>> funclist = [lambda x: 1 + x, lambda x: x - 1, 0] >>> jnp.piecewise(x, condlist, funclist) Array([-3, -2, -1, 0, 0, 0, 1, 2, 3], dtype=int32)
condlist
也可以是一個簡單的純量條件陣列,在這種情況下,相關函數適用於整個範圍>>> condlist = jnp.array([False, True, False]) >>> funclist = [lambda x: x * 0, lambda x: x * 10, lambda x: x * 100] >>> jnp.piecewise(x, condlist, funclist) Array([-40, -30, -20, -10, 0, 10, 20, 30, 40], dtype=int32)