JAX 類型註解路線圖#

  • 作者:jakevdp

  • 日期:2022 年 8 月

背景#

Python 3.0 引入了可選的函式註解 (PEP 3107),後來在 Python 3.5 發布前後被編纂用於靜態類型檢查 (PEP 484)。在某種程度上,類型註解和靜態類型檢查已成為許多 Python 開發工作流程不可或缺的一部分,為此,我們在整個 JAX API 中的許多地方新增了註解。 JAX 中類型註解的目前狀態有點零散,而新增更多註解的努力受到更基本設計問題的阻礙。本文檔試圖總結這些問題,並為 JAX 中類型註解的目標和非目標產生路線圖。

為何我們需要這樣的路線圖?更好/更全面的類型註解是使用者經常提出的要求,包括內部和外部使用者。此外,我們經常收到外部使用者的提取請求 (例如,PR #9917PR #10322) 尋求改進 JAX 的類型註解:對於審查程式碼的 JAX 團隊成員來說,這些貢獻是否有益並不總是清楚的,特別是當它們引入複雜的協議來解決 JAX 使用 Python 進行完整註解時固有的挑戰。本文檔詳細說明 JAX 對於套件內類型註解的目標和建議。

為何需要類型註解?#

Python 專案可能希望為其程式碼庫加上註解的原因有很多;我們將在本文檔中將其總結為層級 1、層級 2 和層級 3。

層級 1:作為文件用途的註解#

最初在 PEP 3107 中引入時,類型註解的部分動機是能夠將其用作函式參數類型和傳回類型的簡潔內嵌文件。長期以來,JAX 一直以這種方式利用註解;一個範例是建立別名為 Any 的類型名稱的常見模式。一個範例可以在 lax/slicing.py 中找到 [來源]

Array = Any
Shape = core.Shape

def slice(operand: Array, start_indices: Sequence[int],
          limit_indices: Sequence[int],
          strides: Optional[Sequence[int]] = None) -> Array:
  ...

為了靜態類型檢查的目的,這種使用 Array = Any 進行陣列類型註解的方式對參數值沒有任何限制 (Any 等同於完全沒有註解),但它確實可以作為開發人員有用的程式碼內文件形式。

為了產生文件的緣故,別名的名稱會遺失 (jax.lax.sliceHTML 文件 將運算元報告為類型 Any),因此文件的好處不會超出原始程式碼 (儘管我們可以啟用一些 sphinx-autodoc 選項來改進這一點:請參閱 autodoc_type_aliases)。

這種層級的類型註解的好處是,使用 Any 註解值永遠不會出錯,因此它將以文件的形式為開發人員和使用者提供具體的好處,而不會增加滿足任何特定靜態類型檢查器更嚴格需求的複雜性。

層級 2:用於智慧自動完成的註解#

許多現代 IDE 利用類型註解作為 智慧程式碼完成 系統的輸入。其中一個範例是 VSCode 的 Pylance 擴充功能,它使用 Microsoft 的 pyright 靜態類型檢查器作為 VSCode IntelliSense 完成的資訊來源。

這種類型檢查的使用需要比上面使用的簡單別名更進一步;例如,知道 slice 函式傳回名為 ArrayAny 別名,並不會為程式碼完成引擎新增任何有用的資訊。但是,如果我們使用 DeviceArray 傳回類型註解函式,自動完成功能將知道如何填入結果的命名空間,因此能夠在開發過程中建議更相關的自動完成。

JAX 已開始在少數地方新增此層級的類型註解;一個範例是 jax.random 套件中的 jnp.ndarray 傳回類型 [來源]

def shuffle(key: KeyArray, x: Array, axis: int = 0) -> jnp.ndarray:
  ...

在這種情況下,jnp.ndarray 是一個抽象基底類別,它預先宣告 JAX 陣列的屬性和方法 (請參閱來源),因此 VSCode 中的 Pylance 可以針對此函式的結果提供完整的自動完成集。以下螢幕截圖顯示結果

VSCode Intellisense Screenshot

自動完成欄位中列出的是抽象 ndarray 類別宣告的所有方法和屬性。我們將在下面進一步討論為什麼有必要建立這個抽象類別,而不是直接使用 DeviceArray 進行註解。

層級 3:用於靜態類型檢查的註解#

如今,當人們考慮 Python 程式碼中類型註解的目的時,靜態類型檢查通常是人們首先想到的。雖然 Python 不會對類型進行任何執行階段檢查,但存在幾種成熟的靜態類型檢查工具,可以將其作為 CI 測試套件的一部分來執行。對於 JAX 來說,最重要的工具如下

  • python/mypy 或多或少是開放 Python 世界中的標準。 JAX 目前在 Github Actions CI 檢查中對原始碼檔案的子集執行 mypy。

  • google/pytype 是 Google 的靜態類型檢查器,Google 內部依賴 JAX 的專案經常使用它。

  • microsoft/pyright 作為 VSCode 中用於先前提及的 Pylance 完成的靜態類型檢查器非常重要。

完整靜態類型檢查是所有類型註解應用程式中最嚴格的,因為每當您的類型註解不完全正確時,它都會顯示錯誤。一方面,這很好,因為您的靜態類型分析可能會捕獲錯誤的類型註解 (例如,jnp.ndarray 抽象類別中遺失 DeviceArray 方法的情況)。

另一方面,這種嚴格性可能會使類型檢查過程在經常依賴鴨子類型而不是嚴格類型安全 API 的套件中非常脆弱。您目前會在整個 JAX 程式碼庫中發現數百個地方散佈著類似 #type: ignore (用於 mypy) 或 #pytype: disable (用於 pytype) 的程式碼註解。這些通常表示出現類型問題的情況;它們可能是 JAX 類型註解中的不準確之處,或靜態類型檢查器正確遵循程式碼中控制流程的能力不準確。有時,它們是由於 pytype 或 mypy 行為中真實且微妙的錯誤造成的。在極少數情況下,它們可能是由於 JAX 使用的 Python 模式難以甚至不可能用 Python 的靜態類型註解語法來表達。

JAX 的類型註解挑戰#

JAX 目前的類型註解是不同樣式的混合,旨在滿足上述所有三個層級的類型註解。部分原因在於,JAX 的原始程式碼對 Python 的類型註解系統提出了許多獨特的挑戰。我們將在此處概述它們。

挑戰 1:pytype、mypy 和開發人員摩擦#

JAX 目前面臨的一個挑戰是,套件開發必須滿足兩個不同靜態類型檢查系統的約束,即 pytype (內部 CI 和內部 Google 專案使用) 和 mypy (外部 CI 和外部依賴項使用)。儘管這兩個類型檢查器在其行為方面具有廣泛的重疊,但每個檢查器都呈現出自己獨特的邊角案例,JAX 程式碼庫中無數的 #type: ignore#pytype: disable 語句證明了這一點。

這會在開發中產生摩擦:內部貢獻者可能會迭代直到測試通過,但發現匯出時他們經 pytype 批准的程式碼會違反 mypy。對於外部貢獻者來說,情況通常相反:最近的一個例子是 #9596,它在未能通過 Google 內部 pytype 檢查後不得不回滾。每次我們將類型註解從層級 1 (Any 無處不在) 移至層級 2 或 3 (更嚴格的註解) 時,都會為這種令人沮喪的開發人員體驗帶來更多可能性。

挑戰 2:陣列鴨子類型#

註解 JAX 程式碼的一個特殊挑戰是其大量使用鴨子類型。標記為 Array 的函式的輸入通常可以是多種不同類型之一:JAX DeviceArray、NumPy np.ndarray、NumPy 純量、Python 純量、Python 序列、具有 __array__ 屬性的物件、具有 __jax_array__ 屬性的物件,或任何類型的 jax.Tracer。因此,像 def func(x: DeviceArray) 這樣的簡單註解是不夠的,並且會導致許多有效用途的誤報。這表示 JAX 函式的類型註解不會簡短或瑣碎,但我們必須有效地開發一組 JAX 特定的類型擴充功能,類似於 numpy.typing 套件中的擴充功能。

挑戰 3:轉換和裝飾器#

JAX 的 Python API 非常依賴函式轉換 (jit()vmap()grad() 等),而這種 API 類型對靜態類型分析提出了特殊的挑戰。裝飾器的彈性註解一直是 mypy 套件中的一個 長期存在的問題,直到最近才透過引入 ParamSpec 解決,這在 PEP 612 中討論,並在 Python 3.10 中新增。由於 JAX 遵循 NEP 29,因此在 2024 年年中之後的某個時間之前,它無法依賴 Python 3.10 功能。在此期間,協議可以用作此問題的部分解決方案 (JAX 在 #9950 中為 jit 和其他方法新增了此協議),並且 ParamSpec 可以透過 typing_extensions 套件使用 (原型在 #9999 中),儘管這目前揭示了 mypy 中的基本錯誤 (請參閱 python/mypy#12593)。總之,目前尚不清楚 JAX 函式轉換的 API 是否可以在 Python 類型註解工具的目前限制內進行適當的註解。

挑戰 4:陣列註解缺乏精細度#

此處的另一個挑戰是 Python 中所有面向陣列的 API 都有的共同挑戰,並且多年來一直是 JAX 討論的一部分 (請參閱 #943)。類型註解與物件的 Python 類別或類型有關,而在基於陣列的語言中,類別的屬性通常更重要。在 NumPy、JAX 和類似套件的情況下,我們通常希望註解特定的陣列形狀和資料類型。

例如,jnp.linspace 函式的參數必須是純量值,但在 JAX 中,純量由零維陣列表示。因此,為了使註解不會引發誤報,我們必須允許這些參數為任意陣列。jax.random.choice 的第二個參數是另一個範例,當 shape=() 時,其 dtype=int 必須為 。Python 有一個計畫透過可變類型泛型啟用具有此精細度的類型註解 (請參閱 PEP 646,預定用於 Python 3.11),但與 ParamSpec 一樣,對此功能的支援將需要一段時間才能穩定。

同時,有一些第三方專案可能會有所幫助,特別是 google/jaxtyping,但這使用非標準註解,可能不適合註解核心 JAX 程式庫本身。總而言之,陣列類型精細度挑戰不如其他挑戰那麼重要,因為主要影響是類似陣列的註解將不如它們本來可以的那樣具體。

挑戰 5:從 NumPy 繼承的不精確 API#

JAX 面向使用者的 API 的很大一部分是從 jax.numpy 子模組中的 NumPy 繼承而來。NumPy 的 API 是在靜態類型檢查成為 Python 語言一部分的幾年前開發的,並且遵循 Python 的歷史建議,即使用 鴨子類型/EAFP 程式碼編寫風格,其中不鼓勵在執行階段進行嚴格的類型檢查。作為一個具體的範例,請考慮 numpy.tile() 函式,其定義如下

def tile(A, reps):
  try:
    tup = tuple(reps)
  except TypeError:
    tup = (reps,)
  d = len(tup)
  ...

此處的意圖reps 將包含 intint 值序列,但實作允許 tup 為任何可迭代物件。當為這種鴨子類型程式碼新增註解時,我們可以採取以下兩種途徑之一

  1. 我們可以選擇註解函式 API 的意圖,這裡可能類似於 reps: Union[int, Sequence[int]]

  2. 相反,我們可以選擇註解函式的實作,這裡可能看起來類似於 reps: Union[ConvertibleToInt, Iterable[ConvertibleToInt]],其中 ConvertibleToInt 是一個特殊的協議,涵蓋我們的函式將輸入轉換為整數的確切機制 (即透過 __int__、透過 __index__、透過 __array__ 等)。另請注意,從嚴格意義上講,Iterable 在這裡是不夠的,因為 Python 中有一些物件以鴨子類型充當可迭代物件,但不滿足針對 Iterable 的靜態類型檢查 (即,透過 __getitem__ 而不是 __iter__ 可迭代的物件)。

註解意圖 #1 的優點是註解在傳達 API 合約方面對使用者更有用;而對於開發人員來說,彈性為必要時的重構留下了空間。缺點 (特別是對於像 JAX 這樣逐步輸入的 API) 是很可能存在使用者程式碼可以正確執行,但會被類型檢查器標記為不正確。現有鴨子類型 API 的逐步輸入表示目前的註解隱含為 Any,因此將其變更為更嚴格的類型可能會向使用者顯示為重大變更。

廣泛來說,註解意圖更好地服務於層級 1 類型檢查,而註解實作更好地服務於層級 3,而層級 2 更像是混合體 (當涉及到 IDE 中的註解時,意圖和實作都很重要)。

JAX 類型註解路線圖#

考虑到這種框架 (層級 1/2/3) 和 JAX 特定的挑戰,我們可以開始制定我們的路線圖,以在整個 JAX 專案中實作一致的類型註解。

指導原則#

對於 JAX 類型註解,我們將遵循以下原則

類型註解的目的#

我們希望盡可能支援完整的層級 1、2 和 3 類型註解。特別是,這表示我們應該對公共 API 函式的輸入和輸出都具有限制性類型註解。

為意圖加上註解#

JAX 類型註解通常應指出 API 的意圖,而非實作方式,如此註解才能有效地傳達 API 的合約。這表示有時在執行階段有效的輸入,可能無法被靜態類型檢查器識別為有效(一個例子可能是任意迭代器,被傳遞以取代註解為 Shape = Sequence[int] 的形狀)。

輸入應寬鬆地輸入類型#

JAX 函數和方法的輸入應該盡可能寬鬆地輸入類型:例如,雖然形狀通常是元組,但接受形狀的函數應該接受任意序列。同樣地,接受 dtype 的函數不需要 np.dtype 類別的實例,而是任何可轉換為 dtype 的物件。這可能包括字串、內建純量類型,或純量物件建構子,例如 np.float64jnp.float64。為了使整個套件盡可能統一,我們將新增一個 jax.typing 模組,其中包含常見的類型規範,從廣泛的類別開始,例如

  • ArrayLike 將會是可以隱式轉換為陣列的任何事物的聯合:例如,jax 陣列、numpy 陣列、JAX 追蹤器,以及 python 或 numpy 純量

  • DTypeLike 將會是可以隱式轉換為 dtype 的任何事物的聯合:例如,numpy dtypes、numpy dtype 物件、jax dtype 物件、字串和內建類型。

  • ShapeLike 將會是可以轉換為形狀的任何事物的聯合:例如,整數或類整數物件的序列。

  • 等等。

請注意,這些通常會比 numpy.typing 中使用的等效協定更簡單。例如,在 DTypeLike 的情況下,JAX 不支援結構化 dtype,因此 JAX 可以使用更簡單的實作方式。同樣地,在 ArrayLike 中,JAX 通常不支援使用列表或元組輸入來代替陣列,因此類型定義將比 NumPy 的類似物更簡單。

輸出應嚴格地輸入類型#

相反地,函數和方法的輸出應該盡可能嚴格地輸入類型:例如,對於傳回陣列的 JAX 函數,輸出應該使用類似 jnp.ndarray 而非 ArrayLike 的東西進行註解。傳回 dtype 的函數應始終註解為 np.dtype,而傳回形狀的函數應始終為 Tuple[int] 或嚴格類型化的 NamedShape 等效物。為此,我們將在 jax.typing 中實作上述寬鬆類型的幾個嚴格類型化類似物,即

  • ArrayNDArray(見下文)用於類型註解目的,實際上等同於 Union[Tracer, jnp.ndarray],應用於註解陣列輸出。

  • DTypenp.dtype 的別名,或許也具有表示金鑰類型和 JAX 內部使用的其他泛化的能力。

  • Shape 本質上是 Tuple[int, ...],或許還有一些額外的彈性來考量動態形狀。

  • NamedShapeShape 的擴展,允許在 JAX 內部使用具名形狀。

  • 等等。

我們也將探討是否可以放棄目前 jax.numpy.ndarray 的實作,轉而使 ndarray 成為 Array 或類似物的別名。

傾向於簡潔#

除了在 jax.typing 中收集的常見類型協定外,我們應該傾向於簡潔。我們應避免為傳遞至 API 函數的引數建構過於複雜的協定,而是在 API 的完整類型規範無法簡潔指定的情況下,使用簡單的聯合,例如 Union[simple_type, Any]。這是一種折衷方案,在避免不必要的複雜性的前提下,實現了層級 1 和 2 註解的目標,同時放棄了層級 3。

避免不穩定的類型機制#

為了不增加不必要的開發摩擦(由於內部/外部 CI 的差異),我們希望在使用類型註解建構時保持保守:特別是,對於最近引入的機制,例如 ParamSpec (PEP 612) 和可變類型泛型 (PEP 646),我們希望等到 mypy 和其他工具的支援成熟且穩定後,再依賴它們。

其中一個影響是,目前,當函數被 JAX 轉換(如 jitvmapgrad 等)裝飾時,JAX 將有效地剝除裝飾函數的所有註解。雖然這很遺憾,但在撰寫本文時,mypy 在 ParamSpec 提供的潛在解決方案方面存在一長串不相容性(請參閱 ParamSpec mypy 錯誤追蹤器),因此我們判斷目前尚未準備好在 JAX 中完全採用。我們將在未來一旦此類功能的支援穩定後,重新檢視此問題。

同樣地,目前我們將避免新增 jaxtyping 專案提供的更複雜且更精細的陣列類型註解。這是一個我們可以在未來重新檢視的決定。

Array 類型設計考量#

如上所述,JAX 中陣列的類型註解帶來獨特的挑戰,因為 JAX 廣泛使用鴨子類型,即在 jax 轉換中傳遞和傳回 Tracer 物件以取代實際陣列。這變得越來越令人困惑,因為用於類型註解的物件通常與用於執行階段實例檢查的物件重疊,並且可能對應也可能不對應到相關物件的實際類型層次結構。對於 JAX,我們需要為兩個情境提供鴨子類型化的物件:靜態類型註解執行階段實例檢查

以下討論將假設 jax.Array 是裝置上陣列的執行階段類型,目前情況並非如此,但一旦 #12016 中的工作完成,情況就會如此。

靜態類型註解#

我們需要提供一個可用於鴨子類型化類型註解的物件。假設我們暫時將此物件稱為 ArrayAnnotation,我們需要一個滿足 mypypytype 的解決方案,用於以下情況

@jit
def f(x: ArrayAnnotation) -> ArrayAnnotation:
  assert isinstance(x, core.Tracer)
  return x

這可以透過多種方法實現,例如

  • 使用類型聯合:ArrayAnnotation = Union[Array, Tracer]

  • 建立一個介面檔案,宣告 TracerArray 應被視為 ArrayAnnotation 的子類別。

  • 重組 ArrayTracer,使 ArrayAnnotation 成為兩者的真實基底類別。

執行階段實例檢查#

我們也必須提供一個可用於鴨子類型化執行階段 isinstance 檢查的物件。假設我們暫時將此物件稱為 ArrayInstance,我們需要一個通過以下執行階段檢查的解決方案

def f(x):
  return isinstance(x, ArrayInstance)
x = jnp.array([1, 2, 3])
assert f(x)       # x will be an array
assert jit(f)(x)  # x will be a tracer

同樣地,有幾種機制可以用於此目的

  • 覆寫 type(ArrayInstance).__instancecheck__ 以針對 ArrayTracer 物件傳回 True;這就是目前 jnp.ndarray 的實作方式 (來源)。

  • ArrayInstance 定義為抽象基底類別,並動態地將其註冊到 ArrayTracer

  • 重組 ArrayTracer,使 ArrayInstance 成為 ArrayTracer 兩者的真實基底類別

我們需要做出的決定是 ArrayAnnotationArrayInstance 應該是相同還是不同的物件。這裡有一些先例;例如,在核心 Python 語言規範中,存在 typing.Dicttyping.List 是為了註解,而內建的 dictlist 則用於實例檢查的目的。然而,在較新的 Python 版本中,DictList棄用,轉而使用 dictlist 進行註解和實例檢查。

遵循 NumPy 的領先地位#

在 NumPy 的案例中,np.typing.NDArray 用於類型註解的目的,而 np.ndarray 則用於實例檢查(以及陣列類型身分)的目的。鑑於此,遵循 NumPy 的先例並實作以下內容可能是合理的

  • jax.Array 是裝置上陣列的實際類型。

  • jax.typing.NDArray 是用於鴨子類型化陣列註解的物件。

  • jax.numpy.ndarray 是用於鴨子類型化陣列實例檢查的物件。

對於 NumPy 的資深使用者來說,這可能感覺有些自然,但這種三分法可能會造成混淆:用於實例檢查和註解的選擇並不明確。

統一實例檢查和註解#

另一種方法是透過上述覆寫機制來統一類型檢查和註解。

選項 1:部分統一#

部分統一可能如下所示

  • jax.Array 是裝置上陣列的實際類型。

  • jax.typing.Array 是用於鴨子類型化陣列註解的物件(透過 ArrayTracer 上的 .pyi 介面)。

  • jax.typing.Array 也是用於鴨子類型化實例檢查的物件(透過其 metaclass 中的 __isinstance__ 覆寫)

在此方法中,jax.numpy.ndarray 將成為向後相容性的簡單別名 jax.typing.Array

選項 2:透過覆寫完全統一#

或者,我們可以選擇透過覆寫完全統一

  • jax.Array 是裝置上陣列的實際類型。

  • jax.Array 也是用於鴨子類型化陣列註解的物件(透過 Tracer 上的 .pyi 介面)

  • jax.Array 也是用於鴨子類型化實例檢查的物件(透過其 metaclass 中的 __isinstance__ 覆寫)

在這裡,jax.numpy.ndarray 將成為向後相容性的簡單別名 jax.Array

選項 3:透過類別層次結構完全統一#

最後,我們可以選擇透過重組類別層次結構並以 OOP 物件層次結構取代鴨子類型來完全統一

  • jax.Array 是裝置上陣列的實際類型

  • jax.Array 也是用於陣列類型註解的物件,透過確保 Tracer 繼承自 jax.Array

  • jax.Array 也是用於實例檢查的物件,透過相同的機制

在這裡,jnp.ndarray 可以是 jax.Array 的別名。從 OOP 設計的角度來看,最終方法在某種程度上是最純粹的,但它有些牽強(Tracer Array 嗎?)。

選項 4:透過類別層次結構部分統一#

我們可以透過使 Tracer 和裝置上陣列的類別繼承自共同基底類別,使類別層次結構更合理。因此,例如

  • jax.ArrayTracer 以及裝置上陣列的實際類型的基底類別,裝置上陣列的實際類型可能是 jax._src.ArrayImpl 或類似物。

  • jax.Array 是用於陣列類型註解的物件

  • jax.Array 也是用於實例檢查的物件

在這裡,jnp.ndarray 將是 Array 的別名。從 OOP 的角度來看,這可能更純粹,但與選項 2 和 3 相比,它捨棄了 type(x) is jax.Array 將評估為 True 的概念。

評估#

考量每種潛在方法的整體優點和缺點

  • 從使用者的角度來看,統一方法(選項 2 和 3)可以說是最好的,因為它們消除了記住要使用哪個物件進行實例檢查或註解所涉及的認知負荷:jax.Array 是您需要知道的一切。

  • 然而,選項 2 和 3 都引入了一些奇怪和/或令人困惑的行為。選項 2 依賴於可能令人困惑的實例檢查覆寫,這些覆寫對於在 pybind11 中定義的類別支援不佳。選項 3 要求 Tracer 成為陣列的子類別。這打破了繼承模型,因為它會要求 Tracer 物件攜帶 Array 物件的所有包袱(資料緩衝區、分片、裝置等)。

  • 選項 4 在 OOP 意義上更純粹,並且避免了對典型實例檢查或類型註解行為進行任何覆寫的需求。其權衡是裝置上陣列的實際類型變得獨立(此處為 jax._src.ArrayImpl)。但絕大多數使用者永遠不必直接接觸此私有實作。

這裡有不同的權衡,但在討論之後,我們已決定採用選項 4 作為我們的前進方向。

實作計畫#

為了推進類型註解,我們將執行以下操作

  1. 迭代此 JEP 文件,直到開發人員和利害關係人認同。

  2. 建立一個私有的 jax._src.typing(目前不提供任何公開 API),並在其中放入上述第一層簡單類型

    • 暫時將 Array = Any 作為別名,因為這需要更多思考。

    • ArrayLike:作為正常 jax.numpy 函數輸入的有效類型聯合

    • DType / DTypeLike(注意:numpy 使用駝峰式 DType;為了易於使用,我們應該遵循此慣例)

    • Shape / NamedShape / ShapeLike

    此工作的開端已在 #12300 中完成。

  3. 開始著手建立遵循前一節選項 4 的 jax.Array 基底類別。最初這將在 Python 中定義,並使用目前在 jnp.ndarray 實作中找到的動態註冊機制,以確保 isinstance 檢查的正確行為。每個追蹤器和類陣列類別的 pyi 覆寫將確保類型註解的正確行為。jnp.ndarray 隨後可以成為 jax.Array 的別名

  4. 作為測試,根據上述指南,使用這些新的類型定義來全面註解 jax.lax 中的函數。

  5. 一次一個模組地繼續新增其他註解,重點放在公開 API 函數上。

  6. 同時,開始在 pybind11 中重新實作 jax.Array 基底類別,以便 ArrayImplTracer 可以從其繼承。使用 pyi 定義來確保靜態類型檢查器識別類別的適當屬性。

  7. 一旦 jax.Arrayjax._src.ArrayImpl 完全完成,移除這些臨時 Python 實作。

  8. 當一切都完成後,建立一個公開的 jax.typing 模組,使上述類型可供使用者使用,並提供使用 JAX 的程式碼的註解最佳實務文件。

我們將在 #12049 中追蹤此工作,此 JEP 從中獲得其編號。