jax.numpy.fmax#
- jax.numpy.fmax(x1, x2)[原始碼]#
傳回輸入陣列的元素級最大值。
numpy.fmax()
的 JAX 實作。- 參數:
x1 (ArrayLike) – 輸入陣列或純量
x2 (ArrayLike) – 輸入陣列或純量。x1 和 x1 必須具有相同的形狀或可廣播相容。
- 傳回:
一個陣列,包含 x1 和 x2 的元素級最大值。
- 傳回類型:
注意
- 對於每對元素,
jnp.fmax
傳回 如果兩個元素都是有限數字,則傳回較大者。
如果一個元素是
nan
,則傳回有限數字。如果兩個元素都是
nan
,則傳回nan
。如果一個元素是
inf
,而另一個是有限或nan
,則傳回inf
。如果一個元素是
-inf
,而另一個是nan
,則傳回-inf
。
範例
>>> jnp.fmax(3, 7) Array(7, dtype=int32, weak_type=True) >>> jnp.fmax(5, jnp.array([1, 7, 9, 4])) Array([5, 7, 9, 5], dtype=int32)
>>> x1 = jnp.array([1, 3, 7, 8]) >>> x2 = jnp.array([-1, 4, 6, 9]) >>> jnp.fmax(x1, x2) Array([1, 4, 7, 9], dtype=int32)
>>> x3 = jnp.array([[2, 3, 5, 10], ... [11, 9, 7, 5]]) >>> jnp.fmax(x1, x3) Array([[ 2, 3, 7, 10], [11, 9, 7, 8]], dtype=int32)
>>> x4 = jnp.array([jnp.inf, 6, -jnp.inf, nan]) >>> x5 = jnp.array([[3, 5, 7, nan], ... [nan, 9, nan, -1]]) >>> jnp.fmax(x4, x5) Array([[ inf, 6., 7., nan], [ inf, 9., -inf, -1.]], dtype=float32)