jax.numpy.fmax#

jax.numpy.fmax(x1, x2)[原始碼]#

傳回輸入陣列的元素級最大值。

numpy.fmax() 的 JAX 實作。

參數:
  • x1 (ArrayLike) – 輸入陣列或純量

  • x2 (ArrayLike) – 輸入陣列或純量。x1 和 x1 必須具有相同的形狀或可廣播相容。

傳回:

一個陣列,包含 x1 和 x2 的元素級最大值。

傳回類型:

陣列

注意

對於每對元素,jnp.fmax 傳回
  • 如果兩個元素都是有限數字,則傳回較大者。

  • 如果一個元素是 nan,則傳回有限數字。

  • 如果兩個元素都是 nan,則傳回 nan

  • 如果一個元素是 inf,而另一個是有限或 nan,則傳回 inf

  • 如果一個元素是 -inf,而另一個是 nan,則傳回 -inf

範例

>>> jnp.fmax(3, 7)
Array(7, dtype=int32, weak_type=True)
>>> jnp.fmax(5, jnp.array([1, 7, 9, 4]))
Array([5, 7, 9, 5], dtype=int32)
>>> x1 = jnp.array([1, 3, 7, 8])
>>> x2 = jnp.array([-1, 4, 6, 9])
>>> jnp.fmax(x1, x2)
Array([1, 4, 7, 9], dtype=int32)
>>> x3 = jnp.array([[2, 3, 5, 10],
...                 [11, 9, 7, 5]])
>>> jnp.fmax(x1, x3)
Array([[ 2,  3,  7, 10],
       [11,  9,  7,  8]], dtype=int32)
>>> x4 = jnp.array([jnp.inf, 6, -jnp.inf, nan])
>>> x5 = jnp.array([[3, 5, 7, nan],
...                 [nan, 9, nan, -1]])
>>> jnp.fmax(x4, x5)
Array([[ inf,   6.,   7.,  nan],
       [ inf,   9., -inf,  -1.]], dtype=float32)