jax.typing 模組#

JAX 型別標註模組是 JAX 特定的靜態型別註解所在之處。此子模組仍在開發中;若要查看此處匯出的型別背後的提案,請參閱 https://jax.dev.org.tw/en/latest/jep/12049-type-annotations.html

目前可用的型別為

  • jax.Array:任何 JAX 陣列或追蹤器(即 JAX 轉換中陣列的表示形式)的註解。

  • jax.typing.ArrayLike:可用於安全地隱式轉換為 JAX 陣列的任何值的註解;這包括 jax.Arraynumpy.ndarray,以及 Python 內建數值型別(例如 intfloat 等)和 numpy 純量值(例如 numpy.int32numpy.float64 等)

  • jax.typing.DTypeLike:可用於轉換為 JAX 相容 dtype 的任何值的註解;這包括字串(例如 ‘float32’‘int32’)、純量型別(例如 floatnp.float32)、dtypes(例如 np.dtype(‘float32’))或具有 dtype 屬性的物件(例如 jnp.float32jnp.int32)。

我們可能會在未來版本中在此處新增其他型別。

JAX 型別標註最佳實務#

在公開 API 函數中註解 JAX 陣列時,我們建議對陣列輸入使用 ArrayLike,對陣列輸出使用 Array

例如,您的函數可能看起來像這樣

import numpy as np
import jax.numpy as jnp
from jax import Array
from jax.typing import ArrayLike

def my_function(x: ArrayLike) -> Array:
  # Runtime type validation, Python 3.10 or newer:
  if not isinstance(x, ArrayLike):
    raise TypeError(f"Expected arraylike input; got {x}")
  # Runtime type validation, any Python version:
  if not (isinstance(x, (np.ndarray, Array)) or np.isscalar(x)):
    raise TypeError(f"Expected arraylike input; got {x}")

  # Convert input to jax.Array:
  x_arr = jnp.asarray(x)

  # ... do some computation; JAX functions will return Array types:
  result = x_arr.sum(0) / x_arr.shape[0]

  # return an Array
  return result

大多數 JAX 的公開 API 都遵循此模式。特別注意,我們建議 JAX 函數不要接受序列(例如 listtuple)來代替陣列,因為這可能會在 JAX 轉換(如 jit())中造成額外負擔,並且在批次轉換(如 vmap()jax.pmap())中產生非預期的行為。有關此方面的更多資訊,請參閱 非陣列輸入 NumPy 與 JAX

成員列表#

ArrayLike

JAX 類陣列物件的型別註解。

DTypeLike

別名為 str | type[Any] | dtype | SupportsDType