jax.numpy.issubdtype#
- jax.numpy.issubdtype(arg1, arg2)[source]#
如果 arg1 在型別層級中等於或低於 arg2,則傳回 True。
JAX 實作的
numpy.issubdtype()
。JAX 實作的主要差異在於它正確地處理 dtype 擴充功能,例如
bfloat16
。- 參數:
arg1 (DTypeLike) – 類似 dtype 的物件。在一般用法中,這會是 dtype 指定符,例如
"float32"
(即字串)、np.dtype('int32')
(即numpy.dtype
的實例)、jnp.complex64
(即 JAX 純量建構子) 或np.uint8
(即 NumPy 純量型別)。arg2 (DTypeLike) – 類似 dtype 的物件。在一般用法中,這會是通用純量型別,例如
jnp.integer
、jnp.floating
或jnp.complexfloating
。
- 傳回:
如果 arg1 代表的 dtype 在型別層級中等於或低於 arg2,則為 True。
- 傳回型別:
另請參閱
jax.numpy.isdtype()
:與陣列 API 標準對齊的類似函式。
範例
>>> jnp.issubdtype('uint32', jnp.unsignedinteger) True >>> jnp.issubdtype(np.int32, jnp.integer) True >>> jnp.issubdtype(jnp.bfloat16, jnp.floating) True >>> jnp.issubdtype(np.dtype('complex64'), jnp.complexfloating) True >>> jnp.issubdtype('complex64', jnp.integer) False
請注意,雖然這與
numpy.issubdtype()
非常相似,但在 JAX 的自訂浮點型別的情況下,這些結果會有所不同>>> np.issubdtype('bfloat16', np.floating) False >>> jnp.issubdtype('bfloat16', jnp.floating) True