jax.numpy.isscalar#

jax.numpy.isscalar(element)[原始碼]#

如果輸入為純量,則返回 True。

numpy.isscalar() 的 JAX 實作。JAX 的實作與 NumPy 的不同之處在於,它將零維陣列視為純量;詳情請參閱下方的注意

參數:

element (Any) – 要檢查的輸入物件;任何型別都是有效的輸入。

返回值:

如果 element 是純量值或零維的類陣列物件,則為 True,否則為 False。

返回型別:

bool

注意

JAX 和 NumPy 在純量值的表示方式上有所不同。NumPy 具有特殊的純量物件(例如 np.int32(0)),這些物件與零維陣列(例如 np.array(0))不同,並且 numpy.isscalar() 對於前者返回 True,對於後者返回 False

JAX 沒有定義特殊的純量物件,而是將純量表示為零維陣列。因此,jax.numpy.isscalar() 對於純量物件(例如 0.0np.float32(0.0))和零維的類陣列物件(例如 jnp.array(0.0)np.array(0.0))都返回 True

isscalar 中不同慣例的原因之一是為了維持 JIT 不變性:即當函數經過 JIT 編譯時,函數的結果不應改變的屬性。由於純量輸入在 JIT 邊界會被轉換為零維 JAX 陣列,因此 numpy.isscalar() 的語意會導致結果在 JIT 下發生變化

>>> np.isscalar(1.0)
True
>>> jax.jit(np.isscalar)(1.0)
Array(False, dtype=bool)

通過將零維陣列視為純量,jax.numpy.isscalar() 避免了這個問題

>>> jnp.isscalar(1.0)
True
>>> jax.jit(jnp.isscalar)(1.0)
Array(True, dtype=bool)

範例

在 JAX 中,純量和零維的類陣列物件都被視為純量

>>> jnp.isscalar(1.0)
True
>>> jnp.isscalar(1 + 1j)
True
>>> jnp.isscalar(jnp.array(1))  # zero-dimensional JAX array
True
>>> jnp.isscalar(jnp.int32(1))  # JAX scalar constructor
True
>>> jnp.isscalar(np.array(1.0))  # zero-dimensional NumPy array
True
>>> jnp.isscalar(np.int32(1))  # NumPy scalar type
True

具有一個或多個維度的陣列不被視為純量

>>> jnp.isscalar(jnp.array([1]))
False
>>> jnp.isscalar(np.array([1]))
False

將此與 numpy.isscalar() 比較,後者對於純量型別的物件返回 True,而對於所有陣列(即使是零維的陣列)都返回 False

>>> np.isscalar(np.int32(1))  # scalar object
True
>>> np.isscalar(np.array(1))  # zero-dimensional array
False

在 JAX 中,與 NumPy 一樣,非類陣列的物件不被視為純量

>>> jnp.isscalar(None)
False
>>> jnp.isscalar([1])
False
>>> jnp.isscalar(tuple())
False
>>> jnp.isscalar(slice(10))
False