jax.numpy.isscalar#
- jax.numpy.isscalar(element)[原始碼]#
如果輸入為純量,則返回 True。
numpy.isscalar()
的 JAX 實作。JAX 的實作與 NumPy 的不同之處在於,它將零維陣列視為純量;詳情請參閱下方的注意。- 參數:
element (Any) – 要檢查的輸入物件;任何型別都是有效的輸入。
- 返回值:
如果
element
是純量值或零維的類陣列物件,則為 True,否則為 False。- 返回型別:
注意
JAX 和 NumPy 在純量值的表示方式上有所不同。NumPy 具有特殊的純量物件(例如
np.int32(0)
),這些物件與零維陣列(例如np.array(0)
)不同,並且numpy.isscalar()
對於前者返回True
,對於後者返回False
。JAX 沒有定義特殊的純量物件,而是將純量表示為零維陣列。因此,
jax.numpy.isscalar()
對於純量物件(例如0.0
或np.float32(0.0)
)和零維的類陣列物件(例如jnp.array(0.0)
、np.array(0.0)
)都返回True
。isscalar
中不同慣例的原因之一是為了維持 JIT 不變性:即當函數經過 JIT 編譯時,函數的結果不應改變的屬性。由於純量輸入在 JIT 邊界會被轉換為零維 JAX 陣列,因此numpy.isscalar()
的語意會導致結果在 JIT 下發生變化>>> np.isscalar(1.0) True >>> jax.jit(np.isscalar)(1.0) Array(False, dtype=bool)
通過將零維陣列視為純量,
jax.numpy.isscalar()
避免了這個問題>>> jnp.isscalar(1.0) True >>> jax.jit(jnp.isscalar)(1.0) Array(True, dtype=bool)
範例
在 JAX 中,純量和零維的類陣列物件都被視為純量
>>> jnp.isscalar(1.0) True >>> jnp.isscalar(1 + 1j) True >>> jnp.isscalar(jnp.array(1)) # zero-dimensional JAX array True >>> jnp.isscalar(jnp.int32(1)) # JAX scalar constructor True >>> jnp.isscalar(np.array(1.0)) # zero-dimensional NumPy array True >>> jnp.isscalar(np.int32(1)) # NumPy scalar type True
具有一個或多個維度的陣列不被視為純量
>>> jnp.isscalar(jnp.array([1])) False >>> jnp.isscalar(np.array([1])) False
將此與
numpy.isscalar()
比較,後者對於純量型別的物件返回True
,而對於所有陣列(即使是零維的陣列)都返回False
>>> np.isscalar(np.int32(1)) # scalar object True >>> np.isscalar(np.array(1)) # zero-dimensional array False
在 JAX 中,與 NumPy 一樣,非類陣列的物件不被視為純量
>>> jnp.isscalar(None) False >>> jnp.isscalar([1]) False >>> jnp.isscalar(tuple()) False >>> jnp.isscalar(slice(10)) False