jax.numpy.maximum#
- jax.numpy.maximum(x, y, /)[原始碼]#
傳回輸入陣列的元素級最大值。
numpy.maximum
的 JAX 實作。- 參數:
x (ArrayLike) – 輸入陣列或純量。
y (ArrayLike) – 輸入陣列或純量。
x
和y
應具有相同形狀或可廣播相容。
- 傳回值:
包含
x
和y
元素級最大值的陣列。- 傳回型別:
注意
- 對於每對元素,
jnp.maximum
傳回 如果兩個元素都是有限數字,則傳回較大者。
如果一個元素是
nan
,則傳回nan
。
另請參閱
jax.numpy.minimum()
:傳回輸入陣列的元素級最小值。jax.numpy.fmax()
:傳回輸入陣列的元素級最大值,忽略 NaN。jax.numpy.amax()
:傳回沿著給定軸的陣列元素最大值。jax.numpy.nanmax()
:傳回沿著給定軸的陣列元素最大值,忽略 NaN。
範例
輸入具有
x.shape == y.shape
>>> x = jnp.array([1, -5, 3, 2]) >>> y = jnp.array([-2, 4, 7, -6]) >>> jnp.maximum(x, y) Array([1, 4, 7, 2], dtype=int32)
具有廣播相容性的輸入
>>> x1 = jnp.array([[-2, 5, 7, 4], ... [1, -6, 3, 8]]) >>> y1 = jnp.array([-5, 3, 6, 9]) >>> jnp.maximum(x1, y1) Array([[-2, 5, 7, 9], [ 1, 3, 6, 9]], dtype=int32)
具有
nan
的輸入>>> nan = jnp.nan >>> x2 = jnp.array([nan, -3, 9]) >>> y2 = jnp.array([[4, -2, nan], ... [-3, -5, 10]]) >>> jnp.maximum(x2, y2) Array([[nan, -2., nan], [nan, -3., 10.]], dtype=float32)