jax.numpy.square#
- jax.numpy.square(x, /)[原始碼]#
計算輸入陣列的元素平方。
JAX 版本的
numpy.square
。- 參數:
x (ArrayLike) – 輸入陣列或純量。
- 返回:
一個包含
x
元素平方的陣列。- 返回類型:
注意
jnp.square
等同於計算jnp.power(x, 2)
。參見
jax.numpy.sqrt()
:計算輸入陣列的元素非負平方根。jax.numpy.power()
:計算元素基底x1
的x2
次方指數。jax.lax.integer_pow()
:計算元素次方 \(x^y\),其中 \(y\) 是固定的整數。jax.numpy.float_power()
:通過提升到非精確 dtype,計算第一個陣列的第二個陣列次方(元素方式)。
範例
>>> x = jnp.array([3, -2, 5.3, 1]) >>> jnp.square(x) Array([ 9. , 4. , 28.090002, 1. ], dtype=float32) >>> jnp.power(x, 2) Array([ 9. , 4. , 28.090002, 1. ], dtype=float32)
對於整數輸入
>>> x1 = jnp.array([2, 4, 5, 6]) >>> jnp.square(x1) Array([ 4, 16, 25, 36], dtype=int32)
對於複數值輸入
>>> x2 = jnp.array([1-3j, -1j, 2]) >>> jnp.square(x2) Array([-8.-6.j, -1.+0.j, 4.+0.j], dtype=complex64)