jax.numpy.power#
- jax.numpy.power(x1, x2, /)[原始碼]#
計算逐元素的基底
x1
的x2
次方。JAX 版本的
numpy.power
。- 參數:
x1 (ArrayLike) – 純量或陣列。指定基底。
x2 (ArrayLike) – 純量或陣列。指定指數。
x1
和x2
應具有相同形狀或可廣播相容。
- 傳回:
一個陣列,包含基底
x1
的x2
次方,其 dtype 與輸入相同。- 傳回類型:
注意
當
x2
是具體的整數純量時,jnp.power
會降低為jax.lax.integer_pow()
。當
x2
是追蹤的純量或陣列時,jnp.power
會降低為jax.lax.pow()
。jnp.power
對於整數類型提升為負整數次方會引發TypeError
。jnp.power
對於負值提升為非整數次方會傳回nan
。
另請參閱
jax.lax.pow()
:計算逐元素次方,\(x^y\)。jax.lax.integer_pow()
:計算逐元素次方 \(x^y\),其中 \(y\) 是固定的整數。jax.numpy.float_power()
:通過提升為非精確 dtype,計算第一個陣列提升為第二個陣列的次方,逐元素計算。jax.numpy.pow()
:計算第一個陣列提升為第二個陣列的次方,逐元素計算。
範例
具有純量整數的輸入
>>> jnp.power(4, 3) Array(64, dtype=int32, weak_type=True)
具有相同形狀的輸入
>>> x1 = jnp.array([2, 4, 5]) >>> x2 = jnp.array([3, 0.5, 2]) >>> jnp.power(x1, x2) Array([ 8., 2., 25.], dtype=float32)
具有廣播相容性的輸入
>>> x3 = jnp.array([-2, 3, 1]) >>> x4 = jnp.array([[4, 1, 6], ... [1.3, 3, 5]]) >>> jnp.power(x3, x4) Array([[16., 3., 1.], [nan, 27., 1.]], dtype=float32)