jax.numpy.sqrt#
- jax.numpy.sqrt(x, /)[source]#
計算輸入陣列的元素級非負平方根。
numpy.sqrt
的 JAX 實作。- 參數:
x (ArrayLike) – 輸入陣列或純量。
- 回傳:
包含
x
元素之非負平方根的陣列。- 回傳型別:
注意
對於實值負輸入,
jnp.sqrt
會產生nan
輸出。對於複值負輸入,
jnp.sqrt
會產生complex
輸出。
另請參閱
jax.numpy.square()
:計算輸入的元素級平方。jax.numpy.power()
:計算x2
的元素級底數x1
指數。
範例
>>> x = jnp.array([-8-6j, 1j, 4]) >>> with jnp.printoptions(precision=3, suppress=True): ... jnp.sqrt(x) Array([1. -3.j , 0.707+0.707j, 2. +0.j ], dtype=complex64) >>> jnp.sqrt(-1) Array(nan, dtype=float32, weak_type=True)