jax.numpy.sign#
- jax.numpy.sign(x, /)[source]#
返回輸入符號的元素級指示。
JAX 實現的
numpy.sign
。實數輸入
x
的符號為\[\begin{split}\mathrm{sign}(x) = \begin{cases} 1, & x > 0\\ 0, & x = 0\\ -1, & x < 0 \end{cases}\end{split}\]對於複數值輸入,
jnp.sign
返回表示相位的單位向量。對於一般情況,x
的符號由下式給出\[\begin{split}\mathrm{sign}(x) = \begin{cases} \frac{x}{abs(x)}, & x \ne 0\\ 0, & x = 0 \end{cases}\end{split}\]- 參數:
x (ArrayLike) – 輸入陣列或純量。
- 返回:
一個與
x
具有相同形狀和 dtype 的陣列,包含符號指示。- 返回類型:
另請參閱
jax.numpy.positive()
:返回輸入的元素級正值。jax.numpy.negative()
:返回輸入的元素級負值。
範例
對於實數值輸入
>>> x = jnp.array([0., -3., 7.]) >>> jnp.sign(x) Array([ 0., -1., 1.], dtype=float32)
對於複數輸入
>>> x1 = jnp.array([1, 3+4j, 5j]) >>> jnp.sign(x1) Array([1. +0.j , 0.6+0.8j, 0. +1.j ], dtype=complex64)