jax.numpy.minimum#

jax.numpy.minimum(x, y, /)[原始碼]#

傳回輸入陣列的逐元素最小值。

JAX 版本的 numpy.minimum

參數:
  • x (ArrayLike) – 輸入陣列或純量。

  • y (ArrayLike) – 輸入陣列或純量。 xy 應具有相同的形狀或可廣播相容。

傳回值:

一個包含 xy 逐元素最小值的陣列。

傳回型別:

Array

注意

對於每對元素,jnp.minimum 傳回
  • 如果兩個元素都是有限數字,則傳回較小者。

  • 如果其中一個元素是 nan,則傳回 nan

另請參閱

範例

具有 x.shape == y.shape 的輸入

>>> x = jnp.array([2, 3, 5, 1])
>>> y = jnp.array([-3, 6, -4, 7])
>>> jnp.minimum(x, y)
Array([-3,  3, -4,  1], dtype=int32)

具有廣播相容性的輸入

>>> x1 = jnp.array([[1, 5, 2],
...                 [-3, 4, 7]])
>>> y1 = jnp.array([-2, 3, 6])
>>> jnp.minimum(x1, y1)
Array([[-2,  3,  2],
       [-3,  3,  6]], dtype=int32)

具有 nan 的輸入

>>> nan = jnp.nan
>>> x2 = jnp.array([[2.5, nan, -2],
...                 [nan, 5, 6],
...                 [-4, 3, 7]])
>>> y2 = jnp.array([1, nan, 5])
>>> jnp.minimum(x2, y2)
Array([[ 1., nan, -2.],
       [nan, nan,  5.],
       [-4., nan,  5.]], dtype=float32)