jax.numpy.piecewise#

jax.numpy.piecewise(x, condlist, funclist, *args, **kw)[原始碼]#

在定義域上分段評估函數。

JAX 版本的 numpy.piecewise(),以 jax.lax.switch() 實作。

注意

不同於 numpy.piecewise()jax.numpy.piecewise() 要求 funclist 中的函數可被 JAX 追蹤,因為它是透過 jax.lax.switch() 實作的。

參數:
  • x (ArrayLike) – 輸入值的陣列。

  • condlist (Array | Sequence[ArrayLike]) – 布林陣列或布林陣列序列,對應於 funclist 中的函數。如果是一個陣列序列,則每個陣列的長度必須與 x 的長度相符

  • funclist (list[ArrayLike | Callable[..., Array]]) – 陣列或函數的列表;長度必須與 condlist 相同,或長度為 len(condlist) + 1,在後者的情況下,最後一個條目是當沒有條件為 True 時套用的預設值。或者,funclist 的條目可以是數值,在這種情況下,它們表示常數函數。

  • args – 傳遞給 funclist 中每個函數的其他引數。

  • kwargs – 傳遞給 funclist 中每個函數的其他引數。

傳回值:

一個陣列,它是根據指定條件在 x 上評估函數的結果。

傳回型別:

Array

另請參閱

範例

以下範例是一個函數,對於負值為零,對於正值為線性

>>> x = jnp.array([-4, -3, -2, -1, 0, 1, 2, 3, 4])
>>> condlist = [x < 0, x >= 0]
>>> funclist = [lambda x: 0 * x, lambda x: x]
>>> jnp.piecewise(x, condlist, funclist)
Array([0, 0, 0, 0, 0, 1, 2, 3, 4], dtype=int32)

funclist 也可以包含簡單的純量值作為常數函數

>>> condlist = [x < 0, x >= 0]
>>> funclist = [0, lambda x: x]
>>> jnp.piecewise(x, condlist, funclist)
Array([0, 0, 0, 0, 0, 1, 2, 3, 4], dtype=int32)

您可以透過在 funclist 中附加額外條件來指定預設值

>>> condlist = [x < -1, x > 1]
>>> funclist = [lambda x: 1 + x, lambda x: x - 1, 0]
>>> jnp.piecewise(x, condlist, funclist)
Array([-3, -2,  -1,  0,  0,  0,  1,  2, 3], dtype=int32)

condlist 也可以是一個簡單的純量條件陣列,在這種情況下,相關函數適用於整個範圍

>>> condlist = jnp.array([False, True, False])
>>> funclist = [lambda x: x * 0, lambda x: x * 10, lambda x: x * 100]
>>> jnp.piecewise(x, condlist, funclist)
Array([-40, -30, -20, -10,   0,  10,  20,  30,  40], dtype=int32)