jax.numpy.sqrt#

jax.numpy.sqrt(x, /)[source]#

計算輸入陣列的元素級非負平方根。

numpy.sqrt 的 JAX 實作。

參數:

x (ArrayLike) – 輸入陣列或純量。

回傳:

包含 x 元素之非負平方根的陣列。

回傳型別:

Array

注意

  • 對於實值負輸入,jnp.sqrt 會產生 nan 輸出。

  • 對於複值負輸入,jnp.sqrt 會產生 complex 輸出。

另請參閱

範例

>>> x = jnp.array([-8-6j, 1j, 4])
>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.sqrt(x)
Array([1.   -3.j   , 0.707+0.707j, 2.   +0.j   ], dtype=complex64)
>>> jnp.sqrt(-1)
Array(nan, dtype=float32, weak_type=True)