變更日誌#

最佳瀏覽方式請點擊這裡。如需實驗性 Pallas API 的特定變更,請參閱Pallas 變更日誌

JAX 遵循基於努力的版本控制;關於此項以及 JAX 的 API 相容性政策的討論,請參閱API 相容性。關於 Python 和 NumPy 版本支援政策,請參閱Python 和 NumPy 版本支援政策

未發布#

jax 0.5.0 (2025 年 1 月 17 日)#

在此版本中,JAX 現在使用基於努力的版本控制。由於此版本對 PRNG 金鑰語意進行了重大變更,可能需要使用者更新其程式碼,因此我們將 JAX 的「meso」版本升級以表示這一點。

  • 重大變更

    • 預設啟用 jax_threefry_partitionable (請參閱更新說明)。

    • 此版本不再支援 Mac x86 wheels。Mac ARM 當然仍然受到支援。如需近期討論,請參閱 https://github.com/jax-ml/jax/discussions/22936。

      促成此決策的兩個主要因素

      • Mac x86 版本(僅限)有許多測試失敗和崩潰。我們寧願不發布版本,也不願發布損壞的版本。

      • Mac x86 硬體已停產,目前開發人員無法輕易取得。因此,即使我們想解決這類問題,也很困難。

      如果社群願意協助支援該平台,我們願意重新新增對 Mac x86 的支援:特別是,在我們再次發布版本之前,我們需要 JAX 測試套件在 Mac x86 上乾淨俐落地通過。

  • 變更

    • NumPy 最低版本現在為 1.25。NumPy 1.25 將維持最低支援版本,直到 2025 年 6 月。

    • SciPy 最低版本現在為 1.11。SciPy 1.11 將維持最低支援版本,直到 2025 年 6 月。

    • jax.numpy.einsum() 現在預設為 optimize='auto' 而非 optimize='optimal'。這避免了在多個引數的情況下,trace-time 呈指數級擴展 (#25214)。

    • jax.numpy.linalg.solve() 不再支援右側的批次 1D 引數。若要在這些情況下恢復先前的行為,請使用 solve(a, b[..., None]).squeeze(-1)

  • 新功能

  • 棄用

    • jax.interpreters.xlaabstractifypytype_aval_mappings 現在已棄用,已被 jax.core 中同名的符號取代。

    • jax.scipy.special.lpmn()jax.scipy.special.lpmn_values() 已棄用,原因是它們在 SciPy v1.15.0 中已棄用。目前沒有計畫用新的 API 取代這些已棄用的函數。

    • jax.extend.ffi 子模組已移至 jax.ffi,先前的匯入路徑已棄用。

  • 刪除

    • jax_enable_memories 旗標已刪除,且該旗標的行為預設為開啟。

    • jax.lib.xla_client 中,先前已棄用的 DeviceXlaRuntimeError 符號已移除;請改用 jax.Devicejax.errors.JaxRuntimeError

    • 在 JAX v0.4.32 中棄用後,jax.experimental.array_api 模組已移除。自該版本以來,jax.numpy 直接支援 array API。

jax 0.4.38 (2024 年 12 月 17 日)#

  • 重大變更

    • XlaExecutable.cost_analysis 現在傳回 dict[str, float] (而不是單一元素的 list[dict[str, float]])。

  • 變更

    • jax.tree.flatten_with_pathjax.tree.map_with_path 已新增為對應 tree_util 函數的捷徑。

  • 棄用

    • 內部 jax.core 命名空間中的許多 API 已棄用。大多數是 no-ops、很少使用,或可以由 jax.extend.core 中同名的 API 取代;請參閱 jax.extend 的文件,以取得有關這些半公開擴充功能的相容性保證資訊。

    • 已移除數個先前已棄用的 API,包括

      • jax.corecheck_eqncheck_typecheck_valid_jaxtypenon_negative_dim

      • jax.lib.xla_bridgexla_clientdefault_backend

      • jax.lib.xla_client_xlabfloat16

      • jax.numpyround_

  • 新功能

jax 0.4.37 (2024 年 12 月 9 日)#

這是 jax 0.4.36 的修補程式版本。此版本僅發布「jax」。

jax 0.4.37#

  • 錯誤修復

    • 修正了如果引數命名為 f 時,jit 會發生錯誤的錯誤 (#25329)。

    • 修正了如果使用者為 flatten 和 flatten_with_path 註冊具有不同輔助資料的 pytree 節點類別,則在 jax.lax.while_loop() 中會擲回 index out of range 錯誤的錯誤。

    • 釘選了新的 libtpu 版本 (0.0.6),修正了 TPU v6e 上的編譯器錯誤。

jax 0.4.36 (2024 年 12 月 5 日)#

  • 重大變更

    • 此版本採用了「stackless」,這是 JAX 追蹤機制的內部變更。我們使追蹤分派純粹成為上下文的函數,而不是上下文和資料的函數。這讓我們刪除了大量用於管理資料相依追蹤的機制:levels、sublevels、post_process_callnew_base_maincustom_bind 等等。此變更應僅影響使用 JAX 內部機制的使用者。

      如果您使用了 JAX 內部元件,則可能需要更新您的程式碼 (請參閱 https://github.com/jax-ml/jax/commit/c36e1f7c1ad4782060cbc8e8c596d85dfb83986f 以取得如何操作的線索)。使用 JAX 函式庫的內部元件也可能存在版本偏差問題。如果您發現此變更破壞了您未使用 JAX 內部元件的程式碼,請嘗試使用 config.jax_data_dependent_tracing_fallback 旗標作為權宜之計,如果您在更新程式碼時需要協助,請提交錯誤報告。

    • 自 2024 年 7 月起,JAX 版本 0.4.31 已棄用搭配 native_serialization=Falseenable_xla=Falsejax.experimental.jax2tf.convert()。現在我們已移除對這些使用案例的支援。jax2tf 搭配原生序列化仍將受到支援。

    • jax.interpreters.xla 中,xbxcxe 符號在 JAX v0.4.31 中被棄用後已移除。請改用 xb = jax.lib.xla_bridgexc = jax.lib.xla_clientxe = jax.lib.xla_extension

    • 已移除已棄用的模組 jax.experimental.export。它在 JAX v0.4.30 中被 jax.export 取代。請參閱遷移指南,以取得有關遷移至新 API 的資訊。

    • 已移除 jax.nn.softmax()jax.nn.log_softmax()initial 參數,該參數已在 v0.4.27 中棄用。

    • 現在對類型化的 PRNG 金鑰 (即由 :func:jax.random.key 產生的金鑰) 呼叫 np.asarray 會引發錯誤。 之前,這會傳回純量物件陣列。

    • 已移除 jax.export 中以下已棄用的方法和函式

      • jax.export.DisabledSafetyCheck.shape_assertions:它已經沒有任何作用。

      • jax.export.Exported.lowering_platforms:請改用 platforms

      • jax.export.Exported.mlir_module_serialization_version:請改用 calling_convention_version

      • jax.export.Exported.uses_shape_polymorphism:請改用 uses_global_constants

      • 用於 jax.export.export()lowering_platforms 關鍵字參數:請改用 platforms

    • 已移除 jax.export.symbolic_args_specs() 中的關鍵字參數 symbolic_scopesymbolic_constraints。它們已於 2024 年 6 月棄用。請改用 scopeconstraints

    • 追蹤器 (tracer) 的雜湊 (hashing) 功能已在 0.4.30 版本中棄用,現在會導致 TypeError

    • 重構:JAX 建置 CLI (build/build.py) 現在使用子命令結構,並取代先前的 build.py 用法。執行 python build/build.py --help 以取得更多詳細資訊。新子命令選項的簡要概述

      • build:建置 JAX wheel 套件。例如,python build/build.py build --wheels=jaxlib,jax-cuda-pjrt

      • requirements_update:更新 requirements_lock.txt 檔案。

    • jax.scipy.linalg.toeplitz() 現在對多維輸入執行隱含的批次處理 (batching)。若要恢復先前的行為,您可以對函式輸入呼叫 jax.numpy.ravel()

    • jax.scipy.special.gamma()jax.scipy.special.gammasgn() 現在針對負整數輸入傳回 NaN,以符合 SciPy 的行為 (來自 https://github.com/scipy/scipy/pull/21827)。

    • 在 v0.4.26 中棄用後,jax.clear_backends 已移除。

    • 我們從保證匯出穩定性的自訂呼叫清單中移除了自訂呼叫 “__gpu$xla.gpu.triton”。這是因為此自訂呼叫依賴於 Triton IR,而 Triton IR 並不保證穩定。如果您需要匯出使用此自訂呼叫的程式碼,您可以使用 disabled_checks 參數。 請參閱文件以取得更多詳細資訊。

  • 新功能

  • 錯誤修復

    • 修正了 GPU 實作的 LU 和 QR 分解在批次大小接近 int32 最大值時會導致索引溢位的錯誤。請參閱 #24843 以取得更多詳細資訊。

  • 棄用

    • jax.lib.xla_extension.ArrayImpljax.lib.xla_client.ArrayImpl 已棄用;請改用 jax.Array

    • jax.lib.xla_extension.XlaRuntimeError 已棄用;請改用 jax.errors.JaxRuntimeError

jax 0.4.35 (2024 年 10 月 22 日)#

  • 重大變更

    • jax.numpy.isscalar() 現在針對任何零維的類陣列物件傳回 True。先前,它僅針對具有弱 dtype 的零維類陣列物件傳回 True。

    • 自 2024 年 3 月起,JAX 版本 0.4.26 已棄用 jax.experimental.host_callback。現在我們已將其移除。請參閱 #20385 以取得替代方案的討論。

  • 變更

    • jax.lax.FftType 作為 FFT 運算的列舉公開名稱引入。jax.lib.xla_client.FftType 這個半公開 API 已棄用。

    • TPU:JAX 現在從 libtpu 套件而不是 libtpu-nightly 套件安裝 TPU 支援。在接下來的幾個版本中,JAX 將釘住 libtpu-nightlylibtpu 的空版本,以簡化過渡;該依賴性將在 2025 年第一季移除。

  • 棄用

    • jax.lib.xla_client.PaddingType 這個半公開 API 已棄用。沒有 JAX API 使用此類型,因此沒有替代方案。

    • vmap 下,jax.pure_callback()jax.extend.ffi.ffi_call() 的預設行為已被棄用,並且這些函式的 vectorized 參數也被棄用。應改用 vmap_method 參數以獲得更明確定義的行為。請參閱 #23881 中的討論以取得更多詳細資訊。

    • jax.lib.xla_client.register_custom_call_target 這個半公開 API 已棄用。請改用 JAX FFI。

    • jax.lib.xla_client.dtype_to_etypejax.lib.xla_client.opsjax.lib.xla_client.shape_from_pyvaljax.lib.xla_client.PrimitiveTypejax.lib.xla_client.Shapejax.lib.xla_client.XlaBuilderjax.lib.xla_client.XlaComputation 這些半公開 API 已棄用。請改用 StableHLO。

jax 0.4.34 (2024 年 10 月 4 日)#

  • 新功能

    • 此版本包含適用於 Python 3.13 的 wheel 檔案。目前尚不支援自由執行緒模式。

    • 已新增 jax.errors.JaxRuntimeError 作為先前私有的 XlaRuntimeError 類型的公開別名。

  • 重大變更

    • jax_pmap_no_rank_reduction 旗標預設設定為 True

      • 現在對 pmap 結果執行 array[0] 會引入 reshape (請改用 array[0:1])。

      • 每個分片形狀 (可透過 jax_array.addressable_shards 或 jax_array.addressable_data(0) 存取) 現在具有前導 (1, …)。請相應地更新直接存取分片的程式碼。每個分片形狀的秩 (rank) 現在與全域形狀的秩相符,這與 jit 的行為相同。這避免了將結果從 pmap 傳遞到 jit 時產生高成本的 reshape。

    • 自 2024 年 3 月起,JAX 版本 0.4.26 已棄用 jax.experimental.host_callback。現在我們將 --jax_host_callback_legacy 組態值的預設值設定為 True,這表示如果您的程式碼使用 jax.experimental.host_callback API,則這些 API 呼叫將以新的 jax.experimental.io_callback API 實作。如果這破壞了您的程式碼,在非常有限的時間內,您可以將 --jax_host_callback_legacy 設定為 True。我們很快就會移除該組態選項,因此您應該改為轉換為使用新的 JAX 回呼 API。請參閱 #20385 以取得討論。

  • 棄用

    • jax.numpy.trim_zeros() 中,非類陣列引數或 ndim != 1 的類陣列引數現在已棄用,並且在未來將導致錯誤。

    • 在 JAX v0.4.30 中棄用後,已移除內部美化列印工具 jax.core.pp_*

    • jax.lib.xla_client.Device 已棄用;請改用 jax.Device

    • jax.lib.xla_client.XlaRuntimeError 已棄用。請改用 jax.errors.JaxRuntimeError

  • 刪除

    • 已刪除 jax.xla_computation。自 0.4.30 JAX 版本中棄用以來已 3 個月。請使用 AOT API 以取得與 jax.xla_computation 相同的功能。

      • jax.xla_computation(fn)(*args, **kwargs) 可以替換為 jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')

      • 您也可以使用 jax.stages.Lowered.out_info 屬性來取得輸出資訊 (例如樹狀結構、形狀和 dtype)。

      • 對於跨後端降低 (lowering),您可以將 jax.xla_computation(fn, backend='tpu')(*args, **kwargs) 替換為 jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')

    • jax.ShapeDtypeStruct 不再接受 named_shape 參數。該參數僅由 xmap 使用,而 xmap 已在 0.4.31 中移除。

    • jax.tree.map(f, None, non-None) 先前發出 DeprecationWarning,現在在未來版本的 jax 中會引發錯誤。None 僅是其自身的樹狀結構前綴 (tree-prefix)。若要保留目前的行為,您可以要求 jax.tree.mapNone 視為葉節點值,方法是寫入:jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)

    • jax.sharding.XLACompatibleSharding 已移除。請使用 jax.sharding.Sharding

  • 錯誤修復

    • 修正了如果提供非布林值輸入並指定 dtype=booljax.numpy.cumsum() 會產生不正確輸出的錯誤。

    • 編輯 jax.numpy.ldexp() 的實作以取得正確的梯度。

jax 0.4.33 (2024 年 9 月 16 日)#

這是 jax 0.4.32 之上的修補程式版本,修正了該版本中發現的兩個錯誤。

在 JAX 0.4.32 釘選的 libtpu 版本中發現了一個僅限於 TPU 的資料損壞錯誤,該錯誤僅在同一個作業中存在多個 TPU 切片時才會顯現,例如,如果在多個 v5e 切片上進行訓練。此版本透過釘選 libtpu 的固定版本來修正該問題。

此版本修正了 CPU 上 F64 tanh 的不準確結果 (#23590)。

jax 0.4.32 (2024 年 9 月 11 日)#

注意:此版本已從 PyPi 撤下,因為 TPU 上存在資料損壞錯誤。請參閱 0.4.33 版本注意事項以取得更多詳細資訊。

  • 新功能

  • 變更

    • jax_enable_memories 旗標預設設定為 True

    • jax.numpy 現在支援 Python Array API Standard 的 v2023.12 版本。請參閱 Python Array API 標準 以取得更多資訊。

    • 在更多情況下,CPU 後端上的計算現在可以非同步分派。先前,非平行計算始終同步分派。您可以透過設定 jax.config.update('jax_cpu_enable_async_dispatch', False) 來恢復舊的行為。

    • 新增了新的 jax.process_indices() 函式,以取代在 JAX v0.2.13 中棄用的 jax.host_ids() 函式。

    • 為了與 numpy.fabs 的行為對齊,jax.numpy.fabs 已修改為不再支援 complex dtypes

    • 如果 nodetype 是 dataclass,則 jax.tree_util.register_dataclass 現在會檢查 data_fieldsmeta_fields 是否包含所有具有 init=True 的 dataclass 欄位,並且僅包含這些欄位。

    • 多個 jax.numpy 函式現在具有完整的 ufunc 介面,包括 addmultiplybitwise_andbitwise_orbitwise_xorlogical_andlogical_andlogical_and。在未來的版本中,我們計劃將這些擴展到其他 ufunc。

    • 新增了 jax.lax.optimization_barrier(),允許使用者防止編譯器最佳化 (例如常見子表達式消除) 並控制排程。

  • 重大變更

    • MHLO MLIR 方言 (jax.extend.mlir.mhlo) 已移除。請改用 stablehlo 方言。

  • 棄用

    • 在 JAX v0.4.27 中棄用後,不再允許 jax.numpy.clip()jax.numpy.hypot() 的複數輸入。

    • 已棄用以下 API

      • jax.lib.xla_bridge.xla_client:請直接使用 jax.lib.xla_client

      • jax.lib.xla_bridge.get_backend:請使用 jax.extend.backend.get_backend()

      • jax.lib.xla_bridge.default_backend:請使用 jax.extend.backend.default_backend()

    • jax.experimental.array_api 模組已棄用,不再需要匯入它才能使用 Array API。jax.numpy 直接支援 array API;請參閱 Python Array API 標準 以取得更多資訊。

    • 內部工具程式 jax.core.check_eqnjax.core.check_typejax.core.check_valid_jaxtype 現在已棄用,並將在未來版本中移除。

    • 在 NumPy 2.0 中移除對應的 API 後,jax.numpy.round_ 已棄用。請改用 jax.numpy.round()

    • 將 DLPack capsule 傳遞給 jax.dlpack.from_dlpack() 已棄用。jax.dlpack.from_dlpack() 的引數應該是來自另一個實作 __dlpack__ 協定的框架的陣列。

jaxlib 0.4.32 (2024 年 9 月 11 日)#

注意:此版本已從 PyPi 撤下,因為 TPU 上存在資料損壞錯誤。請參閱 0.4.33 版本注意事項以取得更多詳細資訊。

  • 重大變更

    • 此版本的 jaxlib 切換到新的 CPU 後端版本,該版本應能更快地編譯並更好地利用平行處理。如果您因本次變更而遇到任何問題,可以暫時啟用舊的 CPU 後端,方法是設定環境變數 XLA_FLAGS=--xla_cpu_use_thunk_runtime=false。如果您需要執行此操作,請提交 JAX 錯誤報告,並附上重現步驟。

    • 新增了 Hermetic CUDA 支援。Hermetic CUDA 使用特定的可下載 CUDA 版本,而不是使用者本機安裝的 CUDA。 Bazel 將下載 CUDA、CUDNN 和 NCCL 發行版本,然後在各種 Bazel 目標中使用 CUDA 函式庫和工具作為依賴項。這為 JAX 及其支援的 CUDA 版本實現了更可重現的建置。

  • 變更

    • 新增了 SparseCore 分析功能。

      • JAX 現在支援在 TPUv5p 晶片上分析 SparseCore。這些追蹤將可在 Tensorboard Profiler 的 TraceViewer 中檢視。

jax 0.4.31 (2024 年 7 月 29 日)#

  • 刪除

    • xmap 已刪除。請使用 shard_map() 作為替代方案。

  • 變更

    • 最低 CuDNN 版本為 v9.1。這在先前的版本中也是如此,但我們現在正式宣告此版本約束。

    • 最低 Python 版本現在為 3.10。3.10 將保持最低支援版本直到 2025 年 7 月。

    • 最低 NumPy 版本現在為 1.24。NumPy 1.24 將保持最低支援版本直到 2024 年 12 月。

    • 最低 SciPy 版本現在為 1.10。SciPy 1.10 將保持最低支援版本直到 2025 年 1 月。

    • 現在 jax.numpy.ceil()jax.numpy.floor()jax.numpy.trunc() 會回傳與輸入相同資料類型 (dtype) 的輸出,也就是說,不再將整數或布林值輸入向上轉型 (upcast) 為浮點數。

    • libdevice.10.bc 不再與 CUDA wheels 捆綁。它必須作為本地 CUDA 安裝的一部分安裝,或透過 NVIDIA 的 CUDA pip wheels 安裝。

    • 現在 jax.experimental.pallas.BlockSpec 預期 block_shapeindex_map 之前 傳遞。舊的參數順序已被棄用,並將在未來版本中移除。

    • 更新了 GPU 裝置的 repr 表示方式,使其與 TPU/CPU 更一致。例如,cuda(id=0) 現在會顯示為 CudaDevice(id=0)

    • jax.Array 新增了 device 屬性和 to_device 方法,作為 JAX Array API 支援的一部分。

  • 棄用

    • 移除了許多先前已棄用的與多型形狀 (polymorphic shapes) 相關的內部 API。從 jax.core 中:移除了 canonicalize_shapedimension_as_valuedefinitely_equalsymbolic_equal_dim

    • HLO lowering 規則不應再將 singleton ir.Values 包裹在元組中。而是回傳未包裹的 singleton ir.Values。未來版本的 JAX 將移除對包裹值的支援。

    • 使用 native_serialization=Falseenable_xla=Falsejax.experimental.jax2tf.convert() 現在已被棄用,此支援將在未來版本中移除。自 JAX 0.4.16 (2023 年 9 月) 以來,原生序列化一直是預設設定。

    • 先前已棄用的函式 jax.random.shuffle 已被移除;請改用 jax.random.permutation 並搭配 independent=True

jaxlib 0.4.31 (2024 年 7 月 29 日)#

  • 錯誤修復

    • 修正了一個錯誤,該錯誤導致 jit 調度快速路徑錯誤處理了 jit 的負 static_argnums。

    • 修正了一個錯誤,該錯誤導致奇異矩陣批次的三角解產生無意義的有限值,而不是 inf 或 nan (#3589, #15429)。

jax 0.4.30 (2024 年 6 月 18 日)#

  • 變更

    • JAX 支援 ml_dtypes >= 0.2。在 0.4.29 版本中,ml_dtypes 版本已升級至 0.4.0,但在此版本中已回退,以便讓 TensorFlow 和 JAX 的使用者有更多時間遷移到較新的 TensorFlow 版本。

    • jax.experimental.mesh_utils 現在可以為 TPU v5e 建立有效率的網格 (mesh)。

    • jax 現在直接依賴 jaxlib。此變更由 CUDA 外掛程式切換啟用:不再有多個 jaxlib 變體。您可以使用 pip install jax 安裝僅限 CPU 的 jax,無需額外套件。

    • 新增了用於匯出和序列化 JAX 函式的 API。此功能以前存在於 jax.experimental.export (正在棄用),現在將位於 jax.export 中。請參閱文件

  • 棄用

    • 內部美化列印 (pretty-printing) 工具 jax.core.pp_* 已被棄用,並將在未來版本中移除。

    • tracer 的雜湊 (Hashing) 已被棄用,並將在未來 JAX 版本中導致 TypeError。先前情況如此,但在最近幾個 JAX 版本中出現了無意的回歸。

    • jax.experimental.export 已被棄用。請改用 jax.export。請參閱遷移指南

    • 在大多數情況下,現在不建議使用陣列來代替 dtype 傳遞;例如,對於陣列 xyx.astype(y) 將引發警告。若要靜音警告,請使用 x.astype(y.dtype)

    • jax.xla_computation 已被棄用,並將在未來版本中移除。請使用 AOT API 來取得與 jax.xla_computation 相同的功能。

      • jax.xla_computation(fn)(*args, **kwargs) 可以替換為 jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')

      • 您也可以使用 jax.stages.Lowered.out_info 屬性來取得輸出資訊 (例如樹狀結構、形狀和 dtype)。

      • 對於跨後端降低 (lowering),您可以將 jax.xla_computation(fn, backend='tpu')(*args, **kwargs) 替換為 jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')

jaxlib 0.4.30 (2024 年 6 月 18 日)#

  • 已移除對 monolithic CUDA jaxlibs 的支援。您必須使用基於外掛程式的安裝 (pip install jax[cuda12]pip install jax[cuda12_local])。

jax 0.4.29 (2024 年 6 月 10 日)#

  • 變更

    • 我們預期這將是 JAX 和 jaxlib 最後一個支援 monolithic CUDA jaxlib 的版本。未來版本將使用 CUDA 外掛程式 jaxlib (例如 pip install jax[cuda12])。

    • JAX 現在需要 ml_dtypes 版本 0.4.0 或更新版本。

    • 移除了對舊版 jax.experimental.export API 用法的向後相容性支援。不再可能使用 from jax.experimental.export import export,您應該改用 from jax.experimental import export。移除的功能自 0.4.24 以來已被棄用。

    • jax.tree.all()jax.tree_util.tree_all() 新增了 is_leaf 參數。

  • 棄用

    • jax.sharding.XLACompatibleSharding 已被棄用。請使用 jax.sharding.Sharding

    • jax.experimental.Exported.in_shardings 已重新命名為 jax.experimental.Exported.in_shardings_hloout_shardings 也做了相同的更名。舊名稱將在 3 個月後移除。

    • 移除了許多先前已棄用的 API

      • jax.core 中:non_negative_dimDimSizeShape

      • jax.lax 中:tie_in

      • jax.nn 中:normalize

      • jax.interpreters.xla 中:backend_specific_translationstranslationsregister_translationxla_destructureTranslationRuleTranslationContextXlaOp

    • jax.numpy.linalg.matrix_rank()tol 參數已被棄用,並將很快移除。請改用 rtol

    • jax.numpy.linalg.pinv()rcond 參數已被棄用,並將很快移除。請改用 rtol

    • 已移除已棄用的 jax.config 子模組。若要設定 JAX,請使用 import jax,然後透過 jax.config 參考 config 物件。

    • jax.random API 不再接受批次處理的金鑰 (keys),儘管先前有些 API 無意中接受了。展望未來,我們建議在這種情況下明確使用 jax.vmap()

    • jax.scipy.special.beta() 中,為了與其他 beta API 一致,xy 參數已重新命名為 ab

  • 新功能

    • 新增了 jax.experimental.Exported.in_shardings_jax(),以從儲存在 Exported 物件中的 HloShardings 建構可用於 JAX API 的分片 (shardings)。

jaxlib 0.4.29 (2024 年 6 月 10 日)#

  • 錯誤修復

    • 修正了一個錯誤,該錯誤導致 XLA 錯誤地分片了一些串聯 (concatenation) 操作,這表現為累積歸約 (cumulative reductions) 的輸出不正確 (#21403)。

    • 修正了一個錯誤,該錯誤導致 XLA:CPU 錯誤編譯了某些矩陣乘法融合 (matmul fusions) (https://github.com/openxla/xla/pull/13301)。

    • 修正了 GPU 上的編譯器崩潰問題 (https://github.com/jax-ml/jax/issues/21396)。

  • 棄用

    • jax.tree.map(f, None, non-None) 現在會發出 DeprecationWarning,並將在未來版本的 jax 中引發錯誤。None 僅是其自身的樹狀結構前綴 (tree-prefix)。若要保留目前行為,您可以要求 jax.tree.mapNone 視為葉值,方法是寫入:jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)

jax 0.4.28 (2024 年 5 月 9 日)#

  • 錯誤修復

    • 還原了對 make_jaxpr 的變更,該變更破壞了 Equinox (#21116)。

  • 棄用與移除

  • 變更

    • 此版本的最低 jaxlib 版本為 0.4.27。

jaxlib 0.4.28 (2024 年 5 月 9 日)#

  • 錯誤修復

    • 修正了 Python 3.10 或更早版本中 Array 和 JIT Python 物件的類型名稱中的記憶體損壞錯誤。

    • 修正了 CUDA 12.4 下的警告 '+ptx84' is not a recognized feature for this target

    • 修正了 CPU 上編譯速度緩慢的問題。

  • 變更

    • Windows 版本現在使用 Clang 而非 MSVC 建置。

jax 0.4.27 (2024 年 5 月 7 日)#

  • 新功能

    • 新增了 jax.numpy.unstack()jax.numpy.cumulative_sum(),遵循它們在 array API 2023 標準中的新增,該標準即將被 NumPy 採用。

    • 新增了一個新的 config 選項 jax_cpu_collectives_implementation,以選擇 CPU 後端使用的跨進程集體運算 (cross-process collective operations) 的實作方式。可用的選項包括 'none' (預設)、'gloo''mpi' (需要 jaxlib 0.4.26)。如果設定為 'none',則會停用跨進程集體運算。

  • 變更

    • 現在 jax.pure_callback()jax.experimental.io_callback()jax.debug.callback() 使用 jax.Array 而非 np.ndarray。您可以透過在將引數傳遞給 callback 之前,透過 jax.tree.map(np.asarray, args) 轉換引數來恢復舊行為。

    • 現在 complex_arr.astype(bool) 遵循與 NumPy 相同的語義,當 complex_arr 等於 0 + 0j 時回傳 False,否則回傳 True。

    • 現在 core.Token 是一個非平凡類別,它包裹了 jax.Array。它可以被建立並在計算中傳入和傳出,以建立依賴關係。singleton 物件 core.token 已被移除,使用者現在應該建立並使用新的 core.Token 物件。

    • 在 GPU 上,Threefry PRNG 實作預設不再降低到核心呼叫 (kernel call)。此選擇可以改善執行時記憶體使用量,但會犧牲編譯時間成本。可以使用 jax.config.update('jax_threefry_gpu_kernel_lowering', True) 恢復產生核心呼叫的先前行為。如果新的預設設定導致問題,請提交錯誤報告。否則,我們打算在未來版本中移除此旗標。

  • 棄用與移除

    • Pallas 現在專門使用 XLA 在 GPU 上編譯核心。透過 Triton Python API 的舊 lowering pass 已被移除,JAX_TRITON_COMPILE_VIA_XLA 環境變數不再有任何作用。

    • jax.numpy.clip() 有一個新的引數簽章:aa_mina_max 已被棄用,改用 x (僅限位置引數)、minmax (#20550)。

    • JAX 陣列的 device() 方法已移除,自 JAX v0.4.21 以來已被棄用。請改用 arr.devices()

    • 已棄用 jax.nn.softmax()jax.nn.log_softmax()initial 參數;現在支援 softmax 的空輸入,而無需設定此參數。

    • jax.jit() 中,傳遞無效的 static_argnumsstatic_argnames 現在會導致錯誤,而不是警告。

    • 現在最低 jaxlib 版本為 0.4.23。

    • 當傳遞複數值輸入給 jax.numpy.hypot() 函式時,現在會發出棄用警告。當棄用完成時,這將引發錯誤。

    • 現在 jax.numpy.nonzero()jax.numpy.where() 和相關函式的純量引數會引發錯誤,這與 NumPy 中的類似變更一致。

    • config 選項 jax_cpu_enable_gloo_collectives 已被棄用。請改用 jax.config.update('jax_cpu_collectives_implementation', 'gloo')

    • 在 JAX v0.4.22 中棄用後,已移除 jax.Array.device_bufferjax.Array.device_buffers 方法。請改用 jax.Array.addressable_shardsjax.Array.addressable_data()

    • 現在 jax.numpy.whereconditionxy 參數僅限於位置引數,這與 JAX v0.4.21 中棄用關鍵字參數一致。

    • 現在必須透過關鍵字指定 jax.lax.linalg 中函式的非陣列引數。先前,這會引發 DeprecationWarning。

    • 現在在幾個 :func:jax.numpy API 中需要類陣列引數,包括 apply_along_axis()apply_over_axes()inner()outer()cross()kron()lexsort()

  • 錯誤修復

    • copy=True 時,jax.numpy.astype() 現在始終會回傳副本。先前,當輸出陣列與輸入陣列具有相同的 dtype 時,不會建立副本。這可能會導致一些記憶體使用量增加。預設值設定為 copy=False,以保留向後相容性。

jaxlib 0.4.27 (2024 年 5 月 7 日)#

jax 0.4.26 (2024 年 4 月 3 日)#

  • 新功能

  • 變更

    • 現在複數值 jax.numpy.geomspace() 選擇與 NumPy 2.0 一致的對數螺旋分支 (logarithmic spiral branch)。

    • lax.rng_bit_generator 的行為,進而 'rbg''unsafe_rbg' PRNG 實作在 jax.vmap 下的行為 已變更,因此對金鑰進行映射 (mapping) 只會從批次中的第一個金鑰產生隨機數。

    • 文件現在使用 jax.random.key 來建構 PRNG 金鑰陣列,而不是 jax.random.PRNGKey

  • 棄用與移除

    • jax.tree_map() 已被棄用;請改用 jax.tree.map,或為了與舊版 JAX 向後相容,請使用 jax.tree_util.tree_map()

    • jax.clear_backends() 已被棄用,因為它不一定會執行其名稱所暗示的操作,並且可能會導致意想不到的後果,例如,它不會銷毀現有的後端並釋放相應的擁有資源。如果您只想清理編譯快取,請使用 jax.clear_caches()。為了向後相容性,或者如果您真的需要切換/重新初始化預設後端,請使用 jax.extend.backend.clear_backends()

    • jax.experimental.maps 模組和 jax.experimental.maps.xmap 已被棄用。請使用 jax.experimental.shard_mapjax.vmap 以及 spmd_axis_name 引數來表達 SPMD 裝置平行計算。

    • jax.experimental.host_callback 模組已被棄用。請改用新的 JAX 外部 callback。新增了 JAX_HOST_CALLBACK_LEGACY 旗標,以協助轉換到新的 callback。請參閱 #20385 以進行討論。

    • 現在將無法轉換為 JAX 陣列的引數傳遞給 jax.numpy.array_equal()jax.numpy.array_equiv() 會導致例外。

    • 已移除已棄用的旗標 jax_parallel_functions_output_gda。此旗標早已被棄用,且沒有任何作用;使用它是空操作 (no-op)。

    • 先前已棄用的匯入 jax.interpreters.ad.configjax.interpreters.ad.source_info_util 現在已被移除。請改用 jax.configjax.extend.source_info_util

    • JAX 匯出不再支援舊版的序列化版本。自 2023 年 10 月 27 日起已支援版本 9,並自 2024 年 2 月 1 日起成為預設版本。請參閱 版本說明。此變更可能會破壞設定低於 9 的特定 JAX 序列化版本的用戶端。

jaxlib 0.4.26 (2024 年 4 月 3 日)#

  • 變更

    • JAX 現在僅支援 CUDA 12.1 或更新版本。已停止支援 CUDA 11.8。

    • JAX 現在支援 NumPy 2.0。

jax 0.4.25 (2024 年 2 月 26 日)#

  • 新功能

  • 變更

    • Pallas 現在使用 XLA 而非 Triton Python API 來編譯 Triton 核心。您可以將 JAX_TRITON_COMPILE_VIA_XLA 環境變數設定為 "0",以還原為舊的行為。

    • 在 v0.4.24 版本中移除的 jax.interpreters.xla 中數個已棄用的 API,已在 v0.4.25 版本中重新加入,包括 backend_specific_translationstranslationsregister_translationxla_destructureTranslationRuleTranslationContextXLAOp。這些 API 仍被視為已棄用,並將在未來有更好的替代方案時再次移除。請參閱 #19816 以了解更多討論。

  • 棄用與移除

    • jax.numpy.linalg.solve() 現在針對 b.ndim > 1 的批次 1D 求解顯示棄用警告。未來,這些將被視為批次 2D 求解。

    • 將非純量陣列轉換為 Python 純量現在會引發錯誤,無論陣列大小為何。先前,對於大小為 1 的非純量陣列,會引發棄用警告。這遵循 NumPy 中類似的棄用。

    • 先前已棄用的組態 API 已按照標準的 3 個月棄用週期移除(請參閱 API 相容性)。這些包括

      • jax.config.config 物件,以及

      • jax.configdefine_*_stateDEFINE_* 方法。

    • 透過 import jax.config 匯入 jax.config 子模組已被棄用。若要設定 JAX,請使用 import jax,然後透過 jax.config 參考組態物件。

    • 最低 jaxlib 版本現在為 0.4.20。

jaxlib 0.4.25 (2024 年 2 月 26 日)#

jax 0.4.24 (2024 年 2 月 6 日)#

  • 變更

    • JAX 降級至 StableHLO 不再依賴實體裝置。如果您的基本運算包裝了自訂分割 (custom_partitioning) 或 JAX 回呼 (callbacks) 在降級規則中,即傳遞給 mlir.register_loweringrule 參數的函式,則將您的基本運算新增至 jax._src.dispatch.prim_requires_devices_during_lowering 集合。這是必要的,因為自訂分割和 JAX 回呼需要實體裝置才能在降級期間建立 Sharding。這是一個暫時狀態,直到我們可以在沒有實體裝置的情況下建立 Sharding

    • jax.numpy.argsort()jax.numpy.sort() 現在支援 stabledescending 引數。

    • 形狀多型 (shape polymorphism) 處理的一些變更(用於 jax.experimental.jax2tfjax.experimental.export 中)

      • 更清晰的符號運算式美觀列印 (#19227)

      • 新增了在維度變數上指定符號約束的功能。這使得形狀多型更具表現力,並提供了一種方法來解決關於不等式推理的限制。請參閱 https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints。

      • 隨著符號約束的加入 (#19235),我們現在認為來自不同範圍的維度變數是不同的,即使它們具有相同的名稱。來自不同範圍的符號運算式無法互動,例如在算術運算中。jax.experimental.jax2tf.convert()jax.experimental.export.symbolic_shape()jax.experimental.export.symbolic_args_specs() 引入了範圍。符號運算式 e 的範圍可以使用 e.scope 讀取,並傳遞到上述函式中,以指示它們在給定範圍內建構符號運算式。請參閱 https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints。

      • 簡化且更快速的相等性比較,當我們考慮兩個符號維度的差的正規化形式簡化為 0 時,則認為它們相等 (#19231;請注意,這可能會導致使用者可見的行為變更)

      • 改進了不確定不等式比較的錯誤訊息 (#19235)。

      • 已棄用 core.non_negative_dim API (最近引入),並引入了 core.max_dimcore.min_dim (#18953) 以表示符號維度的 maxmin。您可以使用 core.max_dim(d, 0) 來代替 core.non_negative_dim(d)

      • 已棄用 shape_poly.is_poly_dim,改用 export.is_symbolic_dim (#19282)。

      • 已棄用 export.args_specs,改用 export.symbolic_args_specs ({jax-issue}#19283`)。

      • 已棄用 shape_poly.PolyShapejax2tf.PolyShape,針對多型形狀規格使用字串 (#19284)。

      • JAX 預設原生序列化版本現在為 9。這與 jax.experimental.jax2tfjax.experimental.export 相關。請參閱 版本號碼描述

    • 重構了 jax.experimental.export 的 API。現在您應該使用 from jax.experimental import export,而不是 from jax.experimental.export import export。舊的匯入方式在 3 個月的棄用期內將繼續運作。

    • 新增了 jax.scipy.stats.sem()

    • 具有 return_inverse = Truejax.numpy.unique() 會傳回重新塑形為輸入維度的反向索引,這遵循 NumPy 2.0 中 numpy.unique() 的類似變更。

    • jax.numpy.sign() 現在針對非零複數輸入傳回 x / abs(x)。這與 NumPy 2.0 版本中 numpy.sign() 的行為一致。

    • 具有 return_sign=Truejax.scipy.special.logsumexp() 現在針對複數符號使用 NumPy 2.0 慣例 x / abs(x)。這與 SciPy v1.13 中 scipy.special.logsumexp() 的行為一致。

    • JAX 現在支援匯入和匯出布林值 DLPack 類型。先前,布林值無法匯入,且會匯出為整數。

  • 棄用與移除

    • 許多先前已棄用的函式已按照標準的 3+ 個月棄用週期移除(請參閱 API 相容性)。這包括

      • 來自 jax.core 的:TracerArrayConversionErrorTracerIntegerConversionErrorUnexpectedTracerErroras_hashable_functioncollectionsdtypeslumapnamedtuplepartialpprefsafe_zipsafe_mapsource_info_utiltotal_orderingtraceback_utiltuple_deletetuple_insertzip

      • 來自 jax.lax 的:dtypesitertoolsnaryopnaryop_dtype_rulestandard_abstract_evalstandard_naryopstandard_primitivestandard_unopunopunop_dtype_rule

      • jax.linear_util 子模組及其所有內容。

      • jax.prng 子模組及其所有內容。

      • 來自 jax.random 的:PRNGKeyArrayKeyArraydefault_prng_implthreefry_2x32threefry2x32_keythreefry2x32_prbg_keyunsafe_rbg_key

      • 來自 jax.tree_util 的:register_keypathsAttributeKeyPathEntryGetItemKeyPathEntry

      • 來自 jax.interpreters.xla 的:backend_specific_translationstranslationsregister_translationxla_destructureTranslationRuleTranslationContextaxis_groupsShapedArrayConcreteArrayAxisEnvbackend_compileXLAOp

      • 來自 jax.numpy 的:NINFNZEROPZEROrow_stackissubsctypetrapzin1d

      • 來自 jax.scipy.linalg 的:triltriu

    • 先前已棄用的方法 PRNGKeyArray.unsafe_raw_array 已移除。改為使用 jax.random.key_data()

    • bool(empty_array) 現在會引發錯誤,而不是傳回 False。這先前會引發棄用警告,並且遵循 NumPy 中類似的變更。

    • 已棄用對 mhlo MLIR 方言的支援。JAX 不再使用 mhlo 方言,而是改用 stablehlo。未來將移除引用 “mhlo” 的 API。請改用 “stablehlo” 方言。

    • jax.random:直接將批次金鑰傳遞給隨機數產生函式,例如 bits()gamma() 等等,已被棄用,並將發出 FutureWarning。針對明確的批次處理,請使用 jax.vmap

    • jax.lax.tie_in() 已棄用:自 JAX v0.2.0 以來,它已不再執行任何操作。

jaxlib 0.4.24 (2024 年 2 月 6 日)#

  • 變更

    • JAX 現在支援 CUDA 12.3 和 CUDA 11.8。已移除對 CUDA 12.2 的支援。

    • cost_analysis 現在可與跨編譯的 Compiled 物件搭配使用(即當使用具有拓撲物件的 .lower().compile() 時,例如從非 TPU 電腦編譯以用於 Cloud TPU)。

    • 新增 CUDA 陣列介面 匯入支援 (需要 jax 0.4.25)。

jax 0.4.23 (2023 年 12 月 13 日)#

jaxlib 0.4.23 (2023 年 12 月 13 日)#

  • 修正了在編譯期間導致 GPU 編譯器產生詳細記錄的錯誤。

jax 0.4.22 (2023 年 12 月 13 日)#

  • 棄用

    • JAX 陣列的 device_bufferdevice_buffers 屬性已棄用。明確的緩衝區已被更彈性的陣列分片介面取代,但之前的輸出可以透過以下方式恢復

      • arr.device_buffer 變為 arr.addressable_data(0)

      • arr.device_buffers 變為 [x.data for x in arr.addressable_shards]

jaxlib 0.4.22 (2023 年 12 月 13 日)#

jax 0.4.21 (2023 年 12 月 4 日)#

  • 新功能

  • 變更

    • 最低 jaxlib 版本現在為 0.4.19。

    • 發布的 wheel 現在使用 clang 而非 gcc 建置。

    • 強制在呼叫 jax.distributed.initialize() 之前,裝置後端尚未初始化。

    • 在 Cloud TPU 環境中自動化 jax.distributed.initialize() 的引數。

  • 棄用

    • 先前已棄用的 sym_pos 引數已從 jax.scipy.linalg.solve() 中移除。改為使用 assume_a='pos'

    • None 傳遞給 jax.array()jax.asarray(),無論是直接傳遞還是清單或元組中傳遞,都已被棄用,現在會引發 FutureWarning。目前它會轉換為 NaN,未來將引發 TypeError

    • 以關鍵字引數方式將 conditionxy 參數傳遞給 jax.numpy.where 已被棄用,以符合 numpy.where

    • 將無法轉換為 JAX 陣列的引數傳遞給 jax.numpy.array_equal()jax.numpy.array_equiv() 已被棄用,現在會引發 DeprecationWaning。目前,函式傳回 False,未來將引發例外狀況。

    • JAX 陣列的 device() 方法已棄用。根據上下文,它可能會被以下其中之一取代

      • jax.Array.devices() 傳回陣列使用的所有裝置的集合。

      • jax.Array.sharding 提供陣列使用的分片組態。

jaxlib 0.4.21 (2023 年 12 月 4 日)#

  • 變更

    • 為了準備新增分散式 CPU 支援,JAX 現在將 CPU 裝置視為與 GPU 和 TPU 裝置相同,也就是說

      • jax.devices() 包括分散式作業中的所有裝置,即使是那些非本機程序的裝置。jax.local_devices() 仍然只包括本機程序的裝置,因此如果對 jax.devices() 的變更中斷了您的程式碼,您很可能想要改用 jax.local_devices()

      • CPU 裝置現在在分散式作業中收到全域唯一 ID 號碼;先前 CPU 裝置會收到程序本機 ID 號碼。

      • 每個 CPU 裝置的 process_index 現在將與同一程序中的任何 GPU 或 TPU 裝置相符;先前 CPU 裝置的 process_index 始終為 0。

    • 在 NVIDIA GPU 上,JAX 現在針對最大 1024x1024 的矩陣優先使用 Jacobi SVD 求解器。Jacobi 求解器似乎比非 Jacobi 版本更快。

  • 錯誤修復

    • 修復了當具有非有限值的陣列傳遞給非對稱特徵值分解 (#18226) 時的錯誤/掛起。具有非有限值的陣列現在會產生充滿 NaN 的陣列作為輸出。

jax 0.4.20 (2023 年 11 月 2 日)#

jaxlib 0.4.20 (2023 年 11 月 2 日)#

  • 錯誤修復

    • 修復了 E4M3 和 E5M2 float8 類型之間的一些類型混淆。

jax 0.4.19 (2023 年 10 月 19 日)#

  • 新功能

    • 新增了 jax.typing.DTypeLike,可用於註解可轉換為 JAX dtype 的物件。

    • 新增了 jax.numpy.fill_diagonal

  • 變更

    • JAX 現在需要 SciPy 1.9 或更新版本。

  • 錯誤修復

    • 在多控制器分散式 JAX 程式中,只有程序 0 會寫入持久編譯快取項目。這修正了當快取放置在網路檔案系統 (例如 GCS) 上時的寫入競爭。

    • cusolver 和 cufft 的版本檢查不再考慮修補程式版本,以判斷已安裝的這些程式庫版本是否至少與建置 JAX 時使用的版本一樣新。

jaxlib 0.4.19 (2023 年 10 月 19 日)#

  • 變更

    • jaxlib 現在將始終優先選擇 pip 安裝的 NVIDIA CUDA 程式庫 (nvidia-… 套件) 而不是任何其他 CUDA 安裝(如果已安裝),包括在 LD_LIBRARY_PATH 中命名的安裝。如果這導致問題且意圖是使用系統安裝的 CUDA,則修復方法是移除 pip 安裝的 CUDA 程式庫套件。

jax 0.4.18 (2023 年 10 月 6 日)#

jaxlib 0.4.18 (2023 年 10 月 6 日)#

  • 變更

    • CUDA jaxlib 現在依賴使用者安裝相容的 NCCL 版本。如果使用建議的 cuda12_pip 安裝,則應自動安裝 NCCL。目前,需要 NCCL 2.16 或更新版本。

    • 我們現在提供 Linux aarch64 wheel,包含和不包含 NVIDIA GPU 支援。

    • jax.Array.item() 現在支援選用的索引引數。

  • 棄用

    • 已棄用 jax.lax 中的許多內部公用程式和意外匯出,並將在未來版本中移除。

      • jax.lax.dtypes:改為使用 jax.dtypes

      • jax.lax.itertools:改為使用 itertools

      • naryopnaryop_dtype_rulestandard_abstract_evalstandard_naryopstandard_primitivestandard_unopunopunop_dtype_rule 是內部公用程式,現在已棄用,沒有替代方案。

  • 錯誤修復

    • 修復了 Cloud TPU 迴歸,其中編譯會因 smem 而導致 OOM。

jax 0.4.17 (2023 年 10 月 3 日)#

  • 新功能

  • 棄用

    • 移除了已棄用的模組 jax.abstract_arrays 及其所有內容。

    • jax.random 中的具名金鑰建構函式已棄用。改為將 impl 引數傳遞給 jax.random.PRNGKey()jax.random.key()

      • random.threefry2x32_key(seed) 變為 random.PRNGKey(seed, impl='threefry2x32')

      • random.rbg_key(seed) 變為 random.PRNGKey(seed, impl='rbg')

      • random.unsafe_rbg_key(seed) 變為 random.PRNGKey(seed, impl='unsafe_rbg')

  • 變更

    • CUDA:JAX 現在驗證其找到的 CUDA 程式庫是否至少與建置 JAX 時使用的 CUDA 程式庫一樣新。如果找到較舊的程式庫,JAX 會引發例外狀況,因為這比神秘的失敗和崩潰更可取。

    • 移除了「找不到 GPU/TPU」警告。而是改為在 Linux 上,如果找到 NVIDIA GPU 或 Google TPU 但未使用,且未指定 --jax_platforms 時發出警告。

    • jax.scipy.stats.mode() 現在,如果跨大小為 0 的軸取得模式,則會傳回 0 計數,這與 SciPy 1.11 中 scipy.stats.mode 的行為相符。

    • 大多數 jax.numpy 函式和屬性現在都具有完整定義的類型存根。先前,這些函式和屬性中的許多都被靜態類型檢查器 (例如 mypypytype) 視為 Any

jaxlib 0.4.17 (Oct 3, 2023)#

  • 變更

    • 此版本新增了 Python 3.12 wheel。

    • CUDA 12 wheel 現在需要 CUDA 12.2 或更新版本,以及 cuDNN 8.9.4 或更新版本。

  • 錯誤修復

    • 修正了初始化 JAX CPU 後端時,ABSL 產生的過多日誌訊息。

jax 0.4.16 (Sept 18, 2023)#

  • 變更

    • 新增了 jax.numpy.ufunc,以及 jax.numpy.frompyfunc(),它可以將任何純量值函數轉換為類似 numpy.ufunc() 的物件,並具有 outer()reduce()accumulate()at()reduceat() 等方法 (#17054)。

    • 新增了 jax.scipy.integrate.trapezoid()

    • 當不在 IPython 下執行時:當引發例外時,JAX 現在會從回溯中過濾掉其所有內部框架。(不包含先前出現的「未過濾堆疊追蹤」。)這應該會產生更友善的回溯外觀。請參閱此處的範例。此行為可以透過設定 JAX_TRACEBACK_FILTERING=remove_frames (用於兩個獨立的未過濾/已過濾回溯,這是舊行為)或 JAX_TRACEBACK_FILTERING=off (用於一個未過濾回溯)來變更。

    • jax2tf 預設序列化版本現在為 7,它引入了新的形狀安全斷言

    • 傳遞至 jax.sharding.Mesh 的裝置應該是可雜湊的。這特別適用於模擬裝置或使用者建立的裝置。jax.devices() 已經是可雜湊的。

  • 重大變更

    • jax2tf 現在預設使用原生序列化。請參閱 jax2tf 文件以瞭解詳細資訊和覆寫預設機制的資訊。

    • 選項 --jax_coordination_service 已移除。現在始終為 True

    • jax.jaxpr_util 已從公開的 JAX 命名空間中移除。

    • JAX_USE_PJRT_C_API_ON_TPU 不再起作用(即,始終預設為 true)。

    • 2021 年 12 月引入的回溯相容性旗標 --jax_host_callback_ad_transforms 已移除。

  • 棄用

    • 依照 NumPy NEP-52,已棄用數個 jax.numpy API。

      • jax.numpy.NINF 已棄用。請改用 -jax.numpy.inf

      • jax.numpy.PZERO 已棄用。請改用 0.0

      • jax.numpy.NZERO 已棄用。請改用 -0.0

      • jax.numpy.issubsctype(x, t) 已棄用。請改用 jax.numpy.issubdtype(x.dtype, t)

      • jax.numpy.row_stack 已棄用。請改用 jax.numpy.vstack

      • jax.numpy.in1d 已棄用。請改用 jax.numpy.isin

      • jax.numpy.trapz 已棄用。請改用 jax.scipy.integrate.trapezoid

    • jax.scipy.linalg.triljax.scipy.linalg.triu 已依照 SciPy 棄用。請改用 jax.numpy.triljax.numpy.triu

    • jax.lax.prod 在 JAX v0.4.11 中棄用後已移除。請改用內建的 math.prod

    • 與定義自訂 JAX 原始運算的 HLO 降低規則相關的 jax.interpreters.xla 中的一些匯出已棄用。自訂原始運算應改為使用 jax.interpreters.mlir 中的 StableHLO 降低公用程式來定義。

    • 以下先前已棄用的函式在三個月的棄用期後已移除

      • jax.abstract_arrays.ShapedArray:請使用 jax.core.ShapedArray

      • jax.abstract_arrays.raise_to_shaped:請使用 jax.core.raise_to_shaped

      • jax.numpy.alltrue:請使用 jax.numpy.all

      • jax.numpy.sometrue:請使用 jax.numpy.any

      • jax.numpy.product:請使用 jax.numpy.prod

      • jax.numpy.cumproduct:請使用 jax.numpy.cumprod

  • 已棄用/移除

    • 內部子模組 jax.prng 現在已棄用。其內容可在 jax.extend.random 中找到。

    • 內部子模組路徑 jax.linear_util 已棄用。請改用 jax.extend.linear_utiljax.extend:擴充模組的一部分)

    • jax.random.PRNGKeyArrayjax.random.KeyArray 已棄用。請使用 jax.Array 進行類型註釋,並使用 jax.dtypes.issubdtype(arr.dtype, jax.dtypes.prng_key) 進行類型化 prng 金鑰的執行階段偵測。

    • 方法 PRNGKeyArray.unsafe_raw_array 已棄用。請改用 jax.random.key_data()

    • jax.experimental.pjit.with_sharding_constraint 已棄用。請改用 jax.lax.with_sharding_constraint

    • 內部公用程式 jax.core.is_opaque_dtypejax.core.has_opaque_dtype 已移除。不透明 dtype 已重新命名為擴充 dtype;請改用 jnp.issubdtype(dtype, jax.dtypes.extended) (自 jax v0.4.14 起可用)。

    • 公用程式 jax.interpreters.xla.register_collective_primitive 已移除。此公用程式在最近的 JAX 版本中沒有任何作用,可以安全地移除對它的呼叫。

    • 內部子模組路徑 jax.linear_util 已棄用。請改用 jax.extend.linear_utiljax.extend:擴充模組的一部分)

jaxlib 0.4.16 (Sept 18, 2023)#

  • 變更

    • 透過實驗性 jax sparse API 進行的稀疏 CSR 矩陣乘法,在 NVIDIA GPU 上不再使用確定性演算法。進行此變更是為了提高與 CUDA 12.2.1 的相容性。

  • 錯誤修復

    • 修正了 Windows 上由於與錯序區段和 IMAGE_REL_AMD64_ADDR32NB 重定位相關的嚴重 LLVM 錯誤而導致的當機問題 (https://github.com/openxla/xla/commit/cb732a921f0c4184995cbed82394931011d12bd4)。

jax 0.4.14 (July 27, 2023)#

  • 變更

    • jax.jit 接受 donate_argnames 作為引數。它的語意與 static_argnames 類似。如果未提供 donate_argnums 和 donate_argnames,則不會捐贈任何引數。如果未提供 donate_argnums 但提供了 donate_argnames,反之亦然,JAX 會使用 inspect.signature(fun) 來尋找與 donate_argnames 對應的任何位置引數(反之亦然)。如果同時提供了 donate_argnums 和 donate_argnames,則不會使用 inspect.signature,並且只會捐贈 donate_argnums 或 donate_argnames 中列出的實際參數。

    • jax.random.gamma() 已重新設計為更有效率的演算法,具有更穩健的端點行為 (#16779)。這表示對於給定的 key,JAX v0.4.13 和 v0.4.14 之間 gamma 和相關取樣器(包括 jax.random.ball()jax.random.beta()jax.random.chisquare()jax.random.dirichlet()jax.random.generalized_normal()jax.random.loggamma()jax.random.t())傳回的值序列將會變更。

  • 刪除

    • in_axis_resourcesout_axis_resources 已從 pjit 中刪除,因為它們已棄用超過 3 個月。請使用 in_shardingsout_shardings 作為替代。這是安全且微不足道的名稱替換。它不會變更任何目前的 pjit 語意,也不會破壞任何程式碼。您仍然可以將 PartitionSpecs 傳遞至 in_shardings 和 out_shardings。

  • 棄用

    • 已依照 https://jax.dev.org.tw/en/latest/deprecation.html 停止支援 Python 3.8

    • JAX 現在依照 https://jax.dev.org.tw/en/latest/deprecation.html 需要 NumPy 1.22 或更新版本

    • 不再支援透過位置將選用引數傳遞至 jax.numpy.ndarray.at(),此功能已在 JAX 版本 0.4.7 中棄用。例如,請使用 x.at[i].get(indices_are_sorted=True) 而非 x.at[i].get(True)

    • 以下 jax.Array 方法已在 JAX v0.4.5 中棄用後移除

    • 以下 API 在先前棄用後已移除

      • jax.ad:請使用 jax.interpreters.ad

      • jax.curry:請使用 curry = lambda f: partial(partial, f)

      • jax.partial_eval:請使用 jax.interpreters.partial_eval

      • jax.pxla:請使用 jax.interpreters.pxla

      • jax.xla:請使用 jax.interpreters.xla

      • jax.ShapedArray:請使用 jax.core.ShapedArray

      • jax.interpreters.pxla.device_put:請使用 jax.device_put()

      • jax.interpreters.pxla.make_sharded_device_array:請使用 jax.make_array_from_single_device_arrays()

      • jax.interpreters.pxla.ShardedDeviceArray:請使用 jax.Array

      • jax.numpy.DeviceArray:請使用 jax.Array

      • jax.stages.Compiled.compiler_ir:請使用 jax.stages.Compiled.as_text()

  • 重大變更

    • JAX 現在需要 ml_dtypes 版本 0.2.0 或更新版本。

    • 為了修正邊角案例,如果第二個和第三個引數是可呼叫的,即使其他運算元也是可呼叫的,則對具有五個引數的 jax.lax.cond() 的呼叫將始終解析為「通用運算元」cond 行為(如文件中所述)。請參閱 #16413

    • 已移除已棄用的組態選項 jax_arrayjax_jit_pjit_api_merge,它們沒有任何作用。這些選項在許多版本中預設為 true。

  • 新功能

    • JAX 現在支援組態旗標 –jax_serialization_version 和 JAX_SERIALIZATION_VERSION 環境變數,以控制序列化版本 (#16746)。

    • 如果序列化版本至少為 7,則 jax2tf 在存在形狀多型性的情況下,現在會產生檢查特定形狀約束的程式碼。請參閱 https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism。

jaxlib 0.4.14 (July 27, 2023)#

  • 棄用

    • 已依照 https://jax.dev.org.tw/en/latest/deprecation.html 停止支援 Python 3.8

jax 0.4.13 (June 22, 2023)#

  • 變更

    • jax.jit 現在允許將 None 傳遞至 in_shardingsout_shardings。語意如下

      • 對於 in_shardings,JAX 會將其標記為已複製,但此行為在未來可能會變更。

      • 對於 out_shardings,我們將依賴 XLA GSPMD 分割器來判斷輸出分片。

    • jax.experimental.pjit.pjit 也允許將 None 傳遞至 in_shardingsout_shardings。語意如下

      • 如果提供網格上下文管理器,則 JAX 可以自由選擇它想要的任何分片。

        • 對於 in_shardings,JAX 會將其標記為已複製,但此行為在未來可能會變更。

        • 對於 out_shardings,我們將依賴 XLA GSPMD 分割器來判斷輸出分片。

      • 如果提供了網格上下文管理器,則 None 將表示值將在網格的所有裝置上複製。

    • Executable.cost_analysis() 在 Cloud TPU 上運作

    • 如果使用未列入允許清單的 jaxlib 外掛程式,則新增警告。

    • 新增了 jax.tree_util.tree_leaves_with_path

    • None 不是 jax.experimental.multihost_utils.host_local_array_to_global_arrayjax.experimental.multihost_utils.global_array_to_host_local_array 的有效輸入。如果您想要複製輸入,請使用 jax.sharding.PartitionSpec()

  • 錯誤修復

    • 修正了 CUDA 12 版本中不正確的 wheel 名稱 (#16362);正確的 wheel 名稱是 cudnn89 而不是 cudnn88

  • 棄用

    • jax.experimental.jax2tf.convert()native_serialization_strict_checks 參數已棄用,改為新的 native_serializaation_disabled_checks (#16347)。

jaxlib 0.4.13 (June 22, 2023)#

  • 變更

    • 將 Windows 僅 CPU wheel 新增至 jaxlib Pypi 版本。

  • 錯誤修復

    • __cuda_array_interface__ 在先前的 jaxlib 版本中已損壞,現在已修正 (#16440)。

    • 現在預設在 NVIDIA GPU 上啟用並行 CUDA 核心追蹤。

jax 0.4.12 (June 8, 2023)#

  • 變更

  • 棄用

    • jax.abstract_arrays 及其內容現在已棄用。請參閱 :mod:jax.core 中的相關功能。

    • jax.numpy.alltrue:請使用 jax.numpy.all。這遵循 NumPy 版本 1.25.0 中 numpy.alltrue 的棄用。

    • jax.numpy.sometrue:請使用 jax.numpy.any。這遵循 NumPy 版本 1.25.0 中 numpy.sometrue 的棄用。

    • jax.numpy.product:請使用 jax.numpy.prod。這遵循 NumPy 版本 1.25.0 中 numpy.product 的棄用。

    • jax.numpy.cumproduct:請使用 jax.numpy.cumprod。這遵循 NumPy 版本 1.25.0 中 numpy.cumproduct 的棄用。

    • jax.sharding.OpShardingSharding 已移除,因為它已棄用 3 個月。

jaxlib 0.4.12 (June 8, 2023)#

  • 變更

    • 包含適用於 Hopper (SM 版本 9.0+) GPU 的 PTX/SASS。舊版 jaxlib 應可在 Hopper 上運作,但在第一次執行 JAX 運算時,JIT 編譯延遲時間會很長。

  • 錯誤修復

    • 修正了 Python 3.11 下 JAX 產生的 Python 回溯中不正確的原始程式碼行資訊。

    • 修正了在 JAX 產生的 Python 回溯中列印框架的區域變數時發生的當機問題 (#16027)。

jax 0.4.11 (May 31, 2023)#

  • 棄用

    • 以下 API 已依照 API 相容性 政策,在 3 個月的棄用期後移除

      • jax.experimental.PartitionSpec:請使用 jax.sharding.PartitionSpec

      • jax.experimental.maps.Mesh:請使用 jax.sharding.Mesh

      • jax.experimental.pjit.NamedSharding:請使用 jax.sharding.NamedSharding

      • jax.experimental.pjit.PartitionSpec:請使用 jax.sharding.PartitionSpec

      • jax.experimental.pjit.FROM_GDA。請改為傳遞分片的 jax.Array 物件作為輸入,並移除 pjit 的選用 in_shardings 引數。

      • jax.interpreters.pxla.PartitionSpec:請使用 jax.sharding.PartitionSpec

      • jax.interpreters.pxla.Mesh:請使用 jax.sharding.Mesh

      • jax.interpreters.xla.Buffer:請使用 jax.Array

      • jax.interpreters.xla.Device:請使用 jax.Device

      • jax.interpreters.xla.DeviceArray:請使用 jax.Array

      • jax.interpreters.xla.device_put:請使用 jax.device_put

      • jax.interpreters.xla.xla_call_p:請使用 jax.experimental.pjit.pjit_p

      • 已移除 with_sharding_constraintaxis_resources 引數。請改用 shardings

jaxlib 0.4.11 (May 31, 2023)#

  • 變更

    • Device 新增了 memory_stats() 方法。如果支援,這會傳回字串統計名稱的 dict 和 int 值,例如 "bytes_in_use",如果平台不支援記憶體統計資訊,則傳回 None。傳回的確切統計資訊可能因平台而異。目前僅在 Cloud TPU 上實作。

    • 重新新增了對 CPU 裝置上 Python 緩衝區協定 (memoryview) 的支援。

jax 0.4.10 (May 11, 2023)#

jaxlib 0.4.10 (May 11, 2023)#

  • 變更

    • 修正了 'apple-m1' is not a recognized processor for this target (ignoring processor) 問題,該問題阻止了先前版本在 Mac M1 上執行。

jax 0.4.9 (May 9, 2023)#

  • 變更

    • 已移除旗標 experimental_cpp_jit、experimental_cpp_pjit 和 experimental_cpp_pmap。它們現在始終處於開啟狀態。

    • 已改善 TPU 上奇異值分解 (SVD) 的準確性 (需要 jaxlib 0.4.9)。

  • 棄用

    • jax.experimental.gda_serialization 已棄用,並已重新命名為 jax.experimental.array_serialization。請變更您的匯入以使用 jax.experimental.array_serialization

    • 已棄用 pjit 的 in_axis_resourcesout_axis_resources 引數。請分別使用 in_shardingsout_shardings

    • 函式 jax.numpy.msort 已移除。它自 JAX v0.4.1 起已棄用。請改用 jnp.sort(a, axis=0)

    • 已從 jax.xla_computation 中移除 in_partsout_parts 引數,因為它們僅用於 sharded_jit,而 sharded_jit 早已消失。

    • 已從 jax.xla_computation 中移除 instantiate_const_outputs 引數,因為它已閒置很長時間。

jaxlib 0.4.9 (May 9, 2023)#

jax 0.4.8 (March 29, 2023)#

  • 重大變更

    • Cloud TPU 執行階段的主要元件已升級。這在 Cloud TPU 上啟用了以下新功能

      jax.experimental.host_callback() 在具有新執行階段元件的 Cloud TPU 上不再受支援。如果新的 jax.debug API 不足以滿足您的使用案例,請在 JAX 問題追蹤器上提交問題。

      舊的執行階段元件將在至少未來三個月內透過設定環境變數 JAX_USE_PJRT_C_API_ON_TPU=false 提供。如果您發現需要因任何原因停用新執行階段,請在 JAX 問題追蹤器上告知我們。

  • 變更

    • 最低 jaxlib 版本已從 0.4.6 提升至 0.4.7。

  • 棄用

    • 已停止支援 CUDA 11.4。JAX GPU wheel 僅支援 CUDA 11.8 和 CUDA 12。如果 jaxlib 是從原始碼建置的,則舊版 CUDA 可能會運作。

    • pmap 的 global_arg_shapes 引數僅適用於 sharded_jit,並且已從 pmap 中移除。請移轉至 pjit 並從 pmap 中移除 global_arg_shapes。

jax 0.4.7 (March 27, 2023)#

  • 變更

    • 依照 https://jax.dev.org.tw/en/latest/jax_array_migration.html#jax-array-migration jax.config.jax_array 無法再停用。

    • jax.config.jax_jit_pjit_api_merge 無法再停用。

    • jax.experimental.jax2tf.convert() 現在支援 native_serialization 參數,以使用 JAX 的原生降低至 StableHLO,來取得整個 JAX 函式的 StableHLO 模組,而不是將每個 JAX 原始運算降低為 TensorFlow 運算。這簡化了內部結構,並提高了您序列化的內容與 JAX 原生語意相符的信心。請參閱 文件。作為此變更的一部分,組態旗標 --jax2tf_default_experimental_native_lowering 已重新命名為 --jax2tf_native_serialization

    • JAX 現在依賴 ml_dtypes,其中包含 NumPy 類型(例如 bfloat16)的定義。這些定義先前是 JAX 的內部定義,但已拆分為個別套件,以方便與其他專案共用。

    • JAX 現在需要 NumPy 1.21 或更新版本以及 SciPy 1.7 或更新版本。

  • 棄用

    • 類型 jax.numpy.DeviceArray 已棄用。請改用 jax.Array,它是它的別名。

    • 類型 jax.interpreters.pxla.ShardedDeviceArray 已棄用。請改用 jax.Array

    • 不再建議依位置將其他引數傳遞至 jax.numpy.ndarray.at()。例如,請使用 x.at[i].get(indices_are_sorted=True) 而非 x.at[i].get(True)

    • jax.interpreters.xla.device_put 已棄用。請使用 jax.device_put

    • jax.interpreters.pxla.device_put 已棄用。請使用 jax.device_put

    • jax.experimental.pjit.FROM_GDA 已棄用。請傳入分片的 jax.Arrays 作為輸入,並移除 pjit 的 in_shardings 引數,因為它是選用的。

jaxlib 0.4.7 (March 27, 2023)#

變更

  • jaxlib 現在依賴 ml_dtypes,其中包含 NumPy 類型(例如 bfloat16)的定義。這些定義先前是 JAX 的內部定義,但已拆分為個別套件,以方便與其他專案共用。

jax 0.4.6 (Mar 9, 2023)#

  • 變更

    • jax.tree_util 現在包含一組 API,允許使用者為其自訂 pytree 節點定義鍵。這包括:

      • tree_flatten_with_path,其會扁平化樹狀結構,且不僅傳回每個葉節點,還會傳回其鍵路徑。

      • tree_map_with_path,其可以映射一個將鍵路徑作為參數的函式。

      • register_pytree_with_keys,用於註冊自訂 pytree 節點中的鍵路徑和葉節點應呈現的樣貌。

      • keystr,用於美觀地列印鍵路徑。

    • jax2tf.call_tf() 有一個新的參數 output_shape_dtype (預設為 None),可用於宣告結果的輸出形狀和資料型別。這使得 jax2tf.call_tf() 能夠在形狀多態的情況下運作。( #14734 )。

  • 棄用

    • jax.tree_util 中的舊版鍵路徑 API 已被棄用,並將在 2023 年 3 月 10 日起 3 個月後移除。

jaxlib 0.4.6 (2023 年 3 月 9 日)#

jax 0.4.5 (2023 年 3 月 2 日)#

  • 棄用

    • jax.sharding.OpShardingSharding 已重新命名為 jax.sharding.GSPMDShardingjax.sharding.OpShardingSharding 將在 2023 年 2 月 17 日起 3 個月後移除。

    • 以下 jax.Array 方法已被棄用,並將在 2023 年 2 月 23 日起 3 個月後移除:

jax 0.4.4 (2023 年 2 月 16 日)#

  • 變更

    • jitpjit 的實作已合併。合併 jit 和 pjit 變更了 JAX 的內部結構,但不影響 JAX 的公開 API。先前,jit 是一種最終樣式基本運算。最終樣式表示 jaxpr 的建立會盡可能延遲,且轉換會堆疊在彼此之上。透過 jit-pjit 實作合併,jit 變成一種初始樣式基本運算,這表示我們會盡可能提早追蹤至 jaxpr。如需更多資訊,請參閱 autodidax 中的此章節。移至初始樣式應可簡化 JAX 的內部結構,並使動態形狀等功能的開發更加容易。您只能透過環境變數停用它,例如 os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'。必須透過環境變數停用合併,因為它會在匯入時影響 JAX,因此需要在匯入 jax 之前停用。

    • with_sharding_constraintaxis_resources 引數已被棄用。請改用 shardings。如果您將 axis_resources 作為 arg 使用,則無需變更。如果您將其作為 kwarg 使用,請改用 shardingsaxis_resources 將在 2023 年 2 月 13 日起 3 個月後移除。

    • 新增了 jax.typing 模組,其中包含 JAX 函式型別註解的工具。

    • 以下名稱已被棄用:

      • jax.xla.Devicejax.interpreters.xla.Device:請使用 jax.Device

      • jax.experimental.maps.Mesh。請改用 jax.sharding.Mesh

      • jax.experimental.pjit.NamedSharding:請使用 jax.sharding.NamedSharding

      • jax.experimental.pjit.PartitionSpec:請使用 jax.sharding.PartitionSpec

      • jax.interpreters.pxla.Mesh:請使用 jax.sharding.Mesh

      • jax.interpreters.pxla.PartitionSpec:請使用 jax.sharding.PartitionSpec

  • 重大變更

    • reduction 函式 (如 :func:jax.numpy.sum) 的 initial 引數現在必須是純量,這與對應的 NumPy API 一致。先前針對非純量 initial 值廣播輸出的行為是一種非預期的實作細節 ( #14446 )。

jaxlib 0.4.4 (2023 年 2 月 16 日)#

  • 重大變更

    • 預設 jaxlib 組建已移除對 NVIDIA Kepler 系列 GPU 的支援。如果需要 Kepler 支援,仍然可以從來源組建具有 Kepler 支援的 jaxlib (透過 build.py--cuda_compute_capabilities=sm_35 選項),但請注意 CUDA 12 已完全停止對 Kepler GPU 的支援。

jax 0.4.3 (2023 年 2 月 8 日)#

jaxlib 0.4.3 (2023 年 2 月 8 日)#

  • jax.Array 現在具有非封鎖 is_ready() 方法,如果陣列已就緒,則傳回 True (另請參閱 jax.block_until_ready() )。

jax 0.4.2 (2023 年 1 月 24 日)#

  • 重大變更

    • 已刪除 jax.experimental.callback

    • 在 jax2tf 形狀多態存在的情況下,具有維度的運算已通用化,可在更多情境中運作,方法是將符號維度轉換為 JAX 陣列。當結果用作形狀值時,涉及符號維度和 np.ndarray 的運算現在可能會引發錯誤 ( #14106 )。

    • jaxpr 物件現在會在屬性設定時引發錯誤,以避免有問題的變動 ( #14102 )

  • 變更

    • jax2tf.call_tf() 有一個新的參數 has_side_effects (預設為 True),可用於宣告執行個體是否可由 JAX 優化 (例如無效程式碼消除) 移除或複製 ( #13980 )。

    • 為 jax2tf 形狀多態新增了更多對 floordiv 和 mod 的支援。先前,某些除法運算在符號維度存在的情況下會導致錯誤 ( #14108 )。

jaxlib 0.4.2 (2023 年 1 月 24 日)#

  • 變更

    • 設定 JAX_USE_PJRT_C_API_ON_TPU=1 以啟用新的 Cloud TPU 執行階段,其具有自動裝置記憶體重組功能。

jax 0.4.1 (2022 年 12 月 13 日)#

  • 變更

    • 已停止支援 Python 3.7,這符合 JAX 的 Python 和 NumPy 版本支援政策

    • 我們推出了 jax.Array,這是一種統一的陣列型別,包含 JAX 中的 DeviceArrayShardedDeviceArrayGlobalDeviceArray 型別。jax.Array 型別有助於使平行處理成為 JAX 的核心功能、簡化和統一 JAX 內部結構,並讓我們能夠統一 jitpjitjax.Array 已在 JAX 0.4 中預設啟用,並對 pjit API 進行了一些重大變更。jax.Array 移轉指南 可以協助您將程式碼庫移轉至 jax.Array。您也可以查看 分散式陣列和自動平行處理 教學課程,以瞭解新的概念。

    • PartitionSpecMesh 現在已脫離實驗階段。新的 API 端點為 jax.sharding.PartitionSpecjax.sharding.Meshjax.experimental.maps.Meshjax.experimental.PartitionSpec 已被棄用,並將在 3 個月後移除。

    • with_sharding_constraint 的新公開端點為 jax.lax.with_sharding_constraint

    • 如果將 ABSL 旗標與 jax.config 一起使用,則在最初從 ABSL 旗標填入 JAX 組態選項後,將不再讀取或寫入 ABSL 旗標值。此變更改善了讀取 jax.config 選項的效能,這些選項在 JAX 中被廣泛使用。

    • jax2tf.call_tf 函式現在針對 TF 降低使用與嵌入 JAX 計算所用平台相同的平台的第一個 TF 裝置。先前,它是針對 JAX 預設後端使用第 0 個裝置。

    • 許多 jax.numpy 函式現在已將其引數標記為僅限位置引數,與 NumPy 相符。

    • jnp.msort 現在已被棄用,原因是 numpy 1.24 中已棄用 np.msort。它將在未來的版本中移除,這符合 API 相容性 政策。它可以使用 jnp.sort(a, axis=0) 替代。

jaxlib 0.4.1 (2022 年 12 月 13 日)#

  • 變更

    • 已停止支援 Python 3.7,這符合 JAX 的 Python 和 NumPy 版本支援政策

    • XLA_PYTHON_CLIENT_MEM_FRACTION=.XX 的行為已變更為配置 XX% 的 GPU 總記憶體,而不是先前使用目前可用的 GPU 記憶體來計算預先配置的行為。如需更多詳細資訊,請參閱 GPU 記憶體分配

    • 已移除已棄用的方法 .block_host_until_ready()。請改用 .block_until_ready()

jax 0.4.0 (2022 年 12 月 12 日)#

  • 此版本已撤回。

jaxlib 0.4.0 (2022 年 12 月 12 日)#

  • 此版本已撤回。

jax 0.3.25 (2022 年 11 月 15 日)#

jaxlib 0.3.25 (2022 年 11 月 15 日)#

  • 變更

    • 新增了對 CPU 和 GPU 上三對角歸約的支援。

    • 新增了對 CPU 上上黑森堡歸約的支援。

  • 錯誤修正

    • 修正了一個錯誤,該錯誤表示 JAX 擷取的追溯中的框架會錯誤地映射到 Python 3.10+ 下的原始碼行

jax 0.3.24 (2022 年 11 月 4 日)#

  • 變更

    • JAX 的匯入速度應該更快。我們現在延遲匯入 scipy,這佔 JAX 匯入時間的很大一部分。

    • 設定環境變數 JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS=$N 可用於限制寫入持續性快取的快取項目數量。依預設,編譯時間超過 1 秒的計算將會快取。

    • 如果未指定順序,則 pmap 在 TPU 上使用的預設裝置順序現在與單一程序作業的 jax.devices() 相符。先前,這兩個順序不同,這可能會導致不必要的複製或記憶體不足錯誤。要求順序一致可簡化問題。

  • 重大變更

  • 棄用

    • jax.sharding.MeshPspecSharding 已重新命名為 jax.sharding.NamedShardingjax.sharding.MeshPspecSharding 名稱將在 3 個月後移除。

jaxlib 0.3.24 (2022 年 11 月 4 日)#

  • 變更

    • 緩衝區捐贈現在可在 CPU 上運作。這可能會破壞在 CPU 上標記緩衝區以進行捐贈,但依賴捐贈未實作的程式碼。

jax 0.3.23 (2022 年 10 月 12 日)#

  • 變更

    • 更新 Colab TPU 驅動程式版本以用於新的 jaxlib 版本。

jax 0.3.22 (2022 年 10 月 11 日)#

  • 變更

    • 在 TPU 初始化中新增 JAX_PLATFORMS=tpu,cpu 作為預設設定,因此如果無法初始化 TPU,JAX 將引發錯誤,而不是回復為 CPU。設定 JAX_PLATFORMS='' 以覆寫此行為並自動選擇可用的後端 (原始預設值),或設定 JAX_PLATFORMS=cpu 以始終使用 CPU,無論 TPU 是否可用。

  • 棄用

    • JAX v0.3.8 中棄用的數個測試公用程式現在已從 jax.test_util 中移除。

jaxlib 0.3.22 (2022 年 10 月 11 日)#

jax 0.3.21 (2022 年 9 月 30 日)#

  • GitHub 提交.

  • 變更

    • 持續性編譯快取現在會在發生錯誤時發出警告而不是引發例外狀況 ( #12582 ),因此如果快取發生問題,程式執行可以繼續。設定 JAX_RAISE_PERSISTENT_CACHE_ERRORS=true 以還原此行為。

jax 0.3.20 (2022 年 9 月 28 日)#

  • 錯誤修復

    • 新增了先前版本中遺失的遺失 .pyi 檔案 ( #12536 )。

    • 修正了 jax 0.3.19 與其釘選的 libtpu 版本之間的不相容性 ( #12550 )。需要 jaxlib 0.3.20。

    • 修正了 setup.py 註解中不正確的 pip URL ( #12528 )。

jaxlib 0.3.20 (2022 年 9 月 28 日)#

  • GitHub 提交.

  • 錯誤修復

    • 修正了透過分散式作業中的 jax_cuda_visible_devices 限制可見 CUDA 裝置的支援。GPU 上 JAX/SLURM 整合需要此功能 ( #12533 )。

jax 0.3.19 (2022 年 9 月 27 日)#

jax 0.3.18 (2022 年 9 月 26 日)#

  • GitHub 提交.

  • 變更

    • 預先 (AOT) 降低和編譯功能 (在 #7733 中追蹤) 是穩定且公開的。請參閱 概觀jax.stages 的 API 文件。

    • 推出了 jax.Array,旨在用於 JAX 中陣列型別的 isinstance 檢查和型別註解。請注意,這包含對 jax.numpy.ndarrayisinstance 運作方式的一些細微變更,因為 jax.numpy.ndarray 現在是 jax.Array 的簡單別名。

  • 重大變更

    • jax._src 不再匯入到公開 jax 命名空間中。這可能會破壞正在使用 JAX 內部結構的使用者。

    • jax.soft_pmap 已刪除。請改用 pjitxmapjax.soft_pmap 未記載。如果已記載,則會提供棄用期。

jax 0.3.17 (2022 年 8 月 31 日)#

  • GitHub 提交.

  • 錯誤修正

    • 修正了指數為零時 lax.pow 梯度的邊角案例問題 ( #12041 )

  • 重大變更

    • jax.checkpoint() (也稱為 jax.remat()) 不再支援 concrete 選項,這是繼先前版本的棄用之後;請參閱 JEP 11830

  • 變更

    • 新增了 jax.pure_callback(),可從已編譯的函式 (例如,以 jax.jitjax.pmap 裝飾的函式) 回呼至純 Python 函式。

  • 棄用

    • 已移除已棄用的 DeviceArray.tile() 方法。請使用 jax.numpy.tile() ( #11944 )。

    • DeviceArray.to_py() 已被棄用。請改用 np.asarray(x)

jax 0.3.16#

jax 0.3.15 (2022 年 7 月 22 日)#

jaxlib 0.3.15 (2022 年 7 月 22 日)#

jax 0.3.14 (2022 年 6 月 27 日)#

  • GitHub 提交.

  • 重大變更

    • jax.experimental.compilation_cache.initialize_cache() 不再支援 max_cache_size_  bytes,且不會再將其作為輸入。

    • 當平台初始化失敗時,JAX_PLATFORMS 現在會引發例外。

  • 變更

    • 修正了與 NumPy 1.23 的相容性問題。

    • jax.numpy.linalg.slogdet() 現在接受可選的 method 引數,允許在基於 LU 分解的實作和基於 QR 分解的實作之間進行選擇。

    • jax.numpy.linalg.qr() 現在支援 mode="raw"

    • picklecopy.copycopy.deepcopy 在用於 jax 陣列時(#10659)現在有更完整的支援。特別是:

      • 先前,當對 DeviceArray 使用 pickledeepcopy 時,會傳回 np.ndarray 物件;現在會傳回 DeviceArray 物件。對於 deepcopy,複製的陣列與原始陣列位於相同的裝置上。對於 pickle,還原序列化的陣列將位於預設裝置上。

      • 在函式轉換(即追蹤程式碼)中,deepcopycopy 先前為無操作。現在它們使用與 DeviceArray.copy() 相同的機制。

      • 在追蹤陣列上呼叫 pickle 現在會導致明確的 ConcretizationTypeError

    • 奇異值分解 (SVD) 和對稱/埃爾米特特徵值分解的實作在 TPU 上應顯著加快,尤其是對於 1000x1000 或更大的矩陣。兩者現在都使用頻譜分治演算法進行特徵值分解 (QDWH-eig)。

    • jax.numpy.ldexp() 不再靜默地將所有輸入提升為 float64,而是對於 int32 或更小尺寸的整數輸入,它會提升為 float32 (#10921)。

    • jax.profiler.start_trace()jax.profiler.start_trace() 新增 create_perfetto_link 選項。使用時,效能分析器將產生 Perfetto UI 的連結以檢視追蹤。

    • 變更了 jax.profiler.start_server(...)() 的語意,將 keepalive 全域儲存,而不是要求使用者保留對它的參考。

    • 新增 jax.random.generalized_normal()

    • 新增 jax.random.ball()

    • 新增 jax.default_device()

    • 新增 python -m jax.collect_profile 指令碼,以手動擷取程式追蹤,作為 TensorBoard UI 的替代方案。

    • 新增 jax.named_scope 環境管理器,將效能分析器中繼資料新增至 Python 程式(類似於 jax.named_call)。

    • 在分散更新操作(即 :attr:jax.numpy.ndarray.at)中,不安全的隱式 dtype 轉換已被棄用,現在會導致 FutureWarning。在未來的版本中,這將變成錯誤。不安全隱式轉換的一個範例是 jnp.zeros(4, dtype=int).at[0].set(1.5),其中 1.5 先前會被靜默截斷為 1

    • jax.experimental.compilation_cache.initialize_cache() 現在支援 gcs 儲存桶路徑作為輸入。

    • 新增 jax.scipy.stats.gennorm()

    • 當係數具有前導零時,strip_zeros=Falsejax.numpy.roots() 現在表現更好 (#11215)。

jaxlib 0.3.14 (2022 年 6 月 27 日)#

  • GitHub 提交.

    • x86-64 Mac wheels 現在需要 Mac OS 10.14 (Mojave) 或更新版本。Mac OS 10.14 於 2018 年發布,因此這應該不是一個非常繁瑣的要求。

    • 捆綁的 NCCL 版本已更新至 2.12.12,修復了一些死鎖。

    • Python flatbuffers 套件不再是 jaxlib 的依賴項。

jax 0.3.13 (2022 年 5 月 16 日)#

jax 0.3.12 (2022 年 5 月 15 日)#

jax 0.3.11 (2022 年 5 月 15 日)#

  • GitHub 提交.

  • 變更

    • jax.lax.eigh() 現在接受可選的 sort_eigenvalues 引數,允許使用者選擇在 TPU 上停用特徵值排序。

  • 棄用

    • jax.lax.linalg 中函式的非陣列引數現在標記為僅限關鍵字。作為向後相容性步驟,以位置方式傳遞僅限關鍵字的引數會產生警告,但在未來的 JAX 版本中,以位置方式傳遞僅限關鍵字的引數將會失敗。但是,大多數使用者應優先使用 jax.numpy.linalg

    • 作為 scipy API 的 JAX 擴充功能,jax.scipy.linalg.polar_unitary() 已被棄用。請改用 jax.scipy.linalg.polar()

jax 0.3.10 (2022 年 5 月 3 日)#

jaxlib 0.3.10 (2022 年 5 月 3 日)#

  • GitHub 提交.

  • 變更

    • TF commit 修復了 MHLO canonicalizer 中的一個問題,該問題導致常數摺疊對於某些程式耗時過長或崩潰。

jax 0.3.9 (2022 年 5 月 2 日)#

  • GitHub 提交.

  • 變更

    • 新增了對 GlobalDeviceArray 的完全非同步檢查點的支援。

jax 0.3.8 (2022 年 4 月 29 日)#

  • GitHub 提交.

  • 變更

    • TPU 上的 jax.numpy.linalg.svd() 使用 qdwh-svd 求解器。

    • TPU 上的 jax.numpy.linalg.cond() 現在接受複數輸入。

    • TPU 上的 jax.numpy.linalg.pinv() 現在接受複數輸入。

    • TPU 上的 jax.numpy.linalg.matrix_rank() 現在接受複數輸入。

    • 已新增 jax.scipy.cluster.vq.vq()

    • jax.experimental.maps.mesh 已被刪除。請使用 jax.experimental.maps.Mesh。請參閱 https://jax.dev.org.tw/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh 以取得更多資訊。

    • mode='r' 時,jax.scipy.linalg.qr() 現在傳回長度為 1 的元組,而不是原始陣列,以便與 scipy.linalg.qr 的行為相符 (#10452)

    • jax.numpy.take_along_axis() 現在接受可選的 mode 參數,用於指定超出邊界索引的行為。預設情況下,超出邊界索引將傳回無效值(例如,NaN)。在先前的 JAX 版本中,無效索引會被鉗制在範圍內。先前的行為可以透過傳遞 mode="clip" 來恢復。

    • jax.numpy.take() 現在預設為 mode="fill",這會為超出邊界索引傳回無效值(例如,NaN)。

    • 分散操作(例如 x.at[...].set(...))現在具有 "drop" 語意。這對分散操作本身沒有影響,但這表示當微分分散時,超出邊界索引的梯度將產生零餘切。先前,超出邊界索引在梯度中被鉗制在範圍內,這在數學上是不正確的。

    • 如果 jax.numpy.take_along_axis() 的索引不是整數類型,則現在會引發 TypeError,與 numpy.take_along_axis() 的行為相符。先前,非整數索引會被靜默轉換為整數。

    • 如果 jax.numpy.ravel_multi_index()dims 引數不是整數類型,則現在會引發 TypeError,與 numpy.ravel_multi_index() 的行為相符。先前,非整數 dims 會被靜默轉換為整數。

    • 如果 jax.numpy.split()axis 引數不是整數類型,則現在會引發 TypeError,與 numpy.split() 的行為相符。先前,非整數 axis 會被靜默轉換為整數。

    • 如果 jax.numpy.indices() 的維度不是整數類型,則現在會引發 TypeError,與 numpy.indices() 的行為相符。先前,非整數維度會被靜默轉換為整數。

    • 如果 jax.numpy.diag()k 引數不是整數類型,則現在會引發 TypeError,與 numpy.diag() 的行為相符。先前,非整數 k 會被靜默轉換為整數。

    • 新增 jax.random.orthogonal()

  • 棄用

    • jax.test_util 中提供的許多函式和物件現在已被棄用,並且在匯入時會引發警告。這包括 cases_from_listcheck_closecheck_eqdevice_under_testformat_shape_dtype_stringrand_uniformskip_on_deviceswith_configxla_bridge_default_tolerance (#10389)。這些以及先前已棄用的 JaxTestCaseJaxTestLoaderBufferDonationTestCase 將在未來的 JAX 版本中移除。這些公用程式中的大多數都可以透過呼叫標準 python 和 numpy 測試公用程式來取代,例如在 unittestabsl.testingnumpy.testing 等中找到。JAX 特定功能(例如裝置檢查)可以透過使用公用 API(例如 jax.devices())來取代。許多已棄用的公用程式仍將存在於 jax._src.test_util 中,但這些不是公用 API,因此可能會在未來的版本中更改或移除,恕不另行通知。

jax 0.3.7 (2022 年 4 月 15 日)#

jaxlib 0.3.7 (2022 年 4 月 15 日)#

  • 變更

    • Linux wheels 現在根據 manylinux2014 標準而不是 manylinux2010 標準建置。

jax 0.3.6 (2022 年 4 月 12 日)#

  • GitHub 提交.

  • 變更

    • 升級了 libtpu wheel 到修復初始化 TPU pod 時掛起問題的版本。修復了 #10218

  • 棄用

    • jax.experimental.loops 正在被棄用。請參閱 #10278 以取得替代 API。

jax 0.3.5 (2022 年 4 月 7 日)#

jaxlib 0.3.5 (2022 年 4 月 7 日)#

  • 錯誤修復

    • 修復了雙精度複數到實數 IRFFT 會在 GPU 上變更其輸入緩衝區的錯誤 (#9946)。

    • 修復了複數分散的不正確常數摺疊問題 (#10159)

jax 0.3.4 (2022 年 3 月 18 日)#

jax 0.3.3 (2022 年 3 月 17 日)#

jax 0.3.2 (2022 年 3 月 16 日)#

  • GitHub 提交.

  • 變更

    • 在 0.2.22 中已棄用的函式 jax.ops.index_updatejax.ops.index_add 已被移除。請改用 JAX 陣列上的 .at 屬性,例如 x.at[idx].set(y)

    • jax.experimental.ann.approx_*_k 移至 jax.lax 中。這些函式是 jax.lax.top_k 的最佳化替代方案。

    • jax.numpy.broadcast_arrays()jax.numpy.broadcast_to() 現在需要純量或類陣列輸入,如果傳遞列表,則會失敗(#7737 的一部分)。

    • 標準 jax[tpu] 安裝現在可用於 Cloud TPU v4 VM。

    • pjit 現在可在 CPU 上運作(除了先前的 TPU 和 GPU 支援)。

jaxlib 0.3.2 (2022 年 3 月 16 日)#

  • 變更

    • XlaComputation.as_hlo_text() 現在支援透過傳遞布林標誌 print_large_constants=True 來列印大型常數。

  • 棄用

    • JAX 陣列上的 .block_host_until_ready() 方法已被棄用。請改用 .block_until_ready()

jax 0.3.1 (2022 年 2 月 18 日)#

jax 0.3.0 (2022 年 2 月 10 日)#

jaxlib 0.3.0 (2022 年 2 月 10 日)#

  • 變更

    • 現在需要 Bazel 5.0.0 才能建置 jaxlib。

    • jaxlib 版本已升級至 0.3.0。請參閱 設計文件 以取得說明。

jax 0.2.28 (2022 年 2 月 1 日)#

  • GitHub 提交.

    • 如果未傳遞 dialect=jax.jit(f).lower(...).compiler_ir() 現在預設為 MHLO 方言。

    • jax.jit(f).lower(...).compiler_ir(dialect='mhlo') 現在傳回 MLIR ir.Module 物件,而不是其字串表示形式。

jaxlib 0.1.76 (2022 年 1 月 27 日)#

  • 新功能

    • 包含適用於 NVidia 計算能力 8.0 GPU (例如 A100) 的預編譯 SASS。移除了適用於計算能力 6.1 的預編譯 SASS,以免增加計算能力的數量:計算能力為 6.1 的 GPU 可以使用 6.0 SASS。

    • 使用 jaxlib 0.1.76,JAX 預設使用 MHLO MLIR 方言作為其主要目標編譯器 IR。

  • 重大變更

    • 已停止支援 NumPy 1.18,根據廢棄政策。請升級至支援的 NumPy 版本。

  • 錯誤修復

    • 修正了由不同路徑建構,但外觀上相同的 pytreedef 物件,無法比較為相等的錯誤 (#9066)。

    • JAX jit 快取需要兩個靜態引數具有相同的類型,才能命中快取 (#9311)。

jax 0.2.27 (2022 年 1 月 18 日)#

  • GitHub 提交.

  • 重大變更

    • 已停止支援 NumPy 1.18,根據廢棄政策。請升級至支援的 NumPy 版本。

    • 已簡化 host_callback 原始運算,以移除針對 hcb.id_tap 和 id_print 的特殊自動微分處理。從現在起,僅會點擊(tapped)原始值。可以透過設定 JAX_HOST_CALLBACK_AD_TRANSFORMS 環境變數,或 --jax_host_callback_ad_transforms 旗標(flag),來取得舊的行為(在有限的時間內)。此外,新增了關於如何使用 JAX 自訂 AD API 實作舊行為的文件 (#8678)。

    • 現在排序行為與 NumPy 針對 0.0NaN 的行為一致,無論其位元表示法為何。特別是,0.0-0.0 現在被視為相等,而先前 -0.0 被視為小於 0.0。此外,所有 NaN 表示法現在都被視為相等,並排序至陣列的末尾。先前,負 NaN 值會排序至陣列的前端,且具有不同內部位元表示法的 NaN 值不會被視為相等,而是根據這些位元模式進行排序 (#9178)。

    • jax.numpy.unique() 現在以與 NumPy 版本 1.21 及更新版本中的 np.unique 相同的方式處理 NaN 值:最多只有一個 NaN 值會出現在唯一化輸出中 (#9184)。

  • 錯誤修復

    • host_callback 現在支援 ad_checkpoint.checkpoint (#8907)。

  • 新功能

    • 新增 jax.block_until_ready ({jax-issue}`#8941)

    • 新增了一個新的偵錯旗標/環境變數 JAX_DUMP_IR_TO=/path。如果設定,JAX 會將其為每個計算產生的 MHLO/HLO IR 傾印到給定路徑下的檔案中。

    • jax.ensure_compile_time_eval 新增至公開 API (#7987)。

    • jax2tf 現在支援旗標 jax2tf_associative_scan_reductions,以變更關聯性歸約的降低(lowering),例如 jnp.cumsum,使其在 CPU 和 GPU 上的行為類似於 JAX(使用關聯性掃描)。請參閱 jax2tf README 以取得更多詳細資訊 (#9189)。

jaxlib 0.1.75 (2021 年 12 月 8 日)#

  • 新功能

    • 支援 python 3.10。

jax 0.2.26 (2021 年 12 月 8 日)#

  • GitHub 提交.

  • 錯誤修復

    • 超出邊界的索引至 jax.ops.segment_sum 現在將會以 FILL_OR_DROP 語意處理,如文件所述。這主要影響反向模式導數,其中對應於超出邊界索引的梯度現在將會回傳為 0。(#8634)。

    • jax2tf 將強制轉換後的程式碼使用 XLA 處理 jax.jit 下的程式碼片段,例如,大多數 jax.numpy 函數 (#7839)。

jaxlib 0.1.74 (2021 年 11 月 17 日)#

  • 啟用了 GPU 之間的點對點複製。先前,GPU 複製會透過主機彈回,通常速度較慢。

  • 新增了實驗性的 MLIR Python 綁定,供 JAX 使用。

jax 0.2.25 (2021 年 11 月 10 日)#

  • GitHub 提交.

  • 新功能

    • (實驗性) jax.distributed.initialize 公開了多主機 GPU 後端。

    • jax.random.permutation 支援新的 independent 關鍵字引數 (#8430)

  • 重大變更

    • jax.experimental.stax 移至 jax.example_libraries.stax

    • jax.experimental.optimizers 移至 jax.example_libraries.optimizers

  • 新功能

    • 新增了 jax.lax.linalg.qdwh

jax 0.2.24 (2021 年 10 月 19 日)#

  • GitHub 提交.

  • 新功能

    • jax.random.choicejax.random.permutation 現在支援多維陣列和可選的 axis 引數 (#8158)

  • 重大變更

    • jax.numpy.takejax.numpy.take_along_axis 現在需要類陣列輸入(請參閱 #7737

jaxlib 0.1.73 (2021 年 10 月 18 日)#

  • jaxlib GPU cuda11 輪子(wheels)現在支援多個 cuDNN 版本。

    • cuDNN 8.2 或更新版本。如果您的 cuDNN 安裝夠新,我們建議使用 cuDNN 8.2 輪子,因為它支援額外功能。

    • cuDNN 8.0.5 或更新版本。

  • 重大變更

    • GPU jaxlib 的安裝指令如下

      pip install --upgrade pip
      
      # Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
      pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
      
      # Installs the wheel compatible with Cuda 11 and cudnn 8.2 or newer.
      pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html
      
      # Installs the wheel compatible with Cuda 11 and cudnn 8.0.5 or newer.
      pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
      

jax 0.2.22 (2021 年 10 月 12 日)#

  • GitHub 提交.

  • 重大變更

    • jax.pmap 的靜態引數現在必須是可雜湊的(hashable)。

      長期以來,jax.jit 上一直不允許使用不可雜湊的靜態引數,但 jax.pmap 仍然允許使用;jax.pmap 使用物件識別來比較不可雜湊的靜態引數。

      這種行為是一個隱藏的陷阱,因為使用物件識別來比較引數,每次物件識別變更時都會導致重新編譯。相反地,我們現在禁止不可雜湊的引數:如果 jax.pmap 的使用者想要透過物件識別來比較靜態引數,他們可以在物件上定義 __hash____eq__ 方法來執行此操作,或者將他們的物件包裝在具有使用物件識別語意的這些操作的物件中。另一個選項是使用 functools.partial 將不可雜湊的靜態引數封裝到函數物件中。

    • jax.util.partial 是一個意外匯出的項目,現在已移除。請改用 Python 標準函式庫中的 functools.partial

  • 棄用

    • 函數 jax.ops.index_updatejax.ops.index_add 等已棄用,並將在未來的 JAX 版本中移除。請改用JAX 陣列上的 .at 屬性,例如,x.at[idx].set(y)。目前,這些函數會產生 DeprecationWarning

  • 新功能

    • 當使用 jaxlib 0.1.72 或更新版本時,改善 pmap 的調度時間的優化 C++ 程式碼路徑現在是預設值。可以使用 --experimental_cpp_pmap 旗標(或 JAX_CPP_PMAP 環境變數)停用此功能。

    • jax.numpy.unique 現在支援可選的 fill_value 引數 (#8121)

jaxlib 0.1.72 (2021 年 10 月 12 日)#

  • 重大變更

    • 已停止支援 CUDA 10.2 和 CUDA 10.1。Jaxlib 現在支援 CUDA 11.1+。

  • 錯誤修復

    • 修正了 https://github.com/jax-ml/jax/issues/7461,該問題由於 XLA 編譯器內部的緩衝區別名錯誤,導致在所有平台上產生錯誤輸出。

jax 0.2.21 (2021 年 9 月 23 日)#

  • GitHub 提交.

  • 重大變更

    • jax.api 已移除。作為 jax.api.* 提供的函數是 jax.* 中函數的別名;請改用 jax.* 中的函數。

    • jax.partialjax.lax.partial 是意外匯出的項目,現在已移除。請改用 Python 標準函式庫中的 functools.partial

    • 布林純量索引現在會引發 TypeError;先前這會靜默地回傳錯誤結果 (#7925)。

    • 更多 jax.numpy 函數現在需要類陣列輸入,如果傳入列表(list)將會發生錯誤 (#7747 #7802 #7907)。請參閱 #7737 以了解此變更背後的理由。

    • 當在諸如 jax.jit 之類的轉換內部時,jax.numpy.array 總是將其產生的陣列分段到追蹤計算中。先前,即使在 jax.jit 裝飾器下,jax.numpy.array 有時也會產生裝置上的陣列。此變更可能會破壞使用 JAX 陣列執行形狀或索引計算的程式碼,這些計算必須是靜態已知的;解決方法是改用經典的 NumPy 陣列執行這些計算。

    • jnp.ndarray 現在是 JAX 陣列的真正基底類別。特別是,這表示對於標準 numpy 陣列 xisinstance(x, jnp.ndarray) 現在將會回傳 False (#7927)。

  • 新功能

jax 0.2.20 (2021 年 9 月 2 日)#

  • GitHub 提交.

  • 重大變更

    • jnp.poly* 函數現在需要類陣列輸入 (#7732)

    • jnp.unique 和其他類似集合的操作現在需要類陣列輸入 (#7662)

jaxlib 0.1.71 (2021 年 9 月 1 日)#

  • 重大變更

    • 已停止支援 CUDA 11.0 和 CUDA 10.1。Jaxlib 現在支援 CUDA 10.2 和 CUDA 11.1+。

jax 0.2.19 (2021 年 8 月 12 日)#

  • GitHub 提交.

  • 重大變更

    • 已停止支援 NumPy 1.17,根據廢棄政策。請升級至支援的 NumPy 版本。

    • 已在 JAX 陣列上的一些運算子的實作周圍新增了 jit 裝飾器。這加快了常見運算子(例如 +)的調度時間。

      此變更對大多數使用者來說應該基本上是透明的。但是,有一個已知的行為變更,那就是大型整數常數在直接傳遞給 JAX 運算子時(例如,x + 2**40)現在可能會產生錯誤。解決方法是將常數轉換為明確的類型(例如,np.float64(2**40))。

  • 新功能

    • 改善了 jax2tf 中對形狀多型性的支援,以用於需要在陣列計算中使用維度大小的操作,例如 jnp.mean。 (#7317)

  • 錯誤修復

    • 先前版本中的一些洩漏追蹤錯誤 (#7613)

jaxlib 0.1.70 (2021 年 8 月 9 日)#

  • 重大變更

    • 已停止支援 Python 3.6,根據廢棄政策。請升級至支援的 Python 版本。

    • 已停止支援 NumPy 1.17,根據廢棄政策。請升級至支援的 NumPy 版本。

    • host_callback 機制現在為每個本機裝置使用一個執行緒,以進行對 Python 回呼的呼叫。先前,所有裝置都只有一個執行緒。這表示回呼現在可能會交錯呼叫。對應於一個裝置的回呼仍將依序呼叫。

jax 0.2.18 (2021 年 7 月 21 日)#

  • GitHub 提交.

  • 重大變更

    • 已停止支援 Python 3.6,根據廢棄政策。請升級至支援的 Python 版本。

    • 最低 jaxlib 版本現在為 0.1.69。

    • 已移除 jax.dlpack.from_dlpack()backend 引數。

  • 新功能

  • 錯誤修復

    • 收緊了對 lax.argmin 和 lax.argmax 的檢查,以確保它們不會與無效的 axis 值或空的歸約維度一起使用。 (#7196)

jaxlib 0.1.69 (2021 年 7 月 9 日)#

  • 修正了 TFRT CPU 後端中導致結果不正確的錯誤。

jax 0.2.17 (2021 年 7 月 9 日)#

  • GitHub 提交.

  • 錯誤修復

    • 對於 jaxlib <= 0.1.68,預設為較舊的 “stream_executor” CPU 執行階段,以解決 #7229,該問題由於並行問題導致 CPU 上產生錯誤輸出。

  • 新功能

jax 0.2.16 (2021 年 6 月 23 日)#

jax 0.2.15 (2021 年 6 月 23 日)#

  • GitHub 提交.

  • 新功能

  • 重大變更

  • 錯誤修復

    • 修正了阻止從 JAX 到 TF 再返回的往返錯誤:jax2tf.call_tf(jax2tf.convert) (#6947)。

jaxlib 0.1.68 (2021 年 6 月 23 日)#

  • 錯誤修復

    • 修正了 TFRT CPU 後端中將 TPU 緩衝區傳輸到 CPU 時會出現 nan 的錯誤。

jax 0.2.14 (2021 年 6 月 10 日)#

  • GitHub 提交.

  • 新功能

    • jax2tf.convert() 現在支援 pjitsharded_jit

    • 新的組態選項 JAX_TRACEBACK_FILTERING 控制 JAX 如何篩選回溯(tracebacks)。

    • 在足夠新版本的 IPython 中,預設情況下已啟用使用 __tracebackhide__ 的新回溯篩選模式。

    • jax2tf.convert() 即使在算術運算中使用未知維度,也支援形狀多型性,例如 jnp.reshape(-1) (#6827)。

    • jax2tf.convert() 在 TF 運算中產生具有位置資訊的自訂屬性。jax2tf 之後 XLA 產生的程式碼與 JAX/XLA 具有相同的位置資訊。

    • 新的 SciPy 函數 jax.scipy.special.lpmn()

  • 錯誤修復

    • jax2tf.convert() 現在確保它對 Python 純量使用相同的類型規則,並為選擇 32 位元與 64 位元計算使用與 JAX 相同的規則 (#6883)。

    • jax2tf.convert() 現在正確地將 enable_xla 轉換參數的作用域設定為僅在即時轉換期間套用 (#6720)。

    • jax2tf.convert() 現在使用 XlaDot TensorFlow 運算來轉換 lax.dot_general,以獲得更好的關於 JAX 數值精確度的保真度 (#6717)。

    • jax2tf.convert() 現在支援複數的不等式比較和 min/max (#6892)。

jaxlib 0.1.67 (2021 年 5 月 17 日)#

jaxlib 0.1.66 (2021 年 5 月 11 日)#

  • 新功能

    • CUDA 11.1 輪子現在在所有 CUDA 11 版本 11.1 或更高版本上都受到支援。

      NVidia 現在承諾 CUDA 次要版本之間的相容性,從 CUDA 11.1 開始。這表示 JAX 可以發布單個 CUDA 11.1 輪子,該輪子與 CUDA 11.2 和 11.3 相容。

      不再有針對 CUDA 11.2(或更高版本)的單獨 jaxlib 版本;對於這些版本,請使用 CUDA 11.1 輪子 (cuda111)。

    • Jaxlib 現在在 CUDA 輪子中捆綁了 libdevice.10.bc。應該不再需要指示 JAX CUDA 安裝位置來尋找此檔案。

    • jit() 實作新增了對靜態關鍵字引數的自動支援。

    • 新增了對預轉換例外追蹤的支援。

    • 初步支援從 jit() 轉換的計算中修剪未使用的引數。修剪仍在進行中。

    • 改善了 PyTreeDef 物件的字串表示形式。

    • 新增了對 XLA 的可變參數 ReduceWindow 的支援。

  • 錯誤修復

    • 修正了遠端雲端 TPU 支援中,將大量引數傳遞給計算時的錯誤。

    • 修正了一個錯誤,該錯誤表示 JAX 垃圾回收未由 jit() 轉換的函數觸發。

jax 0.2.13 (2021 年 5 月 3 日)#

  • GitHub 提交.

  • 新功能

    • 當與 jaxlib 0.1.66 結合使用時,jax.jit() 現在支援靜態關鍵字引數。新增了一個新的 static_argnames 選項,以將關鍵字引數指定為靜態。

    • jax.nonzero() 有一個新的可選 size 引數,使其可以在 jit 內使用 (#6501)

    • jax.numpy.unique() 現在支援 axis 引數 (#6532)。

    • jax.experimental.host_callback.call() 現在支援 pjit.pjit (#6569)。

    • 新增了 jax.scipy.linalg.eigh_tridiagonal(),用於計算三對角矩陣的特徵值。目前僅支援特徵值。

    • 例外狀況中經過篩選和未經篩選的堆疊追蹤的順序已變更。附加到從 JAX 轉換的程式碼擲回的例外狀況的回溯現在已篩選,其中 UnfilteredStackTrace 例外狀況包含原始追蹤作為篩選例外狀況的 __cause__。篩選的堆疊追蹤現在也適用於 Python 3.6。

    • 如果由反向模式自動微分轉換的程式碼擲回例外狀況,JAX 現在會嘗試將 JaxStackTraceBeforeTransformation 物件作為例外狀況的 __cause__ 附加,該物件包含在前向傳遞中建立原始運算的堆疊追蹤。需要 jaxlib 0.1.66。

  • 重大變更

    • 以下函數名稱已變更。仍然有別名,因此這不應破壞現有程式碼,但別名最終將被移除,因此請變更您的程式碼。

    • 同樣地,local_devices() 的引數已從 host_id 重新命名為 process_index

    • 除了函數之外,jax.jit() 的引數現在標記為僅限關鍵字。此變更是為了防止在將引數新增至 jit 時意外中斷。

  • 錯誤修復

    • jax2tf.convert() 現在在具有整數輸入的函數的梯度存在的情況下也能運作 (#6360)。

    • 修正了與捕獲的 tf.Variable 一起使用時,jax2tf.call_tf() 中的斷言失敗問題 (#6572)。

jaxlib 0.1.65 (2021 年 4 月 7 日)#

jax 0.2.12 (2021 年 4 月 1 日)#

  • GitHub 提交.

  • 新功能

  • 重大變更

    • 最低 jaxlib 版本現在為 0.1.64。

    • 一些分析器 API 名稱已變更。仍然有別名,因此這不應破壞現有程式碼,但別名最終將被移除,因此請變更您的程式碼。

    • Omnistaging 功能已無法停用。請參閱 omnistaging 以了解更多資訊。

    • Python 整數若大於 int64 的最大值,在所有情況下都會導致溢位 (overflow),而不會像過去在某些情況下靜默轉換為 uint64 (#6047)。

    • 在 X64 模式之外,Python 整數若超出 int32 可表示的範圍,現在會導致 OverflowError 錯誤,而不是靜默截斷其值。

  • 錯誤修復

    • host_callback 現在支援在引數和結果中使用空陣列 (#6262)。

    • jax.random.randint() 對超出範圍的限制值進行裁剪 (clip) 而非環繞 (wrap),並且現在可以產生指定 dtype 完整範圍內的整數 (#5868)。

jax 0.2.11 (2021 年 3 月 23 日)#

  • GitHub 提交.

  • 新功能

    • #6112 新增了上下文管理器 (context manager):jax.enable_checksjax.check_tracer_leaksjax.debug_nansjax.debug_infsjax.log_compiles

    • #6085 新增了 jnp.delete

  • 錯誤修復

    • #6136jax.flatten_util.ravel_pytree 泛化,以處理整數 dtype。

    • #6129 修復了處理某些常數 (例如 enum.IntEnums) 時的錯誤。

    • #6145 修復了不完全 beta 函數的批次處理問題。

    • #6014 修復了追蹤 (tracing) 期間的 H2D 傳輸問題。

    • #6165 避免了將某些大型 Python 整數轉換為浮點數時的 OverflowError 錯誤。

  • 重大變更

    • 最低 jaxlib 版本現在為 0.1.62。

jaxlib 0.1.64 (2021 年 3 月 18 日)#

jaxlib 0.1.63 (2021 年 3 月 17 日)#

jax 0.2.10 (2021 年 3 月 5 日)#

  • GitHub 提交.

  • 新功能

    • jax.scipy.stats.chi2() 現在可以作為具有 logpdf 和 pdf 方法的分布 (distribution) 使用。

    • jax.scipy.stats.betabinom() 現在可以作為具有 logpmf 和 pmf 方法的分布 (distribution) 使用。

    • 新增了 jax.experimental.jax2tf.call_tf(),以便從 JAX 呼叫 TensorFlow 函數 (#5627) 和 README)。

    • 擴展了 lax.pad 的批次處理規則,以支援填充值 (padding values) 的批次處理。

  • 錯誤修復

  • 重大變更

    • JAX 的型別提升規則已調整,使提升更一致且不受 JIT 影響。特別是,二元運算現在可以在適當的情況下產生弱型別的值。此變更主要在使用者可見的影響是,某些運算會產生與之前不同精度的輸出;例如,運算式 jnp.bfloat16(1) + 0.1 * jnp.arange(10) 先前傳回 float64 陣列,現在則傳回 bfloat16 陣列。JAX 的型別提升行為請參閱 型別提升語意

    • jax.numpy.linspace() 現在計算整數值的下限值,即朝 -inf 而非 0 捨入。此變更旨在與 NumPy 1.20.0 相符。

    • jax.numpy.i0() 不再接受複數。先前,此函數會計算複數引數的絕對值。此變更旨在與 NumPy 1.20.0 的語意相符。

    • 數個 jax.numpy 函數不再接受以元組或列表取代陣列引數:jax.numpy.pad()、:funcjax.numpy.raveljax.numpy.repeat()jax.numpy.reshape()。一般而言,jax.numpy 函數應搭配純量或陣列引數使用。

jaxlib 0.1.62 (2021 年 3 月 9 日)#

  • 新功能

    • jaxlib wheels 現在預設建置為在 x86-64 機台上需要 AVX 指令。如果您想在不支援 AVX 的機器上使用 JAX,可以使用 build.py--target_cpu_features 旗標從原始碼建置 jaxlib。--target_cpu_features 也取代了 --enable_march_native

jaxlib 0.1.61 (2021 年 2 月 12 日)#

jaxlib 0.1.60 (2021 年 2 月 3 日)#

  • 錯誤修復

    • 修復了將 CPU DeviceArrays 轉換為 NumPy 陣列時的記憶體洩漏問題。記憶體洩漏存在於 jaxlib 版本 0.1.58 和 0.1.59 中。

    • boolint8uint8 現在被視為可以安全地轉換為 bfloat16 NumPy 擴充型別。

jax 0.2.9 (2021 年 1 月 26 日)#

jaxlib 0.1.59 (2021 年 1 月 15 日)#

jax 0.2.8 (2021 年 1 月 12 日)#

  • GitHub 提交.

  • 新功能

    • 新增了 jax.closure_convert(),用於高階自訂導數函數。(<#5244>)

    • 新增了 jax.experimental.host_callback.call(),以便在主機上呼叫自訂 Python 函數,並將結果傳回裝置運算。(<#5243>)

  • 錯誤修復

    • jax.numpy.arccosh 現在針對複數輸入傳回與 numpy.arccosh 相同的分支 (#5156)

    • host_callback.id_tap 現在也適用於 jax.pmap。針對 id_tapid_print 有一個選用參數,可要求將輕觸值的裝置作為關鍵字引數傳遞至輕觸函數 (#5182)。

  • 重大變更

    • jax.numpy.pad 現在接受關鍵字引數。位置引數 constant_values 已移除。此外,傳遞不支援的關鍵字引數會引發錯誤。

    • jax.experimental.host_callback.id_tap() 的變更 (#5243)

      • 移除了 jax.experimental.host_callback.id_tap()kwargs 支援。(此支援已棄用數個月。)

      • 變更了 jax.experimental.host_callback.id_print() 的元組列印方式,從 ‘[‘ 改為使用 ‘(‘。

      • 變更了 JVP 存在時的 jax.experimental.host_callback.id_print(),以列印原始值和切線值組。先前,原始值和切線值有兩個個別的列印運算。

      • host_callback.outfeed_receiver 已移除 (它不是必要的,且數個月前已棄用)。

  • 新功能

    • 用於偵錯 inf 的新旗標,類似於 NaN 的旗標 (#5224)。

jax 0.2.7 (2020 年 12 月 4 日)#

  • GitHub 提交.

  • 新功能

    • 新增了 jax.device_put_replicated

    • 將多主機支援新增至 jax.experimental.sharded_jit

    • 新增了對 jax.numpy.linalg.eig 計算出的特徵值進行微分的支援

    • 新增了在 Windows 平台上建置的支援

    • jax.pmap 中新增了對一般 in_axes 和 out_axes 的支援

    • 針對 jax.numpy.linalg.slogdet 新增了複數支援

  • 錯誤修復

    • 修復了 jax.numpy.sinc 在零點的高於二階的導數

    • 修復了轉置規則中一些難以命中的符號零錯誤

  • 重大變更

    • jax.experimental.optix 已刪除,改用獨立的 optax Python 套件。

    • 現在使用非元組序列對 JAX 陣列進行索引會引發 TypeError 錯誤。自 Numpy v1.16 起已棄用此類型的索引,自 JAX v0.2.4 起亦同。請參閱 #4564

jax 0.2.6 (2020 年 11 月 18 日)#

  • GitHub 提交.

  • 新功能

    • 為 jax.experimental.jax2tf 轉換器新增了形狀多型追蹤的支援。請參閱 README.md

  • 重大變更清理

    • 針對 jax.jit 和 xla_computation 的不可雜湊靜態引數引發錯誤。請參閱 cb48f42

    • 改進了型別提升行為的一致性 (#4744)

      • 將複數 Python 純量新增至 JAX 浮點數時,會遵守 JAX 浮點數的精確度。例如,jnp.float32(1) + 1j 現在會傳回 complex64,先前則傳回 complex128

      • 涉及 uint64、帶正負號整數和第三種類型的 3 個或更多項目的型別提升結果,現在與引數的順序無關。例如:jnp.result_type(jnp.uint64, jnp.int64, jnp.float16)jnp.result_type(jnp.float16, jnp.uint64, jnp.int64) 都會傳回 float16,先前前者傳回 float64,後者則傳回 float16

    • 未記載的 jax.lax_linalg 線性代數模組的內容現在公開為 jax.lax.linalg

    • jax.random.PRNGKey 現在在 JIT 編譯內外產生相同的結果 (#4877)。這需要在少數特定情況下變更給定種子的結果

      • 使用 jax_enable_x64=False 時,作為 Python 整數傳遞的負種子現在會在 JIT 模式外傳回不同的結果。例如,jax.random.PRNGKey(-1) 先前傳回 [4294967295, 4294967295],現在則傳回 [0, 4294967295]。這與 JIT 中的行為相符。

      • 超出 JIT 外 int64 可表示範圍的種子,現在會導致 OverflowError 錯誤,而不是 TypeError 錯誤。這與 JIT 中的行為相符。

      若要還原先前針對 JIT 外使用 jax_enable_x64=False 的負整數傳回的金鑰,您可以使用

      key = random.PRNGKey(-1).at[0].set(0xFFFFFFFF)
      
    • DeviceArray 現在會在嘗試存取已刪除的值時,引發 RuntimeError 錯誤,而不是 ValueError 錯誤。

jaxlib 0.1.58 (約 2021 年 1 月 12 日)#

  • 修復了一個錯誤,該錯誤表示 JAX 有時會傳回平台特定的型別 (例如 np.cint),而不是標準型別 (例如 np.int32)。(#4903)

  • 修復了常數摺疊某些 int16 運算時發生的當機問題。(#4971)

  • pytree.flatten() 新增了 is_leaf 述詞。

jaxlib 0.1.57 (2020 年 11 月 12 日)#

  • 修復了 GPU wheels 中的 manylinux2010 相容性問題。

  • 將 CPU FFT 實作從 Eigen 切換為 PocketFFT。

  • 修復了 bfloat16 值的雜湊未正確初始化且可能會變更的錯誤 (#4651)。

  • 新增了在將陣列傳遞至 DLPack 時保留擁有權的支援 (#4636)。

  • 修復了批次三角解算大小大於 128 但不是 128 倍數時的錯誤。

  • 修復了在多個 GPU 上執行並行 FFT 時的錯誤 (#3518)。

  • 修復了工具遺失的分析器錯誤 (#4427)。

  • 已停止支援 CUDA 10.0。

jax 0.2.5 (2020 年 10 月 27 日)#

jax 0.2.4 (2020 年 10 月 19 日)#

  • GitHub 提交.

  • 改進

    • 為 jax.experimental.host_callback 新增了 remat 的支援。請參閱 #4608

  • 棄用

    • 現在已棄用使用非元組序列進行索引,這與 Numpy 中的類似棄用相同。在未來的版本中,這會導致 TypeError 錯誤。請參閱 #4564

jaxlib 0.1.56 (2020 年 10 月 14 日)#

jax 0.2.3 (2020 年 10 月 14 日)#

  • GitHub 提交.

  • 這麼快再次發布的原因是,我們需要暫時回溯新的 jit 快速路徑,同時調查效能降低問題

jax 0.2.2 (2020 年 10 月 13 日)#

jax 0.2.1 (2020 年 10 月 6 日)#

  • GitHub 提交.

  • 改進

    • 作為 omnistaging 的優點,即使 jax.experimental.host_callback.id_print()/ jax.experimental.host_callback.id_tap() 的結果未在運算中使用,host_callback 函數仍會 (依程式順序) 執行。

jax (0.2.0) (2020 年 9 月 23 日)#

jax (0.1.77) (2020 年 9 月 15 日)#

  • 重大變更

    • jax.experimental.host_callback.id_tap() 的新簡化介面 (#4101)

jaxlib 0.1.55 (2020 年 9 月 8 日)#

  • 更新 XLA

    • 修復 DLPackManagedTensorToBuffer 中的錯誤 (#4196)

jax 0.1.76 (2020 年 9 月 8 日)#

jax 0.1.75 (2020 年 7 月 30 日)#

  • GitHub 提交.

  • 錯誤修正

    • 讓 jnp.abs() 適用於未帶正負號的輸入 (#3914)

  • 改進

    • 在旗標後方新增 “Omnistaging” 行為,預設為停用 (#3370)

jax 0.1.74 (2020 年 7 月 29 日)#

  • GitHub 提交.

  • 新功能

    • BFGS (#3101)

    • TPU 支援半精度算術 (#3878)

  • 錯誤修正

    • 防止一些意外的 dtype 警告 (#3874)

    • 修復自訂導數中的多執行緒錯誤 (#3845, #3869)

  • 改進

    • 更快速的 searchsorted 實作 (#3873)

    • 改善 jax.numpy 排序演算法的測試涵蓋率 (#3836)

jaxlib 0.1.52 (2020 年 7 月 22 日)#

  • 更新 XLA。

jax 0.1.73 (2020 年 7 月 22 日)#

  • GitHub 提交.

  • 最低 jaxlib 版本現在為 0.1.51。

  • 新功能

    • jax.image.resize。 (#3703)

    • hfft 和 ihfft (#3664)

    • jax.numpy.intersect1d (#3726)

    • jax.numpy.lexsort (#3812)

    • lax.scanscan 原始項目支援 unroll 參數,以便在降低至 XLA 時進行迴圈展開 (#3738)。

  • 錯誤修正

    • 修復縮減重複軸錯誤 (#3618)

    • 修復 lax.pad 針對大小為 0 的輸入維度的形狀規則。(#3608)

    • 讓 psum 轉置處理零餘切 (#3653)

    • 修復在大小為 0 軸上取得 reduce-prod 的 JVP 時的形狀錯誤。(#3729)

    • 支援透過 jax.lax.all_to_all 進行微分 (#3733)

    • 解決 jax.scipy.special.zeta 中的 nan 問題 (#3777)

  • 改進

    • 對 jax2tf 進行了許多改進

    • 使用單遍可變引數縮減重新實作 argmin/argmax。 (#3611)

    • 預設啟用 XLA SPMD 分割。 (#3151)

    • 新增了對 0d 轉置捲積的支援 (#3643)

    • 讓 LU 梯度適用於低秩矩陣 (#3610)

    • 支援 jet 中的 multiple_results 和自訂 JVP (#3657)

    • 將 reduce-window 填充廣泛化,以支援 (lo, hi) 配對。 (#3728)

    • 在 CPU 和 GPU 上實作複數捲積。 (#3735)

    • 讓 jnp.take 適用於空陣列的空切片。 (#3751)

    • 放寬 dot_general 的維度排序規則。 (#3778)

    • 為 GPU 啟用緩衝區捐贈。 (#3800)

    • 為 reduce window 運算新增了對基本擴張和視窗擴張的支援… (#3803)

jaxlib 0.1.51 (2020 年 7 月 2 日)#

  • 更新 XLA。

  • 為 host_callback 新增了新的執行階段支援。

jax 0.1.72 (2020 年 6 月 28 日)#

  • GitHub 提交.

  • 錯誤修復

    • 修復了先前版本中引入的 odeint 錯誤,請參閱 #3587

jax 0.1.71 (2020 年 6 月 25 日)#

  • GitHub 提交.

  • 最低 jaxlib 版本現在為 0.1.48。

  • 錯誤修復

    • 允許 jax.experimental.ode.odeint 動態函數關閉關於我們正在微分的值 #3562

jaxlib 0.1.50 (2020 年 6 月 25 日)#

  • 新增了 CUDA 11.0 的支援。

  • 停止支援 CUDA 9.2 (我們僅維護對最新四個 CUDA 版本的支援。)

  • 更新 XLA。

jaxlib 0.1.49 (2020 年 6 月 19 日)#

jaxlib 0.1.48 (2020 年 6 月 12 日)#

  • 新功能

    • 新增了對快速追蹤收集的支援。

    • 新增了對裝置上堆積分析的初步支援。

    • bfloat16 型別實作 np.nextafter

    • CPU 和 GPU 上 FFT 的 Complex128 支援。

  • 錯誤修復

    • 改進了 GPU 上 float64 tanh 的準確性。

    • GPU 上的 float64 散佈速度更快。

    • CPU 上的複數矩陣乘法應該更快。

    • CPU 上的穩定排序現在應該確實穩定。

    • CPU 後端的並行錯誤修正。

jax 0.1.70 (2020 年 6 月 8 日)#

  • GitHub 提交.

  • 新功能

    • lax.switch 引入了具有多個分支的索引條件,以及 cond 原始項目的廣泛化 #3318

jax 0.1.69 (2020 年 6 月 3 日)#

jax 0.1.68 (2020 年 5 月 21 日)#

  • GitHub 提交.

  • 新功能

    • lax.cond() 支援單一運算元形式,作為兩個分支的引數 #2993

  • 值得注意的變更

    • jax.experimental.host_callback.id_tap() 原始項目的 transforms 關鍵字的格式已變更 #3132

jax 0.1.67 (2020 年 5 月 12 日)#

  • GitHub 提交.

  • 新功能

    • 使用 axis_index_groups 支援在 pmapped 軸的子集上進行縮減 #2382

    • 從編譯碼進行列印和呼叫主機端 Python 函數的實驗性支援。請參閱 id_print 和 id_tap (#3006)。

  • 值得注意的變更

    • 已收緊從 jax.numpy 匯出的名稱的能見度。這可能會中斷使用先前意外匯出的名稱的程式碼。

jaxlib 0.1.47 (2020 年 5 月 8 日)#

  • 修復了 outfeed 的當機問題。

jax 0.1.66 (2020 年 5 月 5 日)#

jaxlib 0.1.46 (2020 年 5 月 5 日)#

  • 修復了 Mac OS X 上線性代數函數的當機問題 (#432)。

  • 修復了在作業系統或 Hypervisor 停用 AVX512 指令時,因使用 AVX512 指令而導致的非法指令當機問題 (#2906)。

jax 0.1.65 (2020 年 4 月 30 日)#

  • GitHub 提交.

  • 新功能

    • 奇異矩陣行列式的微分 #2809

  • 錯誤修復

    • 修復了具有時間相關動態的 ODE odeint() 相對於時間的微分 #2817,同時新增 ODE CI 測試。

    • 修復了 lax_linalg.qr() 微分 #2867

jaxlib 0.1.45 (2020 年 4 月 21 日)#

  • 修復了區段錯誤:#2755

  • 將 Sort HLO 上的 is_stable 選項貫穿至 Python。

jax 0.1.64 (2020 年 4 月 21 日)#

jaxlib 0.1.44 (2020 年 4 月 16 日)#

  • 修復了一個錯誤,該錯誤導致當存在多個不同型號的 GPU 時,JAX 只會編譯適用於第一個 GPU 的程式。

  • 修復了 batch_group_count 卷積的錯誤。

  • 為更多 GPU 版本新增了預編譯的 SASS,以避免啟動 PTX 編譯卡頓。

jax 0.1.63 (2020 年 4 月 12 日)#

  • GitHub 提交.

  • #2026 新增了 jax.custom_jvpjax.custom_vjp,請參閱教學筆記本。已棄用 jax.custom_transforms 並從文件中移除(但它仍然有效)。

  • 新增 scipy.sparse.linalg.cg #2566

  • 變更了 Tracers 的列印方式,以顯示更多有用的除錯資訊 #2591

  • 使 jax.numpy.isclose 正確處理 naninf #2501

  • jax.experimental.jet 新增了幾個新規則 #2537

  • 修復了當未提供 scale/center 時的 jax.experimental.stax.BatchNorm

  • 修復了 jax.numpy.einsum 中一些遺失的廣播案例 #2512

  • 實作了以平行前綴掃描表示的 jax.numpy.cumsumjax.numpy.cumprod #2596,並使 reduce_prod 可微分至任意階 #2597

  • batch_group_count 新增至 conv_general_dilated #2635

  • test_util.check_grads 新增了文件字串 #2656

  • 新增 callback_transform #2665

  • 實作了 rollaxisconvolve/correlate 1d & 2d、copysigntruncroots,以及 quantile/percentile 插值選項。

jaxlib 0.1.43 (2020 年 3 月 31 日)#

  • 修復了 GPU 上 Resnet-50 的效能衰退問題。

jax 0.1.62 (2020 年 3 月 21 日)#

  • GitHub 提交.

  • JAX 已停止支援 Python 3.5。請升級至 Python 3.6 或更新版本。

  • 移除了內部函數 lax._safe_mul,它實作了 0. * nan == 0. 的慣例。此變更意味著某些程式在微分時會產生 nan 值,而之前它們會產生正確的值,儘管它確保為其他程式產生 nan 值而不是靜默地產生不正確的結果。詳情請參閱 #2447 和 #1052。

  • 新增了 all_gather 平行便利函數。

  • 核心程式碼中新增了更多類型註釋。

jaxlib 0.1.42 (2020 年 3 月 19 日)#

  • jaxlib 0.1.41 由於 API 不相容而中斷了雲端 TPU 支援。此版本再次修復了它。

  • JAX 已停止支援 Python 3.5。請升級至 Python 3.6 或更新版本。

jax 0.1.61 (2020 年 3 月 17 日)#

  • GitHub 提交.

  • 修復了 Python 3.5 支援。這將是最後一個支援 Python 3.5 的 JAX 或 jaxlib 版本。

jax 0.1.60 (2020 年 3 月 17 日)#

  • GitHub 提交.

  • 新功能

    • jax.pmap() 具有 static_broadcast_argnums 參數,允許使用者指定應視為編譯時期常數並應廣播到所有裝置的參數。它的運作方式與 jax.jit() 中的 static_argnums 類似。

    • 改進了當追蹤器錯誤地儲存在全域狀態時的錯誤訊息。

    • 新增了 jax.nn.one_hot() 實用函數。

    • 新增了用於指數級加速高階自動微分的 jax.experimental.jet

    • jax.lax.broadcast_in_dim() 的引數新增了更多正確性檢查。

  • 最低 jaxlib 版本現在為 0.1.41。

jaxlib 0.1.40 (2020 年 3 月 4 日)#

  • 在 Jaxlib 中新增了對 TensorFlow profiler 的實驗性支援,允許從 TensorBoard 追蹤 CPU 和 GPU 計算。

  • 包含透過 NCCL 通訊的多主機 GPU 計算的原型支援。

  • 提高了 GPU 上 NCCL 集體運算的效能。

  • 新增了 TopK、CustomCallWithoutLayout、CustomCallWithLayout、IGammaGradA 和 RandomGamma 實作。

  • 支援在 XLA 編譯時已知的裝置分配。

jax 0.1.59 (2020 年 2 月 11 日)#

  • GitHub 提交.

  • 重大變更

    • 最低 jaxlib 版本現在為 0.1.38。

    • 透過移除 Jaxpr.freevarsJaxpr.bound_subjaxprs 簡化了 Jaxpr。呼叫基本運算 (xla_callxla_pmapsharded_callremat_call) 取得了一個新的參數 call_jaxpr,其中包含完全封閉的(無 constvars)jaxpr。此外,為基本運算新增了一個新的欄位 call_primitive

  • 新功能

    • lax.cond 的反向模式自動微分 (例如 grad),使其現在在兩種模式下都可微分 (#2091)

    • JAX 現在支援 DLPack,它允許以零複製方式與其他程式庫(例如 PyTorch)共用 CPU 和 GPU 陣列。

    • JAX GPU DeviceArrays 現在支援 __cuda_array_interface__,這是另一個用於與其他程式庫(例如 CuPy 和 Numba)共用 GPU 陣列的零複製協定。

    • JAX CPU 裝置緩衝區現在實作了 Python 緩衝區協定,允許在 JAX 和 NumPy 之間進行零複製緩衝區共用。

    • 新增了 JAX_SKIP_SLOW_TESTS 環境變數,以跳過已知速度較慢的測試。

jaxlib 0.1.39 (2020 年 2 月 11 日)#

  • 更新 XLA。

jaxlib 0.1.38 (2020 年 1 月 29 日)#

  • 不再支援 CUDA 9.0。

  • CUDA 10.2 wheels 現在預設為建置。

jax 0.1.58 (2020 年 1 月 28 日)#

值得注意的錯誤修復#

  • 隨著 Python 3 的升級,JAX 不再依賴 fastcache,這應有助於安裝。