jax.numpy.minimum#
- jax.numpy.minimum(x, y, /)[原始碼]#
傳回輸入陣列的逐元素最小值。
JAX 版本的
numpy.minimum
。- 參數:
x (ArrayLike) – 輸入陣列或純量。
y (ArrayLike) – 輸入陣列或純量。
x
和y
應具有相同的形狀或可廣播相容。
- 傳回值:
一個包含
x
和y
逐元素最小值的陣列。- 傳回型別:
注意
- 對於每對元素,
jnp.minimum
傳回 如果兩個元素都是有限數字,則傳回較小者。
如果其中一個元素是
nan
,則傳回nan
。
另請參閱
jax.numpy.maximum()
:傳回輸入陣列的逐元素最大值。jax.numpy.fmin()
:傳回輸入陣列的逐元素最小值,忽略 NaN。jax.numpy.amin()
:傳回沿著給定軸的陣列元素最小值。jax.numpy.nanmin()
:傳回沿著給定軸的陣列元素最小值,忽略 NaN。
範例
具有
x.shape == y.shape
的輸入>>> x = jnp.array([2, 3, 5, 1]) >>> y = jnp.array([-3, 6, -4, 7]) >>> jnp.minimum(x, y) Array([-3, 3, -4, 1], dtype=int32)
具有廣播相容性的輸入
>>> x1 = jnp.array([[1, 5, 2], ... [-3, 4, 7]]) >>> y1 = jnp.array([-2, 3, 6]) >>> jnp.minimum(x1, y1) Array([[-2, 3, 2], [-3, 3, 6]], dtype=int32)
具有
nan
的輸入>>> nan = jnp.nan >>> x2 = jnp.array([[2.5, nan, -2], ... [nan, 5, 6], ... [-4, 3, 7]]) >>> y2 = jnp.array([1, nan, 5]) >>> jnp.minimum(x2, y2) Array([[ 1., nan, -2.], [nan, nan, 5.], [-4., nan, 5.]], dtype=float32)