jax.typing
模組#
JAX 型別標註模組是 JAX 特定的靜態型別註解所在之處。此子模組仍在開發中;若要查看此處匯出的型別背後的提案,請參閱 https://jax.dev.org.tw/en/latest/jep/12049-type-annotations.html。
目前可用的型別為
jax.Array
:任何 JAX 陣列或追蹤器(即 JAX 轉換中陣列的表示形式)的註解。jax.typing.ArrayLike
:可用於安全地隱式轉換為 JAX 陣列的任何值的註解;這包括jax.Array
、numpy.ndarray
,以及 Python 內建數值型別(例如int
、float
等)和 numpy 純量值(例如numpy.int32
、numpy.float64
等)jax.typing.DTypeLike
:可用於轉換為 JAX 相容 dtype 的任何值的註解;這包括字串(例如 ‘float32’、‘int32’)、純量型別(例如 float、np.float32)、dtypes(例如 np.dtype(‘float32’))或具有 dtype 屬性的物件(例如 jnp.float32、jnp.int32)。
我們可能會在未來版本中在此處新增其他型別。
JAX 型別標註最佳實務#
在公開 API 函數中註解 JAX 陣列時,我們建議對陣列輸入使用 ArrayLike
,對陣列輸出使用 Array
。
例如,您的函數可能看起來像這樣
import numpy as np
import jax.numpy as jnp
from jax import Array
from jax.typing import ArrayLike
def my_function(x: ArrayLike) -> Array:
# Runtime type validation, Python 3.10 or newer:
if not isinstance(x, ArrayLike):
raise TypeError(f"Expected arraylike input; got {x}")
# Runtime type validation, any Python version:
if not (isinstance(x, (np.ndarray, Array)) or np.isscalar(x)):
raise TypeError(f"Expected arraylike input; got {x}")
# Convert input to jax.Array:
x_arr = jnp.asarray(x)
# ... do some computation; JAX functions will return Array types:
result = x_arr.sum(0) / x_arr.shape[0]
# return an Array
return result
大多數 JAX 的公開 API 都遵循此模式。特別注意,我們建議 JAX 函數不要接受序列(例如 list
或 tuple
)來代替陣列,因為這可能會在 JAX 轉換(如 jit()
)中造成額外負擔,並且在批次轉換(如 vmap()
或 jax.pmap()
)中產生非預期的行為。有關此方面的更多資訊,請參閱 非陣列輸入 NumPy 與 JAX
成員列表#
JAX 類陣列物件的型別註解。 |
|