jax.numpy.issubdtype#

jax.numpy.issubdtype(arg1, arg2)[source]#

如果 arg1 在型別層級中等於或低於 arg2,則傳回 True。

JAX 實作的 numpy.issubdtype()

JAX 實作的主要差異在於它正確地處理 dtype 擴充功能,例如 bfloat16

參數:
  • arg1 (DTypeLike) – 類似 dtype 的物件。在一般用法中,這會是 dtype 指定符,例如 "float32" (即字串)、np.dtype('int32') (即 numpy.dtype 的實例)、jnp.complex64 (即 JAX 純量建構子) 或 np.uint8 (即 NumPy 純量型別)。

  • arg2 (DTypeLike) – 類似 dtype 的物件。在一般用法中,這會是通用純量型別,例如 jnp.integerjnp.floatingjnp.complexfloating

傳回:

如果 arg1 代表的 dtype 在型別層級中等於或低於 arg2,則為 True。

傳回型別:

bool

另請參閱

範例

>>> jnp.issubdtype('uint32', jnp.unsignedinteger)
True
>>> jnp.issubdtype(np.int32, jnp.integer)
True
>>> jnp.issubdtype(jnp.bfloat16, jnp.floating)
True
>>> jnp.issubdtype(np.dtype('complex64'), jnp.complexfloating)
True
>>> jnp.issubdtype('complex64', jnp.integer)
False

請注意,雖然這與 numpy.issubdtype() 非常相似,但在 JAX 的自訂浮點型別的情況下,這些結果會有所不同

>>> np.issubdtype('bfloat16', np.floating)
False
>>> jnp.issubdtype('bfloat16', jnp.floating)
True