jax.numpy.min#

jax.numpy.min(a, axis=None, out=None, keepdims=False, initial=None, where=None)[原始碼]#

傳回沿著給定軸的陣列元素的最小值。

JAX 實作的 numpy.min()

參數:
  • a (ArrayLike) – 輸入陣列。

  • axis (Axis | None) – int 或陣列,預設值=None。要計算最小值的軸。如果為 None,則沿所有軸計算最小值。

  • keepdims (bool) – bool,預設值=False。如果為 true,則縮減的軸會保留在結果中,大小為 1。

  • initial (ArrayLike | None | None) – int 或陣列,預設值=None。最小值的初始值。

  • where (ArrayLike | None | None) – int 或陣列,預設值=None。要在最小值中使用的元素。陣列應與輸入廣播相容。initial 在使用 where 時必須指定。

  • out (None | None) – JAX 未使用。

傳回值:

沿著給定軸的最小值陣列。

傳回類型:

Array

參見

範例

預設情況下,最小值是沿所有軸計算的。

>>> x = jnp.array([[2, 5, 1, 6],
...                [3, -7, -2, 4],
...                [8, -4, 1, -3]])
>>> jnp.min(x)
Array(-7, dtype=int32)

如果 axis=1,則沿軸 1 計算最小值。

>>> jnp.min(x, axis=1)
Array([ 1, -7, -4], dtype=int32)

如果 keepdims=True,則輸出的 ndim 將與輸入的相同。

>>> jnp.min(x, axis=1, keepdims=True)
Array([[ 1],
       [-7],
       [-4]], dtype=int32)

若要僅包含特定元素來計算最小值,您可以使用 wherewhere 可以與輸入具有相同的維度。

>>> where=jnp.array([[1, 0, 1, 0],
...                  [0, 0, 1, 1],
...                  [1, 1, 1, 0]], dtype=bool)
>>> jnp.min(x, axis=1, keepdims=True, initial=0, where=where)
Array([[ 0],
       [-2],
       [-4]], dtype=int32)

或必須與輸入廣播相容。

>>> where = jnp.array([[False],
...                    [False],
...                    [False]])
>>> jnp.min(x, axis=0, keepdims=True, initial=0, where=where)
Array([[0, 0, 0, 0]], dtype=int32)