jax.numpy.exp#
- jax.numpy.exp(x, /)[原始碼]#
計算輸入的逐元素指數。
JAX 版本的
numpy.exp
。- 參數:
x (ArrayLike) – 輸入陣列或純量
- 傳回:
一個陣列,包含
x
中每個元素的指數,提升為非精確 dtype。- 傳回型別:
另請參閱
jax.numpy.log()
:計算輸入的逐元素對數。jax.numpy.expm1()
:計算輸入中每個元素的 \(e^x-1\)。jax.numpy.exp2()
:計算輸入中每個元素的底數為 2 的指數。
範例
jnp.exp
遵循指數的屬性,例如 \(e^{(a+b)} = e^a * e^b\)。>>> x1 = jnp.array([2, 4, 3, 1]) >>> x2 = jnp.array([1, 3, 2, 3]) >>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.exp(x1+x2)) [ 20.09 1096.63 148.41 54.6 ] >>> with jnp.printoptions(precision=2, suppress=True): ... print(jnp.exp(x1)*jnp.exp(x2)) [ 20.09 1096.63 148.41 54.6 ]
此屬性也適用於複數輸入
>>> jnp.allclose(jnp.exp(3-4j), jnp.exp(3)*jnp.exp(-4j)) Array(True, dtype=bool)