jax.numpy.power#

jax.numpy.power(x1, x2, /)[原始碼]#

計算逐元素的基底 x1x2 次方。

JAX 版本的 numpy.power

參數:
  • x1 (ArrayLike) – 純量或陣列。指定基底。

  • x2 (ArrayLike) – 純量或陣列。指定指數。x1x2 應具有相同形狀或可廣播相容。

傳回:

一個陣列,包含基底 x1x2 次方,其 dtype 與輸入相同。

傳回類型:

Array

注意

  • x2 是具體的整數純量時,jnp.power 會降低為 jax.lax.integer_pow()

  • x2 是追蹤的純量或陣列時,jnp.power 會降低為 jax.lax.pow()

  • jnp.power 對於整數類型提升為負整數次方會引發 TypeError

  • jnp.power 對於負值提升為非整數次方會傳回 nan

另請參閱

  • jax.lax.pow():計算逐元素次方,\(x^y\)

  • jax.lax.integer_pow():計算逐元素次方 \(x^y\),其中 \(y\) 是固定的整數。

  • jax.numpy.float_power():通過提升為非精確 dtype,計算第一個陣列提升為第二個陣列的次方,逐元素計算。

  • jax.numpy.pow():計算第一個陣列提升為第二個陣列的次方,逐元素計算。

範例

具有純量整數的輸入

>>> jnp.power(4, 3)
Array(64, dtype=int32, weak_type=True)

具有相同形狀的輸入

>>> x1 = jnp.array([2, 4, 5])
>>> x2 = jnp.array([3, 0.5, 2])
>>> jnp.power(x1, x2)
Array([ 8.,  2., 25.], dtype=float32)

具有廣播相容性的輸入

>>> x3 = jnp.array([-2, 3, 1])
>>> x4 = jnp.array([[4, 1, 6],
...                 [1.3, 3, 5]])
>>> jnp.power(x3, x4)
Array([[16.,  3.,  1.],
       [nan, 27.,  1.]], dtype=float32)