jax.numpy.sign#

jax.numpy.sign(x, /)[source]#

返回輸入符號的元素級指示。

JAX 實現的 numpy.sign

實數輸入 x 的符號為

\[\begin{split}\mathrm{sign}(x) = \begin{cases} 1, & x > 0\\ 0, & x = 0\\ -1, & x < 0 \end{cases}\end{split}\]

對於複數值輸入,jnp.sign 返回表示相位的單位向量。對於一般情況,x 的符號由下式給出

\[\begin{split}\mathrm{sign}(x) = \begin{cases} \frac{x}{abs(x)}, & x \ne 0\\ 0, & x = 0 \end{cases}\end{split}\]
參數:

x (ArrayLike) – 輸入陣列或純量。

返回:

一個與 x 具有相同形狀和 dtype 的陣列,包含符號指示。

返回類型:

陣列

另請參閱

範例

對於實數值輸入

>>> x = jnp.array([0., -3., 7.])
>>> jnp.sign(x)
Array([ 0., -1.,  1.], dtype=float32)

對於複數輸入

>>> x1 = jnp.array([1, 3+4j, 5j])
>>> jnp.sign(x1)
Array([1. +0.j , 0.6+0.8j, 0. +1.j ], dtype=complex64)