jax.numpy.maximum#

jax.numpy.maximum(x, y, /)[原始碼]#

傳回輸入陣列的元素級最大值。

numpy.maximum 的 JAX 實作。

參數:
  • x (ArrayLike) – 輸入陣列或純量。

  • y (ArrayLike) – 輸入陣列或純量。 xy 應具有相同形狀或可廣播相容。

傳回值:

包含 xy 元素級最大值的陣列。

傳回型別:

Array

注意

對於每對元素,jnp.maximum 傳回
  • 如果兩個元素都是有限數字,則傳回較大者。

  • 如果一個元素是 nan,則傳回 nan

另請參閱

範例

輸入具有 x.shape == y.shape

>>> x = jnp.array([1, -5, 3, 2])
>>> y = jnp.array([-2, 4, 7, -6])
>>> jnp.maximum(x, y)
Array([1, 4, 7, 2], dtype=int32)

具有廣播相容性的輸入

>>> x1 = jnp.array([[-2, 5, 7, 4],
...                 [1, -6, 3, 8]])
>>> y1 = jnp.array([-5, 3, 6, 9])
>>> jnp.maximum(x1, y1)
Array([[-2,  5,  7,  9],
       [ 1,  3,  6,  9]], dtype=int32)

具有 nan 的輸入

>>> nan = jnp.nan
>>> x2 = jnp.array([nan, -3, 9])
>>> y2 = jnp.array([[4, -2, nan],
...                 [-3, -5, 10]])
>>> jnp.maximum(x2, y2)
Array([[nan, -2., nan],
      [nan, -3., 10.]], dtype=float32)