變更日誌#
最佳瀏覽方式請點擊這裡。如需實驗性 Pallas API 的特定變更,請參閱Pallas 變更日誌。
JAX 遵循基於努力的版本控制;關於此項以及 JAX 的 API 相容性政策的討論,請參閱API 相容性。關於 Python 和 NumPy 版本支援政策,請參閱Python 和 NumPy 版本支援政策。
未發布#
jax 0.5.0 (2025 年 1 月 17 日)#
在此版本中,JAX 現在使用基於努力的版本控制。由於此版本對 PRNG 金鑰語意進行了重大變更,可能需要使用者更新其程式碼,因此我們將 JAX 的「meso」版本升級以表示這一點。
重大變更
預設啟用
jax_threefry_partitionable
(請參閱更新說明)。此版本不再支援 Mac x86 wheels。Mac ARM 當然仍然受到支援。如需近期討論,請參閱 https://github.com/jax-ml/jax/discussions/22936。
促成此決策的兩個主要因素
Mac x86 版本(僅限)有許多測試失敗和崩潰。我們寧願不發布版本,也不願發布損壞的版本。
Mac x86 硬體已停產,目前開發人員無法輕易取得。因此,即使我們想解決這類問題,也很困難。
如果社群願意協助支援該平台,我們願意重新新增對 Mac x86 的支援:特別是,在我們再次發布版本之前,我們需要 JAX 測試套件在 Mac x86 上乾淨俐落地通過。
變更
NumPy 最低版本現在為 1.25。NumPy 1.25 將維持最低支援版本,直到 2025 年 6 月。
SciPy 最低版本現在為 1.11。SciPy 1.11 將維持最低支援版本,直到 2025 年 6 月。
jax.numpy.einsum()
現在預設為optimize='auto'
而非optimize='optimal'
。這避免了在多個引數的情況下,trace-time 呈指數級擴展 (#25214)。jax.numpy.linalg.solve()
不再支援右側的批次 1D 引數。若要在這些情況下恢復先前的行為,請使用solve(a, b[..., None]).squeeze(-1)
。
新功能
jax.numpy.fft.fftn()
、jax.numpy.fft.rfftn()
、jax.numpy.fft.ifftn()
和jax.numpy.fft.irfftn()
現在支援超過 3 個維度的轉換,這在以前是上限。請參閱 #25606 以取得更多詳細資訊。透過新的
jax.ffi.register_ffi_type_id()
函數,新增了對 FFI 中使用者定義狀態的支援。AOT 降低
.as_text()
方法現在支援debug_info
選項,以在輸出中包含除錯資訊,例如原始碼位置。
棄用
從
jax.interpreters.xla
,abstractify
和pytype_aval_mappings
現在已棄用,已被jax.core
中同名的符號取代。jax.scipy.special.lpmn()
和jax.scipy.special.lpmn_values()
已棄用,原因是它們在 SciPy v1.15.0 中已棄用。目前沒有計畫用新的 API 取代這些已棄用的函數。jax.extend.ffi
子模組已移至jax.ffi
,先前的匯入路徑已棄用。
刪除
jax_enable_memories
旗標已刪除,且該旗標的行為預設為開啟。從
jax.lib.xla_client
中,先前已棄用的Device
和XlaRuntimeError
符號已移除;請改用jax.Device
和jax.errors.JaxRuntimeError
。在 JAX v0.4.32 中棄用後,
jax.experimental.array_api
模組已移除。自該版本以來,jax.numpy
直接支援 array API。
jax 0.4.38 (2024 年 12 月 17 日)#
重大變更
XlaExecutable.cost_analysis
現在傳回dict[str, float]
(而不是單一元素的list[dict[str, float]]
)。
變更
jax.tree.flatten_with_path
和jax.tree.map_with_path
已新增為對應tree_util
函數的捷徑。
棄用
內部
jax.core
命名空間中的許多 API 已棄用。大多數是 no-ops、很少使用,或可以由jax.extend.core
中同名的 API 取代;請參閱jax.extend
的文件,以取得有關這些半公開擴充功能的相容性保證資訊。已移除數個先前已棄用的 API,包括
從
jax.core
:check_eqn
、check_type
、check_valid_jaxtype
和non_negative_dim
。從
jax.lib.xla_bridge
:xla_client
和default_backend
。從
jax.lib.xla_client
:_xla
和bfloat16
。從
jax.numpy
:round_
。
新功能
jax.export.export()
可以與使用jax.sharding.AbstractMesh()
建構的分片一起用於裝置多型匯出。請參閱 jax.export 文件。新增
jax.lax.split()
。這是jax.numpy.split()
的基本版本,新增的原因是它在自動微分期間產生更緊湊的轉置。
jax 0.4.37 (2024 年 12 月 9 日)#
這是 jax 0.4.36 的修補程式版本。此版本僅發布「jax」。
jax 0.4.37#
錯誤修復
修正了如果引數命名為
f
時,jit
會發生錯誤的錯誤 (#25329)。修正了如果使用者為 flatten 和 flatten_with_path 註冊具有不同輔助資料的 pytree 節點類別,則在
jax.lax.while_loop()
中會擲回index out of range
錯誤的錯誤。釘選了新的 libtpu 版本 (0.0.6),修正了 TPU v6e 上的編譯器錯誤。
jax 0.4.36 (2024 年 12 月 5 日)#
重大變更
此版本採用了「stackless」,這是 JAX 追蹤機制的內部變更。我們使追蹤分派純粹成為上下文的函數,而不是上下文和資料的函數。這讓我們刪除了大量用於管理資料相依追蹤的機制:levels、sublevels、
post_process_call
、new_base_main
、custom_bind
等等。此變更應僅影響使用 JAX 內部機制的使用者。如果您使用了 JAX 內部元件,則可能需要更新您的程式碼 (請參閱 https://github.com/jax-ml/jax/commit/c36e1f7c1ad4782060cbc8e8c596d85dfb83986f 以取得如何操作的線索)。使用 JAX 函式庫的內部元件也可能存在版本偏差問題。如果您發現此變更破壞了您未使用 JAX 內部元件的程式碼,請嘗試使用
config.jax_data_dependent_tracing_fallback
旗標作為權宜之計,如果您在更新程式碼時需要協助,請提交錯誤報告。自 2024 年 7 月起,JAX 版本 0.4.31 已棄用搭配
native_serialization=False
或enable_xla=False
的jax.experimental.jax2tf.convert()
。現在我們已移除對這些使用案例的支援。jax2tf
搭配原生序列化仍將受到支援。在
jax.interpreters.xla
中,xb
、xc
和xe
符號在 JAX v0.4.31 中被棄用後已移除。請改用xb = jax.lib.xla_bridge
、xc = jax.lib.xla_client
和xe = jax.lib.xla_extension
。已移除已棄用的模組
jax.experimental.export
。它在 JAX v0.4.30 中被jax.export
取代。請參閱遷移指南,以取得有關遷移至新 API 的資訊。已移除
jax.nn.softmax()
和jax.nn.log_softmax()
的initial
參數,該參數已在 v0.4.27 中棄用。現在對類型化的 PRNG 金鑰 (即由 :func:
jax.random.key
產生的金鑰) 呼叫np.asarray
會引發錯誤。 之前,這會傳回純量物件陣列。已移除
jax.export
中以下已棄用的方法和函式jax.export.DisabledSafetyCheck.shape_assertions
:它已經沒有任何作用。jax.export.Exported.lowering_platforms
:請改用platforms
。jax.export.Exported.mlir_module_serialization_version
:請改用calling_convention_version
。jax.export.Exported.uses_shape_polymorphism
:請改用uses_global_constants
。用於
jax.export.export()
的lowering_platforms
關鍵字參數:請改用platforms
。
已移除
jax.export.symbolic_args_specs()
中的關鍵字參數symbolic_scope
和symbolic_constraints
。它們已於 2024 年 6 月棄用。請改用scope
和constraints
。追蹤器 (tracer) 的雜湊 (hashing) 功能已在 0.4.30 版本中棄用,現在會導致
TypeError
。重構:JAX 建置 CLI (build/build.py) 現在使用子命令結構,並取代先前的 build.py 用法。執行
python build/build.py --help
以取得更多詳細資訊。新子命令選項的簡要概述build
:建置 JAX wheel 套件。例如,python build/build.py build --wheels=jaxlib,jax-cuda-pjrt
requirements_update
:更新 requirements_lock.txt 檔案。
jax.scipy.linalg.toeplitz()
現在對多維輸入執行隱含的批次處理 (batching)。若要恢復先前的行為,您可以對函式輸入呼叫jax.numpy.ravel()
。jax.scipy.special.gamma()
和jax.scipy.special.gammasgn()
現在針對負整數輸入傳回 NaN,以符合 SciPy 的行為 (來自 https://github.com/scipy/scipy/pull/21827)。在 v0.4.26 中棄用後,
jax.clear_backends
已移除。我們從保證匯出穩定性的自訂呼叫清單中移除了自訂呼叫 “__gpu$xla.gpu.triton”。這是因為此自訂呼叫依賴於 Triton IR,而 Triton IR 並不保證穩定。如果您需要匯出使用此自訂呼叫的程式碼,您可以使用
disabled_checks
參數。 請參閱文件以取得更多詳細資訊。
新功能
jax.jit()
新增了一個compiler_options: dict[str, Any]
參數,用於將編譯選項傳遞給 XLA。目前它尚未記錄在文件中,並且可能會變動。jax.tree_util.register_dataclass()
現在允許透過dataclasses.field()
以內聯方式宣告元數據欄位。請參閱函式文件以取得範例。GPU 現在支援
jax.lax.linalg.eig()
和相關的jax.numpy
函式 (jax.numpy.linalg.eig()
和jax.numpy.linalg.eigvals()
)。請參閱 #24663 以取得更多詳細資訊。新增了兩個新的組態旗標
jax_exec_time_optimization_effort
和jax_memory_fitting_effort
,以控制編譯器在最小化執行時間和記憶體使用量上花費的精力。有效值介於 -1.0 和 1.0 之間,預設值為 0.0。
錯誤修復
修正了 GPU 實作的 LU 和 QR 分解在批次大小接近 int32 最大值時會導致索引溢位的錯誤。請參閱 #24843 以取得更多詳細資訊。
棄用
jax.lib.xla_extension.ArrayImpl
和jax.lib.xla_client.ArrayImpl
已棄用;請改用jax.Array
。jax.lib.xla_extension.XlaRuntimeError
已棄用;請改用jax.errors.JaxRuntimeError
。
jax 0.4.35 (2024 年 10 月 22 日)#
重大變更
jax.numpy.isscalar()
現在針對任何零維的類陣列物件傳回 True。先前,它僅針對具有弱 dtype 的零維類陣列物件傳回 True。自 2024 年 3 月起,JAX 版本 0.4.26 已棄用
jax.experimental.host_callback
。現在我們已將其移除。請參閱 #20385 以取得替代方案的討論。
變更
jax.lax.FftType
作為 FFT 運算的列舉公開名稱引入。jax.lib.xla_client.FftType
這個半公開 API 已棄用。TPU:JAX 現在從
libtpu
套件而不是libtpu-nightly
套件安裝 TPU 支援。在接下來的幾個版本中,JAX 將釘住libtpu-nightly
和libtpu
的空版本,以簡化過渡;該依賴性將在 2025 年第一季移除。
棄用
jax.lib.xla_client.PaddingType
這個半公開 API 已棄用。沒有 JAX API 使用此類型,因此沒有替代方案。在
vmap
下,jax.pure_callback()
和jax.extend.ffi.ffi_call()
的預設行為已被棄用,並且這些函式的vectorized
參數也被棄用。應改用vmap_method
參數以獲得更明確定義的行為。請參閱 #23881 中的討論以取得更多詳細資訊。jax.lib.xla_client.register_custom_call_target
這個半公開 API 已棄用。請改用 JAX FFI。jax.lib.xla_client.dtype_to_etype
、jax.lib.xla_client.ops
、jax.lib.xla_client.shape_from_pyval
、jax.lib.xla_client.PrimitiveType
、jax.lib.xla_client.Shape
、jax.lib.xla_client.XlaBuilder
和jax.lib.xla_client.XlaComputation
這些半公開 API 已棄用。請改用 StableHLO。
jax 0.4.34 (2024 年 10 月 4 日)#
新功能
此版本包含適用於 Python 3.13 的 wheel 檔案。目前尚不支援自由執行緒模式。
已新增
jax.errors.JaxRuntimeError
作為先前私有的XlaRuntimeError
類型的公開別名。
重大變更
jax_pmap_no_rank_reduction
旗標預設設定為True
。現在對 pmap 結果執行 array[0] 會引入 reshape (請改用 array[0:1])。
每個分片形狀 (可透過 jax_array.addressable_shards 或 jax_array.addressable_data(0) 存取) 現在具有前導 (1, …)。請相應地更新直接存取分片的程式碼。每個分片形狀的秩 (rank) 現在與全域形狀的秩相符,這與 jit 的行為相同。這避免了將結果從 pmap 傳遞到 jit 時產生高成本的 reshape。
自 2024 年 3 月起,JAX 版本 0.4.26 已棄用
jax.experimental.host_callback
。現在我們將--jax_host_callback_legacy
組態值的預設值設定為True
,這表示如果您的程式碼使用jax.experimental.host_callback
API,則這些 API 呼叫將以新的jax.experimental.io_callback
API 實作。如果這破壞了您的程式碼,在非常有限的時間內,您可以將--jax_host_callback_legacy
設定為True
。我們很快就會移除該組態選項,因此您應該改為轉換為使用新的 JAX 回呼 API。請參閱 #20385 以取得討論。
棄用
在
jax.numpy.trim_zeros()
中,非類陣列引數或ndim != 1
的類陣列引數現在已棄用,並且在未來將導致錯誤。在 JAX v0.4.30 中棄用後,已移除內部美化列印工具
jax.core.pp_*
。jax.lib.xla_client.Device
已棄用;請改用jax.Device
。jax.lib.xla_client.XlaRuntimeError
已棄用。請改用jax.errors.JaxRuntimeError
。
刪除
已刪除
jax.xla_computation
。自 0.4.30 JAX 版本中棄用以來已 3 個月。請使用 AOT API 以取得與jax.xla_computation
相同的功能。jax.xla_computation(fn)(*args, **kwargs)
可以替換為jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')
。您也可以使用
jax.stages.Lowered
的.out_info
屬性來取得輸出資訊 (例如樹狀結構、形狀和 dtype)。對於跨後端降低 (lowering),您可以將
jax.xla_computation(fn, backend='tpu')(*args, **kwargs)
替換為jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')
。
jax.ShapeDtypeStruct
不再接受named_shape
參數。該參數僅由xmap
使用,而xmap
已在 0.4.31 中移除。jax.tree.map(f, None, non-None)
先前發出DeprecationWarning
,現在在未來版本的 jax 中會引發錯誤。None
僅是其自身的樹狀結構前綴 (tree-prefix)。若要保留目前的行為,您可以要求jax.tree.map
將None
視為葉節點值,方法是寫入:jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)
。jax.sharding.XLACompatibleSharding
已移除。請使用jax.sharding.Sharding
。
錯誤修復
修正了如果提供非布林值輸入並指定
dtype=bool
,jax.numpy.cumsum()
會產生不正確輸出的錯誤。編輯
jax.numpy.ldexp()
的實作以取得正確的梯度。
jax 0.4.33 (2024 年 9 月 16 日)#
這是 jax 0.4.32 之上的修補程式版本,修正了該版本中發現的兩個錯誤。
在 JAX 0.4.32 釘選的 libtpu 版本中發現了一個僅限於 TPU 的資料損壞錯誤,該錯誤僅在同一個作業中存在多個 TPU 切片時才會顯現,例如,如果在多個 v5e 切片上進行訓練。此版本透過釘選 libtpu 的固定版本來修正該問題。
此版本修正了 CPU 上 F64 tanh 的不準確結果 (#23590)。
jax 0.4.32 (2024 年 9 月 11 日)#
注意:此版本已從 PyPi 撤下,因為 TPU 上存在資料損壞錯誤。請參閱 0.4.33 版本注意事項以取得更多詳細資訊。
新功能
新增了
jax.extend.ffi.ffi_call()
和jax.extend.ffi.ffi_lowering()
,以支援使用新的 Foreign function interface (FFI) 從 JAX 與自訂 C++ 和 CUDA 程式碼介接。
變更
jax_enable_memories
旗標預設設定為True
。jax.numpy
現在支援 Python Array API Standard 的 v2023.12 版本。請參閱 Python Array API 標準 以取得更多資訊。在更多情況下,CPU 後端上的計算現在可以非同步分派。先前,非平行計算始終同步分派。您可以透過設定
jax.config.update('jax_cpu_enable_async_dispatch', False)
來恢復舊的行為。新增了新的
jax.process_indices()
函式,以取代在 JAX v0.2.13 中棄用的jax.host_ids()
函式。為了與
numpy.fabs
的行為對齊,jax.numpy.fabs
已修改為不再支援complex dtypes
。如果
nodetype
是 dataclass,則jax.tree_util.register_dataclass
現在會檢查data_fields
和meta_fields
是否包含所有具有init=True
的 dataclass 欄位,並且僅包含這些欄位。多個
jax.numpy
函式現在具有完整的ufunc
介面,包括add
、multiply
、bitwise_and
、bitwise_or
、bitwise_xor
、logical_and
、logical_and
和logical_and
。在未來的版本中,我們計劃將這些擴展到其他 ufunc。新增了
jax.lax.optimization_barrier()
,允許使用者防止編譯器最佳化 (例如常見子表達式消除) 並控制排程。
重大變更
MHLO MLIR 方言 (
jax.extend.mlir.mhlo
) 已移除。請改用stablehlo
方言。
棄用
在 JAX v0.4.27 中棄用後,不再允許
jax.numpy.clip()
和jax.numpy.hypot()
的複數輸入。已棄用以下 API
jax.lib.xla_bridge.xla_client
:請直接使用jax.lib.xla_client
。jax.lib.xla_bridge.get_backend
:請使用jax.extend.backend.get_backend()
。jax.lib.xla_bridge.default_backend
:請使用jax.extend.backend.default_backend()
。
jax.experimental.array_api
模組已棄用,不再需要匯入它才能使用 Array API。jax.numpy
直接支援 array API;請參閱 Python Array API 標準 以取得更多資訊。內部工具程式
jax.core.check_eqn
、jax.core.check_type
和jax.core.check_valid_jaxtype
現在已棄用,並將在未來版本中移除。在 NumPy 2.0 中移除對應的 API 後,
jax.numpy.round_
已棄用。請改用jax.numpy.round()
。將 DLPack capsule 傳遞給
jax.dlpack.from_dlpack()
已棄用。jax.dlpack.from_dlpack()
的引數應該是來自另一個實作__dlpack__
協定的框架的陣列。
jaxlib 0.4.32 (2024 年 9 月 11 日)#
注意:此版本已從 PyPi 撤下,因為 TPU 上存在資料損壞錯誤。請參閱 0.4.33 版本注意事項以取得更多詳細資訊。
重大變更
此版本的 jaxlib 切換到新的 CPU 後端版本,該版本應能更快地編譯並更好地利用平行處理。如果您因本次變更而遇到任何問題,可以暫時啟用舊的 CPU 後端,方法是設定環境變數
XLA_FLAGS=--xla_cpu_use_thunk_runtime=false
。如果您需要執行此操作,請提交 JAX 錯誤報告,並附上重現步驟。新增了 Hermetic CUDA 支援。Hermetic CUDA 使用特定的可下載 CUDA 版本,而不是使用者本機安裝的 CUDA。 Bazel 將下載 CUDA、CUDNN 和 NCCL 發行版本,然後在各種 Bazel 目標中使用 CUDA 函式庫和工具作為依賴項。這為 JAX 及其支援的 CUDA 版本實現了更可重現的建置。
變更
新增了 SparseCore 分析功能。
JAX 現在支援在 TPUv5p 晶片上分析 SparseCore。這些追蹤將可在 Tensorboard Profiler 的 TraceViewer 中檢視。
jax 0.4.31 (2024 年 7 月 29 日)#
刪除
xmap 已刪除。請使用
shard_map()
作為替代方案。
變更
最低 CuDNN 版本為 v9.1。這在先前的版本中也是如此,但我們現在正式宣告此版本約束。
最低 Python 版本現在為 3.10。3.10 將保持最低支援版本直到 2025 年 7 月。
最低 NumPy 版本現在為 1.24。NumPy 1.24 將保持最低支援版本直到 2024 年 12 月。
最低 SciPy 版本現在為 1.10。SciPy 1.10 將保持最低支援版本直到 2025 年 1 月。
現在
jax.numpy.ceil()
、jax.numpy.floor()
和jax.numpy.trunc()
會回傳與輸入相同資料類型 (dtype) 的輸出,也就是說,不再將整數或布林值輸入向上轉型 (upcast) 為浮點數。libdevice.10.bc
不再與 CUDA wheels 捆綁。它必須作為本地 CUDA 安裝的一部分安裝,或透過 NVIDIA 的 CUDA pip wheels 安裝。現在
jax.experimental.pallas.BlockSpec
預期block_shape
在index_map
之前 傳遞。舊的參數順序已被棄用,並將在未來版本中移除。更新了 GPU 裝置的 repr 表示方式,使其與 TPU/CPU 更一致。例如,
cuda(id=0)
現在會顯示為CudaDevice(id=0)
。為
jax.Array
新增了device
屬性和to_device
方法,作為 JAX Array API 支援的一部分。
棄用
移除了許多先前已棄用的與多型形狀 (polymorphic shapes) 相關的內部 API。從
jax.core
中:移除了canonicalize_shape
、dimension_as_value
、definitely_equal
和symbolic_equal_dim
。HLO lowering 規則不應再將 singleton ir.Values 包裹在元組中。而是回傳未包裹的 singleton ir.Values。未來版本的 JAX 將移除對包裹值的支援。
使用
native_serialization=False
或enable_xla=False
的jax.experimental.jax2tf.convert()
現在已被棄用,此支援將在未來版本中移除。自 JAX 0.4.16 (2023 年 9 月) 以來,原生序列化一直是預設設定。先前已棄用的函式
jax.random.shuffle
已被移除;請改用jax.random.permutation
並搭配independent=True
。
jaxlib 0.4.31 (2024 年 7 月 29 日)#
錯誤修復
修正了一個錯誤,該錯誤導致 jit 調度快速路徑錯誤處理了 jit 的負 static_argnums。
修正了一個錯誤,該錯誤導致奇異矩陣批次的三角解產生無意義的有限值,而不是 inf 或 nan (#3589, #15429)。
jax 0.4.30 (2024 年 6 月 18 日)#
變更
JAX 支援 ml_dtypes >= 0.2。在 0.4.29 版本中,ml_dtypes 版本已升級至 0.4.0,但在此版本中已回退,以便讓 TensorFlow 和 JAX 的使用者有更多時間遷移到較新的 TensorFlow 版本。
jax.experimental.mesh_utils
現在可以為 TPU v5e 建立有效率的網格 (mesh)。jax 現在直接依賴 jaxlib。此變更由 CUDA 外掛程式切換啟用:不再有多個 jaxlib 變體。您可以使用
pip install jax
安裝僅限 CPU 的 jax,無需額外套件。新增了用於匯出和序列化 JAX 函式的 API。此功能以前存在於
jax.experimental.export
(正在棄用),現在將位於jax.export
中。請參閱文件。
棄用
內部美化列印 (pretty-printing) 工具
jax.core.pp_*
已被棄用,並將在未來版本中移除。tracer 的雜湊 (Hashing) 已被棄用,並將在未來 JAX 版本中導致
TypeError
。先前情況如此,但在最近幾個 JAX 版本中出現了無意的回歸。jax.experimental.export
已被棄用。請改用jax.export
。請參閱遷移指南。在大多數情況下,現在不建議使用陣列來代替 dtype 傳遞;例如,對於陣列
x
和y
,x.astype(y)
將引發警告。若要靜音警告,請使用x.astype(y.dtype)
。jax.xla_computation
已被棄用,並將在未來版本中移除。請使用 AOT API 來取得與jax.xla_computation
相同的功能。jax.xla_computation(fn)(*args, **kwargs)
可以替換為jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')
。您也可以使用
jax.stages.Lowered
的.out_info
屬性來取得輸出資訊 (例如樹狀結構、形狀和 dtype)。對於跨後端降低 (lowering),您可以將
jax.xla_computation(fn, backend='tpu')(*args, **kwargs)
替換為jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')
。
jaxlib 0.4.30 (2024 年 6 月 18 日)#
已移除對 monolithic CUDA jaxlibs 的支援。您必須使用基於外掛程式的安裝 (
pip install jax[cuda12]
或pip install jax[cuda12_local]
)。
jax 0.4.29 (2024 年 6 月 10 日)#
變更
我們預期這將是 JAX 和 jaxlib 最後一個支援 monolithic CUDA jaxlib 的版本。未來版本將使用 CUDA 外掛程式 jaxlib (例如
pip install jax[cuda12]
)。JAX 現在需要 ml_dtypes 版本 0.4.0 或更新版本。
移除了對舊版
jax.experimental.export
API 用法的向後相容性支援。不再可能使用from jax.experimental.export import export
,您應該改用from jax.experimental import export
。移除的功能自 0.4.24 以來已被棄用。為
jax.tree.all()
和jax.tree_util.tree_all()
新增了is_leaf
參數。
棄用
jax.sharding.XLACompatibleSharding
已被棄用。請使用jax.sharding.Sharding
。jax.experimental.Exported.in_shardings
已重新命名為jax.experimental.Exported.in_shardings_hlo
。out_shardings
也做了相同的更名。舊名稱將在 3 個月後移除。移除了許多先前已棄用的 API
jax.numpy.linalg.matrix_rank()
的tol
參數已被棄用,並將很快移除。請改用rtol
。jax.numpy.linalg.pinv()
的rcond
參數已被棄用,並將很快移除。請改用rtol
。已移除已棄用的
jax.config
子模組。若要設定 JAX,請使用import jax
,然後透過jax.config
參考 config 物件。jax.random
API 不再接受批次處理的金鑰 (keys),儘管先前有些 API 無意中接受了。展望未來,我們建議在這種情況下明確使用jax.vmap()
。在
jax.scipy.special.beta()
中,為了與其他beta
API 一致,x
和y
參數已重新命名為a
和b
。
新功能
新增了
jax.experimental.Exported.in_shardings_jax()
,以從儲存在Exported
物件中的 HloShardings 建構可用於 JAX API 的分片 (shardings)。
jaxlib 0.4.29 (2024 年 6 月 10 日)#
錯誤修復
修正了一個錯誤,該錯誤導致 XLA 錯誤地分片了一些串聯 (concatenation) 操作,這表現為累積歸約 (cumulative reductions) 的輸出不正確 (#21403)。
修正了一個錯誤,該錯誤導致 XLA:CPU 錯誤編譯了某些矩陣乘法融合 (matmul fusions) (https://github.com/openxla/xla/pull/13301)。
修正了 GPU 上的編譯器崩潰問題 (https://github.com/jax-ml/jax/issues/21396)。
棄用
jax.tree.map(f, None, non-None)
現在會發出DeprecationWarning
,並將在未來版本的 jax 中引發錯誤。None
僅是其自身的樹狀結構前綴 (tree-prefix)。若要保留目前行為,您可以要求jax.tree.map
將None
視為葉值,方法是寫入:jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)
。
jax 0.4.28 (2024 年 5 月 9 日)#
錯誤修復
還原了對
make_jaxpr
的變更,該變更破壞了 Equinox (#21116)。
棄用與移除
現在已移除
jax.numpy.sort()
和jax.numpy.argsort()
的kind
參數。請改用stable=True
或stable=False
。從
jax.experimental.pallas.gpu
模組中移除了get_compute_capability
。請改用 GPU 裝置的compute_capability
屬性,該屬性由jax.devices()
或jax.local_devices()
回傳。正在棄用
jax.numpy.reshape()
的newshape
參數,並將很快移除。請改用shape
。
變更
此版本的最低 jaxlib 版本為 0.4.27。
jaxlib 0.4.28 (2024 年 5 月 9 日)#
錯誤修復
修正了 Python 3.10 或更早版本中 Array 和 JIT Python 物件的類型名稱中的記憶體損壞錯誤。
修正了 CUDA 12.4 下的警告
'+ptx84' is not a recognized feature for this target
。修正了 CPU 上編譯速度緩慢的問題。
變更
Windows 版本現在使用 Clang 而非 MSVC 建置。
jax 0.4.27 (2024 年 5 月 7 日)#
新功能
新增了
jax.numpy.unstack()
和jax.numpy.cumulative_sum()
,遵循它們在 array API 2023 標準中的新增,該標準即將被 NumPy 採用。新增了一個新的 config 選項
jax_cpu_collectives_implementation
,以選擇 CPU 後端使用的跨進程集體運算 (cross-process collective operations) 的實作方式。可用的選項包括'none'
(預設)、'gloo'
和'mpi'
(需要 jaxlib 0.4.26)。如果設定為'none'
,則會停用跨進程集體運算。
變更
現在
jax.pure_callback()
、jax.experimental.io_callback()
和jax.debug.callback()
使用jax.Array
而非np.ndarray
。您可以透過在將引數傳遞給 callback 之前,透過jax.tree.map(np.asarray, args)
轉換引數來恢復舊行為。現在
complex_arr.astype(bool)
遵循與 NumPy 相同的語義,當complex_arr
等於0 + 0j
時回傳 False,否則回傳 True。現在
core.Token
是一個非平凡類別,它包裹了jax.Array
。它可以被建立並在計算中傳入和傳出,以建立依賴關係。singleton 物件core.token
已被移除,使用者現在應該建立並使用新的core.Token
物件。在 GPU 上,Threefry PRNG 實作預設不再降低到核心呼叫 (kernel call)。此選擇可以改善執行時記憶體使用量,但會犧牲編譯時間成本。可以使用
jax.config.update('jax_threefry_gpu_kernel_lowering', True)
恢復產生核心呼叫的先前行為。如果新的預設設定導致問題,請提交錯誤報告。否則,我們打算在未來版本中移除此旗標。
棄用與移除
Pallas 現在專門使用 XLA 在 GPU 上編譯核心。透過 Triton Python API 的舊 lowering pass 已被移除,
JAX_TRITON_COMPILE_VIA_XLA
環境變數不再有任何作用。jax.numpy.clip()
有一個新的引數簽章:a
、a_min
和a_max
已被棄用,改用x
(僅限位置引數)、min
和max
(#20550)。JAX 陣列的
device()
方法已移除,自 JAX v0.4.21 以來已被棄用。請改用arr.devices()
。已棄用
jax.nn.softmax()
和jax.nn.log_softmax()
的initial
參數;現在支援 softmax 的空輸入,而無需設定此參數。在
jax.jit()
中,傳遞無效的static_argnums
或static_argnames
現在會導致錯誤,而不是警告。現在最低 jaxlib 版本為 0.4.23。
當傳遞複數值輸入給
jax.numpy.hypot()
函式時,現在會發出棄用警告。當棄用完成時,這將引發錯誤。現在
jax.numpy.nonzero()
、jax.numpy.where()
和相關函式的純量引數會引發錯誤,這與 NumPy 中的類似變更一致。config 選項
jax_cpu_enable_gloo_collectives
已被棄用。請改用jax.config.update('jax_cpu_collectives_implementation', 'gloo')
。在 JAX v0.4.22 中棄用後,已移除
jax.Array.device_buffer
和jax.Array.device_buffers
方法。請改用jax.Array.addressable_shards
和jax.Array.addressable_data()
。現在
jax.numpy.where
的condition
、x
和y
參數僅限於位置引數,這與 JAX v0.4.21 中棄用關鍵字參數一致。現在必須透過關鍵字指定
jax.lax.linalg
中函式的非陣列引數。先前,這會引發 DeprecationWarning。現在在幾個 :func:
jax.numpy
API 中需要類陣列引數,包括apply_along_axis()
、apply_over_axes()
、inner()
、outer()
、cross()
、kron()
和lexsort()
。
錯誤修復
當
copy=True
時,jax.numpy.astype()
現在始終會回傳副本。先前,當輸出陣列與輸入陣列具有相同的 dtype 時,不會建立副本。這可能會導致一些記憶體使用量增加。預設值設定為copy=False
,以保留向後相容性。
jaxlib 0.4.27 (2024 年 5 月 7 日)#
jax 0.4.26 (2024 年 4 月 3 日)#
新功能
新增了
jax.numpy.trapezoid()
,遵循 NumPy 2.0 中新增的此函式。
變更
現在複數值
jax.numpy.geomspace()
選擇與 NumPy 2.0 一致的對數螺旋分支 (logarithmic spiral branch)。lax.rng_bit_generator
的行為,進而'rbg'
和'unsafe_rbg'
PRNG 實作在jax.vmap
下的行為 已變更,因此對金鑰進行映射 (mapping) 只會從批次中的第一個金鑰產生隨機數。文件現在使用
jax.random.key
來建構 PRNG 金鑰陣列,而不是jax.random.PRNGKey
。
棄用與移除
jax.tree_map()
已被棄用;請改用jax.tree.map
,或為了與舊版 JAX 向後相容,請使用jax.tree_util.tree_map()
。jax.clear_backends()
已被棄用,因為它不一定會執行其名稱所暗示的操作,並且可能會導致意想不到的後果,例如,它不會銷毀現有的後端並釋放相應的擁有資源。如果您只想清理編譯快取,請使用jax.clear_caches()
。為了向後相容性,或者如果您真的需要切換/重新初始化預設後端,請使用jax.extend.backend.clear_backends()
。jax.experimental.maps
模組和jax.experimental.maps.xmap
已被棄用。請使用jax.experimental.shard_map
或jax.vmap
以及spmd_axis_name
引數來表達 SPMD 裝置平行計算。jax.experimental.host_callback
模組已被棄用。請改用新的 JAX 外部 callback。新增了JAX_HOST_CALLBACK_LEGACY
旗標,以協助轉換到新的 callback。請參閱 #20385 以進行討論。現在將無法轉換為 JAX 陣列的引數傳遞給
jax.numpy.array_equal()
和jax.numpy.array_equiv()
會導致例外。已移除已棄用的旗標
jax_parallel_functions_output_gda
。此旗標早已被棄用,且沒有任何作用;使用它是空操作 (no-op)。先前已棄用的匯入
jax.interpreters.ad.config
和jax.interpreters.ad.source_info_util
現在已被移除。請改用jax.config
和jax.extend.source_info_util
。JAX 匯出不再支援舊版的序列化版本。自 2023 年 10 月 27 日起已支援版本 9,並自 2024 年 2 月 1 日起成為預設版本。請參閱 版本說明。此變更可能會破壞設定低於 9 的特定 JAX 序列化版本的用戶端。
jaxlib 0.4.26 (2024 年 4 月 3 日)#
變更
JAX 現在僅支援 CUDA 12.1 或更新版本。已停止支援 CUDA 11.8。
JAX 現在支援 NumPy 2.0。
jax 0.4.25 (2024 年 2 月 26 日)#
新功能
新增了 CUDA Array Interface 匯入支援 (需要 jaxlib 0.4.24)。
JAX 陣列現在支援 NumPy 風格的純量布林索引,例如
x[True]
或x[False]
。新增了
jax.tree
模組,為參考jax.tree_util
中的函式提供了更方便的介面。jax.tree.transpose()
(即jax.tree_util.tree_transpose()
) 現在接受inner_treedef=None
,在這種情況下,將自動推斷內部 treedef。
變更
Pallas 現在使用 XLA 而非 Triton Python API 來編譯 Triton 核心。您可以將
JAX_TRITON_COMPILE_VIA_XLA
環境變數設定為"0"
,以還原為舊的行為。在 v0.4.24 版本中移除的
jax.interpreters.xla
中數個已棄用的 API,已在 v0.4.25 版本中重新加入,包括backend_specific_translations
、translations
、register_translation
、xla_destructure
、TranslationRule
、TranslationContext
和XLAOp
。這些 API 仍被視為已棄用,並將在未來有更好的替代方案時再次移除。請參閱 #19816 以了解更多討論。
棄用與移除
jax.numpy.linalg.solve()
現在針對b.ndim > 1
的批次 1D 求解顯示棄用警告。未來,這些將被視為批次 2D 求解。將非純量陣列轉換為 Python 純量現在會引發錯誤,無論陣列大小為何。先前,對於大小為 1 的非純量陣列,會引發棄用警告。這遵循 NumPy 中類似的棄用。
先前已棄用的組態 API 已按照標準的 3 個月棄用週期移除(請參閱 API 相容性)。這些包括
jax.config.config
物件,以及jax.config
的define_*_state
和DEFINE_*
方法。
透過
import jax.config
匯入jax.config
子模組已被棄用。若要設定 JAX,請使用import jax
,然後透過jax.config
參考組態物件。最低 jaxlib 版本現在為 0.4.20。
jaxlib 0.4.25 (2024 年 2 月 26 日)#
jax 0.4.24 (2024 年 2 月 6 日)#
變更
JAX 降級至 StableHLO 不再依賴實體裝置。如果您的基本運算包裝了自訂分割 (custom_partitioning) 或 JAX 回呼 (callbacks) 在降級規則中,即傳遞給
mlir.register_lowering
的rule
參數的函式,則將您的基本運算新增至jax._src.dispatch.prim_requires_devices_during_lowering
集合。這是必要的,因為自訂分割和 JAX 回呼需要實體裝置才能在降級期間建立Sharding
。這是一個暫時狀態,直到我們可以在沒有實體裝置的情況下建立Sharding
。jax.numpy.argsort()
和jax.numpy.sort()
現在支援stable
和descending
引數。形狀多型 (shape polymorphism) 處理的一些變更(用於
jax.experimental.jax2tf
和jax.experimental.export
中)更清晰的符號運算式美觀列印 (#19227)
新增了在維度變數上指定符號約束的功能。這使得形狀多型更具表現力,並提供了一種方法來解決關於不等式推理的限制。請參閱 https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints。
隨著符號約束的加入 (#19235),我們現在認為來自不同範圍的維度變數是不同的,即使它們具有相同的名稱。來自不同範圍的符號運算式無法互動,例如在算術運算中。
jax.experimental.jax2tf.convert()
、jax.experimental.export.symbolic_shape()
、jax.experimental.export.symbolic_args_specs()
引入了範圍。符號運算式e
的範圍可以使用e.scope
讀取,並傳遞到上述函式中,以指示它們在給定範圍內建構符號運算式。請參閱 https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints。簡化且更快速的相等性比較,當我們考慮兩個符號維度的差的正規化形式簡化為 0 時,則認為它們相等 (#19231;請注意,這可能會導致使用者可見的行為變更)
改進了不確定不等式比較的錯誤訊息 (#19235)。
已棄用
core.non_negative_dim
API (最近引入),並引入了core.max_dim
和core.min_dim
(#18953) 以表示符號維度的max
和min
。您可以使用core.max_dim(d, 0)
來代替core.non_negative_dim(d)
。已棄用
shape_poly.is_poly_dim
,改用export.is_symbolic_dim
(#19282)。已棄用
export.args_specs
,改用export.symbolic_args_specs ({jax-issue}
#19283`)。已棄用
shape_poly.PolyShape
和jax2tf.PolyShape
,針對多型形狀規格使用字串 (#19284)。JAX 預設原生序列化版本現在為 9。這與
jax.experimental.jax2tf
和jax.experimental.export
相關。請參閱 版本號碼描述。
重構了
jax.experimental.export
的 API。現在您應該使用from jax.experimental import export
,而不是from jax.experimental.export import export
。舊的匯入方式在 3 個月的棄用期內將繼續運作。具有
return_inverse = True
的jax.numpy.unique()
會傳回重新塑形為輸入維度的反向索引,這遵循 NumPy 2.0 中numpy.unique()
的類似變更。jax.numpy.sign()
現在針對非零複數輸入傳回x / abs(x)
。這與 NumPy 2.0 版本中numpy.sign()
的行為一致。具有
return_sign=True
的jax.scipy.special.logsumexp()
現在針對複數符號使用 NumPy 2.0 慣例x / abs(x)
。這與 SciPy v1.13 中scipy.special.logsumexp()
的行為一致。JAX 現在支援匯入和匯出布林值 DLPack 類型。先前,布林值無法匯入,且會匯出為整數。
棄用與移除
許多先前已棄用的函式已按照標準的 3+ 個月棄用週期移除(請參閱 API 相容性)。這包括
來自
jax.core
的:TracerArrayConversionError
、TracerIntegerConversionError
、UnexpectedTracerError
、as_hashable_function
、collections
、dtypes
、lu
、map
、namedtuple
、partial
、pp
、ref
、safe_zip
、safe_map
、source_info_util
、total_ordering
、traceback_util
、tuple_delete
、tuple_insert
和zip
。來自
jax.lax
的:dtypes
、itertools
、naryop
、naryop_dtype_rule
、standard_abstract_eval
、standard_naryop
、standard_primitive
、standard_unop
、unop
和unop_dtype_rule
。jax.linear_util
子模組及其所有內容。jax.prng
子模組及其所有內容。來自
jax.random
的:PRNGKeyArray
、KeyArray
、default_prng_impl
、threefry_2x32
、threefry2x32_key
、threefry2x32_p
、rbg_key
和unsafe_rbg_key
。來自
jax.tree_util
的:register_keypaths
、AttributeKeyPathEntry
和GetItemKeyPathEntry
。來自
jax.interpreters.xla
的:backend_specific_translations
、translations
、register_translation
、xla_destructure
、TranslationRule
、TranslationContext
、axis_groups
、ShapedArray
、ConcreteArray
、AxisEnv
、backend_compile
和XLAOp
。來自
jax.numpy
的:NINF
、NZERO
、PZERO
、row_stack
、issubsctype
、trapz
和in1d
。來自
jax.scipy.linalg
的:tril
和triu
。
先前已棄用的方法
PRNGKeyArray.unsafe_raw_array
已移除。改為使用jax.random.key_data()
。bool(empty_array)
現在會引發錯誤,而不是傳回False
。這先前會引發棄用警告,並且遵循 NumPy 中類似的變更。已棄用對 mhlo MLIR 方言的支援。JAX 不再使用 mhlo 方言,而是改用 stablehlo。未來將移除引用 “mhlo” 的 API。請改用 “stablehlo” 方言。
jax.random
:直接將批次金鑰傳遞給隨機數產生函式,例如bits()
、gamma()
等等,已被棄用,並將發出FutureWarning
。針對明確的批次處理,請使用jax.vmap
。jax.lax.tie_in()
已棄用:自 JAX v0.2.0 以來,它已不再執行任何操作。
jaxlib 0.4.24 (2024 年 2 月 6 日)#
變更
JAX 現在支援 CUDA 12.3 和 CUDA 11.8。已移除對 CUDA 12.2 的支援。
cost_analysis
現在可與跨編譯的Compiled
物件搭配使用(即當使用具有拓撲物件的.lower().compile()
時,例如從非 TPU 電腦編譯以用於 Cloud TPU)。新增 CUDA 陣列介面 匯入支援 (需要 jax 0.4.25)。
jax 0.4.23 (2023 年 12 月 13 日)#
jaxlib 0.4.23 (2023 年 12 月 13 日)#
修正了在編譯期間導致 GPU 編譯器產生詳細記錄的錯誤。
jax 0.4.22 (2023 年 12 月 13 日)#
棄用
JAX 陣列的
device_buffer
和device_buffers
屬性已棄用。明確的緩衝區已被更彈性的陣列分片介面取代,但之前的輸出可以透過以下方式恢復arr.device_buffer
變為arr.addressable_data(0)
arr.device_buffers
變為[x.data for x in arr.addressable_shards]
jaxlib 0.4.22 (2023 年 12 月 13 日)#
jax 0.4.21 (2023 年 12 月 4 日)#
新功能
新增了
jax.nn.squareplus
。
變更
最低 jaxlib 版本現在為 0.4.19。
發布的 wheel 現在使用 clang 而非 gcc 建置。
強制在呼叫
jax.distributed.initialize()
之前,裝置後端尚未初始化。在 Cloud TPU 環境中自動化
jax.distributed.initialize()
的引數。
棄用
先前已棄用的
sym_pos
引數已從jax.scipy.linalg.solve()
中移除。改為使用assume_a='pos'
。將
None
傳遞給jax.array()
或jax.asarray()
,無論是直接傳遞還是清單或元組中傳遞,都已被棄用,現在會引發FutureWarning
。目前它會轉換為 NaN,未來將引發TypeError
。以關鍵字引數方式將
condition
、x
和y
參數傳遞給jax.numpy.where
已被棄用,以符合numpy.where
。將無法轉換為 JAX 陣列的引數傳遞給
jax.numpy.array_equal()
和jax.numpy.array_equiv()
已被棄用,現在會引發DeprecationWaning
。目前,函式傳回 False,未來將引發例外狀況。JAX 陣列的
device()
方法已棄用。根據上下文,它可能會被以下其中之一取代jax.Array.devices()
傳回陣列使用的所有裝置的集合。jax.Array.sharding
提供陣列使用的分片組態。
jaxlib 0.4.21 (2023 年 12 月 4 日)#
變更
為了準備新增分散式 CPU 支援,JAX 現在將 CPU 裝置視為與 GPU 和 TPU 裝置相同,也就是說
jax.devices()
包括分散式作業中的所有裝置,即使是那些非本機程序的裝置。jax.local_devices()
仍然只包括本機程序的裝置,因此如果對jax.devices()
的變更中斷了您的程式碼,您很可能想要改用jax.local_devices()
。CPU 裝置現在在分散式作業中收到全域唯一 ID 號碼;先前 CPU 裝置會收到程序本機 ID 號碼。
每個 CPU 裝置的
process_index
現在將與同一程序中的任何 GPU 或 TPU 裝置相符;先前 CPU 裝置的process_index
始終為 0。
在 NVIDIA GPU 上,JAX 現在針對最大 1024x1024 的矩陣優先使用 Jacobi SVD 求解器。Jacobi 求解器似乎比非 Jacobi 版本更快。
錯誤修復
修復了當具有非有限值的陣列傳遞給非對稱特徵值分解 (#18226) 時的錯誤/掛起。具有非有限值的陣列現在會產生充滿 NaN 的陣列作為輸出。
jax 0.4.20 (2023 年 11 月 2 日)#
jaxlib 0.4.20 (2023 年 11 月 2 日)#
錯誤修復
修復了 E4M3 和 E5M2 float8 類型之間的一些類型混淆。
jax 0.4.19 (2023 年 10 月 19 日)#
新功能
新增了
jax.typing.DTypeLike
,可用於註解可轉換為 JAX dtype 的物件。新增了
jax.numpy.fill_diagonal
。
變更
JAX 現在需要 SciPy 1.9 或更新版本。
錯誤修復
在多控制器分散式 JAX 程式中,只有程序 0 會寫入持久編譯快取項目。這修正了當快取放置在網路檔案系統 (例如 GCS) 上時的寫入競爭。
cusolver 和 cufft 的版本檢查不再考慮修補程式版本,以判斷已安裝的這些程式庫版本是否至少與建置 JAX 時使用的版本一樣新。
jaxlib 0.4.19 (2023 年 10 月 19 日)#
變更
jaxlib 現在將始終優先選擇 pip 安裝的 NVIDIA CUDA 程式庫 (nvidia-… 套件) 而不是任何其他 CUDA 安裝(如果已安裝),包括在
LD_LIBRARY_PATH
中命名的安裝。如果這導致問題且意圖是使用系統安裝的 CUDA,則修復方法是移除 pip 安裝的 CUDA 程式庫套件。
jax 0.4.18 (2023 年 10 月 6 日)#
jaxlib 0.4.18 (2023 年 10 月 6 日)#
變更
CUDA jaxlib 現在依賴使用者安裝相容的 NCCL 版本。如果使用建議的
cuda12_pip
安裝,則應自動安裝 NCCL。目前,需要 NCCL 2.16 或更新版本。我們現在提供 Linux aarch64 wheel,包含和不包含 NVIDIA GPU 支援。
jax.Array.item()
現在支援選用的索引引數。
棄用
已棄用
jax.lax
中的許多內部公用程式和意外匯出,並將在未來版本中移除。jax.lax.dtypes
:改為使用jax.dtypes
。jax.lax.itertools
:改為使用itertools
。naryop
、naryop_dtype_rule
、standard_abstract_eval
、standard_naryop
、standard_primitive
、standard_unop
、unop
和unop_dtype_rule
是內部公用程式,現在已棄用,沒有替代方案。
錯誤修復
修復了 Cloud TPU 迴歸,其中編譯會因 smem 而導致 OOM。
jax 0.4.17 (2023 年 10 月 3 日)#
新功能
新增了
jax.numpy.bitwise_count()
函式,與最近新增至 NumPy 的類似函式的 API 相符。
棄用
移除了已棄用的模組
jax.abstract_arrays
及其所有內容。jax.random
中的具名金鑰建構函式已棄用。改為將impl
引數傳遞給jax.random.PRNGKey()
或jax.random.key()
random.threefry2x32_key(seed)
變為random.PRNGKey(seed, impl='threefry2x32')
random.rbg_key(seed)
變為random.PRNGKey(seed, impl='rbg')
random.unsafe_rbg_key(seed)
變為random.PRNGKey(seed, impl='unsafe_rbg')
變更
CUDA:JAX 現在驗證其找到的 CUDA 程式庫是否至少與建置 JAX 時使用的 CUDA 程式庫一樣新。如果找到較舊的程式庫,JAX 會引發例外狀況,因為這比神秘的失敗和崩潰更可取。
移除了「找不到 GPU/TPU」警告。而是改為在 Linux 上,如果找到 NVIDIA GPU 或 Google TPU 但未使用,且未指定
--jax_platforms
時發出警告。jax.scipy.stats.mode()
現在,如果跨大小為 0 的軸取得模式,則會傳回 0 計數,這與 SciPy 1.11 中scipy.stats.mode
的行為相符。大多數
jax.numpy
函式和屬性現在都具有完整定義的類型存根。先前,這些函式和屬性中的許多都被靜態類型檢查器 (例如mypy
和pytype
) 視為Any
。
jaxlib 0.4.17 (Oct 3, 2023)#
變更
此版本新增了 Python 3.12 wheel。
CUDA 12 wheel 現在需要 CUDA 12.2 或更新版本,以及 cuDNN 8.9.4 或更新版本。
錯誤修復
修正了初始化 JAX CPU 後端時,ABSL 產生的過多日誌訊息。
jax 0.4.16 (Sept 18, 2023)#
變更
新增了
jax.numpy.ufunc
,以及jax.numpy.frompyfunc()
,它可以將任何純量值函數轉換為類似numpy.ufunc()
的物件,並具有outer()
、reduce()
、accumulate()
、at()
和reduceat()
等方法 (#17054)。當不在 IPython 下執行時:當引發例外時,JAX 現在會從回溯中過濾掉其所有內部框架。(不包含先前出現的「未過濾堆疊追蹤」。)這應該會產生更友善的回溯外觀。請參閱此處的範例。此行為可以透過設定
JAX_TRACEBACK_FILTERING=remove_frames
(用於兩個獨立的未過濾/已過濾回溯,這是舊行為)或JAX_TRACEBACK_FILTERING=off
(用於一個未過濾回溯)來變更。jax2tf 預設序列化版本現在為 7,它引入了新的形狀安全斷言。
傳遞至
jax.sharding.Mesh
的裝置應該是可雜湊的。這特別適用於模擬裝置或使用者建立的裝置。jax.devices()
已經是可雜湊的。
重大變更
jax2tf 現在預設使用原生序列化。請參閱 jax2tf 文件以瞭解詳細資訊和覆寫預設機制的資訊。
選項
--jax_coordination_service
已移除。現在始終為True
。jax.jaxpr_util
已從公開的 JAX 命名空間中移除。JAX_USE_PJRT_C_API_ON_TPU
不再起作用(即,始終預設為 true)。2021 年 12 月引入的回溯相容性旗標
--jax_host_callback_ad_transforms
已移除。
棄用
依照 NumPy NEP-52,已棄用數個
jax.numpy
API。jax.numpy.NINF
已棄用。請改用-jax.numpy.inf
。jax.numpy.PZERO
已棄用。請改用0.0
。jax.numpy.NZERO
已棄用。請改用-0.0
。jax.numpy.issubsctype(x, t)
已棄用。請改用jax.numpy.issubdtype(x.dtype, t)
。jax.numpy.row_stack
已棄用。請改用jax.numpy.vstack
。jax.numpy.in1d
已棄用。請改用jax.numpy.isin
。jax.numpy.trapz
已棄用。請改用jax.scipy.integrate.trapezoid
。
jax.scipy.linalg.tril
和jax.scipy.linalg.triu
已依照 SciPy 棄用。請改用jax.numpy.tril
和jax.numpy.triu
。jax.lax.prod
在 JAX v0.4.11 中棄用後已移除。請改用內建的math.prod
。與定義自訂 JAX 原始運算的 HLO 降低規則相關的
jax.interpreters.xla
中的一些匯出已棄用。自訂原始運算應改為使用jax.interpreters.mlir
中的 StableHLO 降低公用程式來定義。以下先前已棄用的函式在三個月的棄用期後已移除
jax.abstract_arrays.ShapedArray
:請使用jax.core.ShapedArray
。jax.abstract_arrays.raise_to_shaped
:請使用jax.core.raise_to_shaped
。jax.numpy.alltrue
:請使用jax.numpy.all
。jax.numpy.sometrue
:請使用jax.numpy.any
。jax.numpy.product
:請使用jax.numpy.prod
。jax.numpy.cumproduct
:請使用jax.numpy.cumprod
。
已棄用/移除
內部子模組
jax.prng
現在已棄用。其內容可在jax.extend.random
中找到。內部子模組路徑
jax.linear_util
已棄用。請改用jax.extend.linear_util
(jax.extend:擴充模組的一部分)jax.random.PRNGKeyArray
和jax.random.KeyArray
已棄用。請使用jax.Array
進行類型註釋,並使用jax.dtypes.issubdtype(arr.dtype, jax.dtypes.prng_key)
進行類型化 prng 金鑰的執行階段偵測。方法
PRNGKeyArray.unsafe_raw_array
已棄用。請改用jax.random.key_data()
。jax.experimental.pjit.with_sharding_constraint
已棄用。請改用jax.lax.with_sharding_constraint
。內部公用程式
jax.core.is_opaque_dtype
和jax.core.has_opaque_dtype
已移除。不透明 dtype 已重新命名為擴充 dtype;請改用jnp.issubdtype(dtype, jax.dtypes.extended)
(自 jax v0.4.14 起可用)。公用程式
jax.interpreters.xla.register_collective_primitive
已移除。此公用程式在最近的 JAX 版本中沒有任何作用,可以安全地移除對它的呼叫。內部子模組路徑
jax.linear_util
已棄用。請改用jax.extend.linear_util
(jax.extend:擴充模組的一部分)
jaxlib 0.4.16 (Sept 18, 2023)#
變更
透過實驗性 jax sparse API 進行的稀疏 CSR 矩陣乘法,在 NVIDIA GPU 上不再使用確定性演算法。進行此變更是為了提高與 CUDA 12.2.1 的相容性。
錯誤修復
修正了 Windows 上由於與錯序區段和 IMAGE_REL_AMD64_ADDR32NB 重定位相關的嚴重 LLVM 錯誤而導致的當機問題 (https://github.com/openxla/xla/commit/cb732a921f0c4184995cbed82394931011d12bd4)。
jax 0.4.14 (July 27, 2023)#
變更
jax.jit
接受donate_argnames
作為引數。它的語意與static_argnames
類似。如果未提供 donate_argnums 和 donate_argnames,則不會捐贈任何引數。如果未提供 donate_argnums 但提供了 donate_argnames,反之亦然,JAX 會使用inspect.signature(fun)
來尋找與 donate_argnames 對應的任何位置引數(反之亦然)。如果同時提供了 donate_argnums 和 donate_argnames,則不會使用 inspect.signature,並且只會捐贈 donate_argnums 或 donate_argnames 中列出的實際參數。jax.random.gamma()
已重新設計為更有效率的演算法,具有更穩健的端點行為 (#16779)。這表示對於給定的key
,JAX v0.4.13 和 v0.4.14 之間gamma
和相關取樣器(包括jax.random.ball()
、jax.random.beta()
、jax.random.chisquare()
、jax.random.dirichlet()
、jax.random.generalized_normal()
、jax.random.loggamma()
、jax.random.t()
)傳回的值序列將會變更。
刪除
in_axis_resources
和out_axis_resources
已從 pjit 中刪除,因為它們已棄用超過 3 個月。請使用in_shardings
和out_shardings
作為替代。這是安全且微不足道的名稱替換。它不會變更任何目前的 pjit 語意,也不會破壞任何程式碼。您仍然可以將PartitionSpecs
傳遞至 in_shardings 和 out_shardings。
棄用
已依照 https://jax.dev.org.tw/en/latest/deprecation.html 停止支援 Python 3.8
JAX 現在依照 https://jax.dev.org.tw/en/latest/deprecation.html 需要 NumPy 1.22 或更新版本
不再支援透過位置將選用引數傳遞至
jax.numpy.ndarray.at()
,此功能已在 JAX 版本 0.4.7 中棄用。例如,請使用x.at[i].get(indices_are_sorted=True)
而非x.at[i].get(True)
以下
jax.Array
方法已在 JAX v0.4.5 中棄用後移除jax.Array.broadcast
:請改用jax.lax.broadcast()
。jax.Array.broadcast_in_dim
:請改用jax.lax.broadcast_in_dim()
。jax.Array.split
:請改用jax.numpy.split()
。
以下 API 在先前棄用後已移除
jax.ad
:請使用jax.interpreters.ad
。jax.curry
:請使用curry = lambda f: partial(partial, f)
。jax.partial_eval
:請使用jax.interpreters.partial_eval
。jax.pxla
:請使用jax.interpreters.pxla
。jax.xla
:請使用jax.interpreters.xla
。jax.ShapedArray
:請使用jax.core.ShapedArray
。jax.interpreters.pxla.device_put
:請使用jax.device_put()
。jax.interpreters.pxla.make_sharded_device_array
:請使用jax.make_array_from_single_device_arrays()
。jax.interpreters.pxla.ShardedDeviceArray
:請使用jax.Array
。jax.numpy.DeviceArray
:請使用jax.Array
。jax.stages.Compiled.compiler_ir
:請使用jax.stages.Compiled.as_text()
。
重大變更
JAX 現在需要 ml_dtypes 版本 0.2.0 或更新版本。
為了修正邊角案例,如果第二個和第三個引數是可呼叫的,即使其他運算元也是可呼叫的,則對具有五個引數的
jax.lax.cond()
的呼叫將始終解析為「通用運算元」cond
行為(如文件中所述)。請參閱 #16413。已移除已棄用的組態選項
jax_array
和jax_jit_pjit_api_merge
,它們沒有任何作用。這些選項在許多版本中預設為 true。
新功能
JAX 現在支援組態旗標 –jax_serialization_version 和 JAX_SERIALIZATION_VERSION 環境變數,以控制序列化版本 (#16746)。
如果序列化版本至少為 7,則 jax2tf 在存在形狀多型性的情況下,現在會產生檢查特定形狀約束的程式碼。請參閱 https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism。
jaxlib 0.4.14 (July 27, 2023)#
棄用
已依照 https://jax.dev.org.tw/en/latest/deprecation.html 停止支援 Python 3.8
jax 0.4.13 (June 22, 2023)#
變更
jax.jit
現在允許將None
傳遞至in_shardings
和out_shardings
。語意如下對於 in_shardings,JAX 會將其標記為已複製,但此行為在未來可能會變更。
對於 out_shardings,我們將依賴 XLA GSPMD 分割器來判斷輸出分片。
jax.experimental.pjit.pjit
也允許將None
傳遞至in_shardings
和out_shardings
。語意如下如果未提供網格上下文管理器,則 JAX 可以自由選擇它想要的任何分片。
對於 in_shardings,JAX 會將其標記為已複製,但此行為在未來可能會變更。
對於 out_shardings,我們將依賴 XLA GSPMD 分割器來判斷輸出分片。
如果提供了網格上下文管理器,則 None 將表示值將在網格的所有裝置上複製。
Executable.cost_analysis() 在 Cloud TPU 上運作
如果使用未列入允許清單的
jaxlib
外掛程式,則新增警告。新增了
jax.tree_util.tree_leaves_with_path
。None
不是jax.experimental.multihost_utils.host_local_array_to_global_array
或jax.experimental.multihost_utils.global_array_to_host_local_array
的有效輸入。如果您想要複製輸入,請使用jax.sharding.PartitionSpec()
。
錯誤修復
修正了 CUDA 12 版本中不正確的 wheel 名稱 (#16362);正確的 wheel 名稱是
cudnn89
而不是cudnn88
。
棄用
jax.experimental.jax2tf.convert()
的native_serialization_strict_checks
參數已棄用,改為新的native_serializaation_disabled_checks
(#16347)。
jaxlib 0.4.13 (June 22, 2023)#
變更
將 Windows 僅 CPU wheel 新增至
jaxlib
Pypi 版本。
錯誤修復
__cuda_array_interface__
在先前的 jaxlib 版本中已損壞,現在已修正 (#16440)。現在預設在 NVIDIA GPU 上啟用並行 CUDA 核心追蹤。
jax 0.4.12 (June 8, 2023)#
變更
棄用
jax.abstract_arrays
及其內容現在已棄用。請參閱 :mod:jax.core
中的相關功能。jax.numpy.alltrue
:請使用jax.numpy.all
。這遵循 NumPy 版本 1.25.0 中numpy.alltrue
的棄用。jax.numpy.sometrue
:請使用jax.numpy.any
。這遵循 NumPy 版本 1.25.0 中numpy.sometrue
的棄用。jax.numpy.product
:請使用jax.numpy.prod
。這遵循 NumPy 版本 1.25.0 中numpy.product
的棄用。jax.numpy.cumproduct
:請使用jax.numpy.cumprod
。這遵循 NumPy 版本 1.25.0 中numpy.cumproduct
的棄用。jax.sharding.OpShardingSharding
已移除,因為它已棄用 3 個月。
jaxlib 0.4.12 (June 8, 2023)#
變更
包含適用於 Hopper (SM 版本 9.0+) GPU 的 PTX/SASS。舊版 jaxlib 應可在 Hopper 上運作,但在第一次執行 JAX 運算時,JIT 編譯延遲時間會很長。
錯誤修復
修正了 Python 3.11 下 JAX 產生的 Python 回溯中不正確的原始程式碼行資訊。
修正了在 JAX 產生的 Python 回溯中列印框架的區域變數時發生的當機問題 (#16027)。
jax 0.4.11 (May 31, 2023)#
棄用
以下 API 已依照 API 相容性 政策,在 3 個月的棄用期後移除
jax.experimental.PartitionSpec
:請使用jax.sharding.PartitionSpec
。jax.experimental.maps.Mesh
:請使用jax.sharding.Mesh
jax.experimental.pjit.NamedSharding
:請使用jax.sharding.NamedSharding
。jax.experimental.pjit.PartitionSpec
:請使用jax.sharding.PartitionSpec
。jax.experimental.pjit.FROM_GDA
。請改為傳遞分片的jax.Array
物件作為輸入,並移除pjit
的選用in_shardings
引數。jax.interpreters.pxla.PartitionSpec
:請使用jax.sharding.PartitionSpec
。jax.interpreters.pxla.Mesh
:請使用jax.sharding.Mesh
jax.interpreters.xla.Buffer
:請使用jax.Array
。jax.interpreters.xla.Device
:請使用jax.Device
。jax.interpreters.xla.DeviceArray
:請使用jax.Array
。jax.interpreters.xla.device_put
:請使用jax.device_put
。jax.interpreters.xla.xla_call_p
:請使用jax.experimental.pjit.pjit_p
。已移除
with_sharding_constraint
的axis_resources
引數。請改用shardings
。
jaxlib 0.4.11 (May 31, 2023)#
變更
為
Device
新增了memory_stats()
方法。如果支援,這會傳回字串統計名稱的 dict 和 int 值,例如"bytes_in_use"
,如果平台不支援記憶體統計資訊,則傳回 None。傳回的確切統計資訊可能因平台而異。目前僅在 Cloud TPU 上實作。重新新增了對 CPU 裝置上 Python 緩衝區協定 (
memoryview
) 的支援。
jax 0.4.10 (May 11, 2023)#
jaxlib 0.4.10 (May 11, 2023)#
變更
修正了
'apple-m1' is not a recognized processor for this target (ignoring processor)
問題,該問題阻止了先前版本在 Mac M1 上執行。
jax 0.4.9 (May 9, 2023)#
變更
已移除旗標 experimental_cpp_jit、experimental_cpp_pjit 和 experimental_cpp_pmap。它們現在始終處於開啟狀態。
已改善 TPU 上奇異值分解 (SVD) 的準確性 (需要 jaxlib 0.4.9)。
棄用
jax.experimental.gda_serialization
已棄用,並已重新命名為jax.experimental.array_serialization
。請變更您的匯入以使用jax.experimental.array_serialization
。已棄用 pjit 的
in_axis_resources
和out_axis_resources
引數。請分別使用in_shardings
和out_shardings
。函式
jax.numpy.msort
已移除。它自 JAX v0.4.1 起已棄用。請改用jnp.sort(a, axis=0)
。已從
jax.xla_computation
中移除in_parts
和out_parts
引數,因為它們僅用於 sharded_jit,而 sharded_jit 早已消失。已從
jax.xla_computation
中移除instantiate_const_outputs
引數,因為它已閒置很長時間。
jaxlib 0.4.9 (May 9, 2023)#
jax 0.4.8 (March 29, 2023)#
重大變更
Cloud TPU 執行階段的主要元件已升級。這在 Cloud TPU 上啟用了以下新功能
jax.debug.print()
、jax.debug.callback()
和jax.debug.breakpoint()
現在可在 Cloud TPU 上運作自動 TPU 記憶體重組
jax.experimental.host_callback()
在具有新執行階段元件的 Cloud TPU 上不再受支援。如果新的jax.debug
API 不足以滿足您的使用案例,請在 JAX 問題追蹤器上提交問題。舊的執行階段元件將在至少未來三個月內透過設定環境變數
JAX_USE_PJRT_C_API_ON_TPU=false
提供。如果您發現需要因任何原因停用新執行階段,請在 JAX 問題追蹤器上告知我們。
變更
最低 jaxlib 版本已從 0.4.6 提升至 0.4.7。
棄用
已停止支援 CUDA 11.4。JAX GPU wheel 僅支援 CUDA 11.8 和 CUDA 12。如果 jaxlib 是從原始碼建置的,則舊版 CUDA 可能會運作。
pmap 的
global_arg_shapes
引數僅適用於 sharded_jit,並且已從 pmap 中移除。請移轉至 pjit 並從 pmap 中移除 global_arg_shapes。
jax 0.4.7 (March 27, 2023)#
變更
依照 https://jax.dev.org.tw/en/latest/jax_array_migration.html#jax-array-migration
jax.config.jax_array
無法再停用。jax.config.jax_jit_pjit_api_merge
無法再停用。jax.experimental.jax2tf.convert()
現在支援native_serialization
參數,以使用 JAX 的原生降低至 StableHLO,來取得整個 JAX 函式的 StableHLO 模組,而不是將每個 JAX 原始運算降低為 TensorFlow 運算。這簡化了內部結構,並提高了您序列化的內容與 JAX 原生語意相符的信心。請參閱 文件。作為此變更的一部分,組態旗標--jax2tf_default_experimental_native_lowering
已重新命名為--jax2tf_native_serialization
。JAX 現在依賴
ml_dtypes
,其中包含 NumPy 類型(例如 bfloat16)的定義。這些定義先前是 JAX 的內部定義,但已拆分為個別套件,以方便與其他專案共用。JAX 現在需要 NumPy 1.21 或更新版本以及 SciPy 1.7 或更新版本。
棄用
類型
jax.numpy.DeviceArray
已棄用。請改用jax.Array
,它是它的別名。類型
jax.interpreters.pxla.ShardedDeviceArray
已棄用。請改用jax.Array
。不再建議依位置將其他引數傳遞至
jax.numpy.ndarray.at()
。例如,請使用x.at[i].get(indices_are_sorted=True)
而非x.at[i].get(True)
jax.interpreters.xla.device_put
已棄用。請使用jax.device_put
。jax.interpreters.pxla.device_put
已棄用。請使用jax.device_put
。jax.experimental.pjit.FROM_GDA
已棄用。請傳入分片的 jax.Arrays 作為輸入,並移除 pjit 的in_shardings
引數,因為它是選用的。
jaxlib 0.4.7 (March 27, 2023)#
變更
jaxlib 現在依賴
ml_dtypes
,其中包含 NumPy 類型(例如 bfloat16)的定義。這些定義先前是 JAX 的內部定義,但已拆分為個別套件,以方便與其他專案共用。
jax 0.4.6 (Mar 9, 2023)#
變更
jax.tree_util
現在包含一組 API,允許使用者為其自訂 pytree 節點定義鍵。這包括:tree_flatten_with_path
,其會扁平化樹狀結構,且不僅傳回每個葉節點,還會傳回其鍵路徑。tree_map_with_path
,其可以映射一個將鍵路徑作為參數的函式。register_pytree_with_keys
,用於註冊自訂 pytree 節點中的鍵路徑和葉節點應呈現的樣貌。keystr
,用於美觀地列印鍵路徑。
jax2tf.call_tf()
有一個新的參數output_shape_dtype
(預設為None
),可用於宣告結果的輸出形狀和資料型別。這使得jax2tf.call_tf()
能夠在形狀多態的情況下運作。( #14734 )。
棄用
jax.tree_util
中的舊版鍵路徑 API 已被棄用,並將在 2023 年 3 月 10 日起 3 個月後移除。register_keypaths
:請改用jax.tree_util.register_pytree_with_keys()
。AttributeKeyPathEntry
:請改用GetAttrKey
。GetitemKeyPathEntry
:請改用SequenceKey
或DictKey
。
jaxlib 0.4.6 (2023 年 3 月 9 日)#
jax 0.4.5 (2023 年 3 月 2 日)#
棄用
jax.sharding.OpShardingSharding
已重新命名為jax.sharding.GSPMDSharding
。jax.sharding.OpShardingSharding
將在 2023 年 2 月 17 日起 3 個月後移除。以下
jax.Array
方法已被棄用,並將在 2023 年 2 月 23 日起 3 個月後移除:jax.Array.broadcast
:請改用jax.lax.broadcast()
。jax.Array.broadcast_in_dim
:請改用jax.lax.broadcast_in_dim()
。jax.Array.split
:請改用jax.numpy.split()
。
jax 0.4.4 (2023 年 2 月 16 日)#
變更
jit
和pjit
的實作已合併。合併 jit 和 pjit 變更了 JAX 的內部結構,但不影響 JAX 的公開 API。先前,jit
是一種最終樣式基本運算。最終樣式表示 jaxpr 的建立會盡可能延遲,且轉換會堆疊在彼此之上。透過jit
-pjit
實作合併,jit
變成一種初始樣式基本運算,這表示我們會盡可能提早追蹤至 jaxpr。如需更多資訊,請參閱 autodidax 中的此章節。移至初始樣式應可簡化 JAX 的內部結構,並使動態形狀等功能的開發更加容易。您只能透過環境變數停用它,例如os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'
。必須透過環境變數停用合併,因為它會在匯入時影響 JAX,因此需要在匯入 jax 之前停用。with_sharding_constraint
的axis_resources
引數已被棄用。請改用shardings
。如果您將axis_resources
作為 arg 使用,則無需變更。如果您將其作為 kwarg 使用,請改用shardings
。axis_resources
將在 2023 年 2 月 13 日起 3 個月後移除。新增了
jax.typing
模組,其中包含 JAX 函式型別註解的工具。以下名稱已被棄用:
jax.xla.Device
和jax.interpreters.xla.Device
:請使用jax.Device
。jax.experimental.maps.Mesh
。請改用jax.sharding.Mesh
。jax.experimental.pjit.NamedSharding
:請使用jax.sharding.NamedSharding
。jax.experimental.pjit.PartitionSpec
:請使用jax.sharding.PartitionSpec
。jax.interpreters.pxla.Mesh
:請使用jax.sharding.Mesh
。jax.interpreters.pxla.PartitionSpec
:請使用jax.sharding.PartitionSpec
。
重大變更
reduction 函式 (如 :func:
jax.numpy.sum
) 的initial
引數現在必須是純量,這與對應的 NumPy API 一致。先前針對非純量initial
值廣播輸出的行為是一種非預期的實作細節 ( #14446 )。
jaxlib 0.4.4 (2023 年 2 月 16 日)#
重大變更
預設
jaxlib
組建已移除對 NVIDIA Kepler 系列 GPU 的支援。如果需要 Kepler 支援,仍然可以從來源組建具有 Kepler 支援的jaxlib
(透過build.py
的--cuda_compute_capabilities=sm_35
選項),但請注意 CUDA 12 已完全停止對 Kepler GPU 的支援。
jax 0.4.3 (2023 年 2 月 8 日)#
重大變更
已刪除
jax.scipy.linalg.polar_unitary()
,這是已棄用的 JAX 擴充功能至 scipy API。請改用jax.scipy.linalg.polar()
。
變更
jaxlib 0.4.3 (2023 年 2 月 8 日)#
jax.Array
現在具有非封鎖is_ready()
方法,如果陣列已就緒,則傳回True
(另請參閱jax.block_until_ready()
)。
jax 0.4.2 (2023 年 1 月 24 日)#
重大變更
變更
jaxlib 0.4.2 (2023 年 1 月 24 日)#
變更
設定 JAX_USE_PJRT_C_API_ON_TPU=1 以啟用新的 Cloud TPU 執行階段,其具有自動裝置記憶體重組功能。
jax 0.4.1 (2022 年 12 月 13 日)#
變更
已停止支援 Python 3.7,這符合 JAX 的 Python 和 NumPy 版本支援政策。
我們推出了
jax.Array
,這是一種統一的陣列型別,包含 JAX 中的DeviceArray
、ShardedDeviceArray
和GlobalDeviceArray
型別。jax.Array
型別有助於使平行處理成為 JAX 的核心功能、簡化和統一 JAX 內部結構,並讓我們能夠統一jit
和pjit
。jax.Array
已在 JAX 0.4 中預設啟用,並對pjit
API 進行了一些重大變更。jax.Array 移轉指南 可以協助您將程式碼庫移轉至jax.Array
。您也可以查看 分散式陣列和自動平行處理 教學課程,以瞭解新的概念。PartitionSpec
和Mesh
現在已脫離實驗階段。新的 API 端點為jax.sharding.PartitionSpec
和jax.sharding.Mesh
。jax.experimental.maps.Mesh
和jax.experimental.PartitionSpec
已被棄用,並將在 3 個月後移除。with_sharding_constraint
的新公開端點為jax.lax.with_sharding_constraint
。如果將 ABSL 旗標與
jax.config
一起使用,則在最初從 ABSL 旗標填入 JAX 組態選項後,將不再讀取或寫入 ABSL 旗標值。此變更改善了讀取jax.config
選項的效能,這些選項在 JAX 中被廣泛使用。jax2tf.call_tf 函式現在針對 TF 降低使用與嵌入 JAX 計算所用平台相同的平台的第一個 TF 裝置。先前,它是針對 JAX 預設後端使用第 0 個裝置。
許多
jax.numpy
函式現在已將其引數標記為僅限位置引數,與 NumPy 相符。jnp.msort
現在已被棄用,原因是 numpy 1.24 中已棄用np.msort
。它將在未來的版本中移除,這符合 API 相容性 政策。它可以使用jnp.sort(a, axis=0)
替代。
jaxlib 0.4.1 (2022 年 12 月 13 日)#
變更
已停止支援 Python 3.7,這符合 JAX 的 Python 和 NumPy 版本支援政策。
XLA_PYTHON_CLIENT_MEM_FRACTION=.XX
的行為已變更為配置 XX% 的 GPU 總記憶體,而不是先前使用目前可用的 GPU 記憶體來計算預先配置的行為。如需更多詳細資訊,請參閱 GPU 記憶體分配。已移除已棄用的方法
.block_host_until_ready()
。請改用.block_until_ready()
。
jax 0.4.0 (2022 年 12 月 12 日)#
此版本已撤回。
jaxlib 0.4.0 (2022 年 12 月 12 日)#
此版本已撤回。
jax 0.3.25 (2022 年 11 月 15 日)#
變更
jax.numpy.linalg.pinv()
現在支援hermitian
選項。jax.scipy.linalg.hessenberg()
現在僅在 CPU 上支援。需要 jaxlib > 0.3.24。新增了函式
jax.lax.linalg.hessenberg()
、jax.lax.linalg.tridiagonal()
和jax.lax.linalg.householder_product()
。Householder 歸約目前僅限 CPU,而三對角歸約僅在 CPU 和 GPU 上支援。現在針對非平方矩陣,更經濟地計算
svd
和jax.numpy.linalg.pinv
的梯度。
重大變更
已刪除
jax_experimental_name_stack
組態選項。將字串
axis_names
引數轉換為jax.experimental.maps.Mesh
建構函式,轉換為單例元組,而不是將字串解壓縮為字元軸名稱序列。
jaxlib 0.3.25 (2022 年 11 月 15 日)#
變更
新增了對 CPU 和 GPU 上三對角歸約的支援。
新增了對 CPU 上上黑森堡歸約的支援。
錯誤修正
修正了一個錯誤,該錯誤表示 JAX 擷取的追溯中的框架會錯誤地映射到 Python 3.10+ 下的原始碼行
jax 0.3.24 (2022 年 11 月 4 日)#
變更
JAX 的匯入速度應該更快。我們現在延遲匯入 scipy,這佔 JAX 匯入時間的很大一部分。
設定環境變數
JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS=$N
可用於限制寫入持續性快取的快取項目數量。依預設,編譯時間超過 1 秒的計算將會快取。如果未指定順序,則
pmap
在 TPU 上使用的預設裝置順序現在與單一程序作業的jax.devices()
相符。先前,這兩個順序不同,這可能會導致不必要的複製或記憶體不足錯誤。要求順序一致可簡化問題。
重大變更
jax.numpy.gradient()
現在的行為與jax.numpy
中的大多數其他函式類似,且禁止傳遞清單或元組來代替陣列 ( #12958 )jax.numpy.linalg
和jax.numpy.fft
中的函式現在統一要求輸入為類陣列:即清單和元組不能用來代替陣列。部分 #7737。
棄用
jax.sharding.MeshPspecSharding
已重新命名為jax.sharding.NamedSharding
。jax.sharding.MeshPspecSharding
名稱將在 3 個月後移除。
jaxlib 0.3.24 (2022 年 11 月 4 日)#
變更
緩衝區捐贈現在可在 CPU 上運作。這可能會破壞在 CPU 上標記緩衝區以進行捐贈,但依賴捐贈未實作的程式碼。
jax 0.3.23 (2022 年 10 月 12 日)#
變更
更新 Colab TPU 驅動程式版本以用於新的 jaxlib 版本。
jax 0.3.22 (2022 年 10 月 11 日)#
變更
在 TPU 初始化中新增
JAX_PLATFORMS=tpu,cpu
作為預設設定,因此如果無法初始化 TPU,JAX 將引發錯誤,而不是回復為 CPU。設定JAX_PLATFORMS=''
以覆寫此行為並自動選擇可用的後端 (原始預設值),或設定JAX_PLATFORMS=cpu
以始終使用 CPU,無論 TPU 是否可用。
棄用
JAX v0.3.8 中棄用的數個測試公用程式現在已從
jax.test_util
中移除。
jaxlib 0.3.22 (2022 年 10 月 11 日)#
jax 0.3.21 (2022 年 9 月 30 日)#
jax 0.3.20 (2022 年 9 月 28 日)#
jaxlib 0.3.20 (2022 年 9 月 28 日)#
jax 0.3.19 (2022 年 9 月 27 日)#
修正了必要的 jaxlib 版本。
jax 0.3.18 (2022 年 9 月 26 日)#
變更
預先 (AOT) 降低和編譯功能 (在 #7733 中追蹤) 是穩定且公開的。請參閱 概觀 和
jax.stages
的 API 文件。推出了
jax.Array
,旨在用於 JAX 中陣列型別的isinstance
檢查和型別註解。請注意,這包含對jax.numpy.ndarray
的isinstance
運作方式的一些細微變更,因為jax.numpy.ndarray
現在是jax.Array
的簡單別名。
重大變更
jax._src
不再匯入到公開jax
命名空間中。這可能會破壞正在使用 JAX 內部結構的使用者。jax.soft_pmap
已刪除。請改用pjit
或xmap
。jax.soft_pmap
未記載。如果已記載,則會提供棄用期。
jax 0.3.17 (2022 年 8 月 31 日)#
錯誤修正
修正了指數為零時
lax.pow
梯度的邊角案例問題 ( #12041 )
重大變更
jax.checkpoint()
(也稱為jax.remat()
) 不再支援concrete
選項,這是繼先前版本的棄用之後;請參閱 JEP 11830。
變更
新增了
jax.pure_callback()
,可從已編譯的函式 (例如,以jax.jit
或jax.pmap
裝飾的函式) 回呼至純 Python 函式。
棄用
已移除已棄用的
DeviceArray.tile()
方法。請使用jax.numpy.tile()
( #11944 )。DeviceArray.to_py()
已被棄用。請改用np.asarray(x)
。
jax 0.3.16#
重大變更
根據 棄用政策,已停止支援 NumPy 1.19。請升級至 NumPy 1.20 或更新版本。
變更
新增了
jax.debug
,其中包含執行階段數值偵錯的公用程式,例如jax.debug.print()
和jax.debug.breakpoint()
。為 執行階段數值偵錯 新增了新的文件
棄用
jax.mask()
jax.shapecheck()
API 已移除。請參閱 #11557。jax.experimental.loops
已移除。如需替代 API,請參閱 #10278。jax.tree_util.tree_multimap()
已移除。它自 JAX 版本 0.3.5 起已被棄用,而jax.tree_util.tree_map()
是直接替代品。已移除
jax.experimental.stax
;長期以來,它一直是jax.example_libraries.stax
的已棄用別名。已移除
jax.experimental.optimizers
;長期以來,它一直是jax.example_libraries.optimizers
的已棄用別名。jax.checkpoint()
(也稱為jax.remat()
) 有一個預設為開啟的新實作,這表示舊實作已被棄用;請參閱 JEP 11830。
jax 0.3.15 (2022 年 7 月 22 日)#
變更
JaxTestCase
和JaxTestLoader
已從jax.test_util
中移除。這些類別自 v0.3.1 起已被棄用 ( #11248 )。新增了
jax.scipy.gaussian_kde
( #11237 )。JAX 陣列與內建集合 (
dict
、list
、set
、tuple
) 之間的二元運算現在在所有情況下都會引發TypeError
。先前,某些情況 (尤其是相等和不相等) 會傳回與 NumPy 中類似運算不一致的布林純量 ( #11234 )。數個以頂層 JAX 套件匯入方式存取的
jax.tree_util
常式現在已被棄用,並將在未來 JAX 版本中根據 API 相容性 政策移除。jax.treedef_is_leaf()
已被棄用,建議改用jax.tree_util.treedef_is_leaf()
。jax.tree_flatten()
已被棄用,建議改用jax.tree_util.tree_flatten()
。jax.tree_leaves()
已被棄用,建議改用jax.tree_util.tree_leaves()
。jax.tree_structure()
已被棄用,建議改用jax.tree_util.tree_structure()
。jax.tree_transpose()
已被棄用,建議改用jax.tree_util.tree_transpose()
。jax.tree_unflatten()
已被棄用,建議改用jax.tree_util.tree_unflatten()
。
jax.scipy.linalg.solve()
的sym_pos
引數已被棄用,建議改用assume_a='pos'
,此變更遵循scipy.linalg.solve()
中類似的棄用。
jaxlib 0.3.15 (2022 年 7 月 22 日)#
jax 0.3.14 (2022 年 6 月 27 日)#
重大變更
jax.experimental.compilation_cache.initialize_cache()
不再支援max_cache_size_ bytes
,且不會再將其作為輸入。當平台初始化失敗時,
JAX_PLATFORMS
現在會引發例外。
變更
修正了與 NumPy 1.23 的相容性問題。
jax.numpy.linalg.slogdet()
現在接受可選的method
引數,允許在基於 LU 分解的實作和基於 QR 分解的實作之間進行選擇。jax.numpy.linalg.qr()
現在支援mode="raw"
。pickle
、copy.copy
和copy.deepcopy
在用於 jax 陣列時(#10659)現在有更完整的支援。特別是:先前,當對
DeviceArray
使用pickle
和deepcopy
時,會傳回np.ndarray
物件;現在會傳回DeviceArray
物件。對於deepcopy
,複製的陣列與原始陣列位於相同的裝置上。對於pickle
,還原序列化的陣列將位於預設裝置上。在函式轉換(即追蹤程式碼)中,
deepcopy
和copy
先前為無操作。現在它們使用與DeviceArray.copy()
相同的機制。在追蹤陣列上呼叫
pickle
現在會導致明確的ConcretizationTypeError
。
奇異值分解 (SVD) 和對稱/埃爾米特特徵值分解的實作在 TPU 上應顯著加快,尤其是對於 1000x1000 或更大的矩陣。兩者現在都使用頻譜分治演算法進行特徵值分解 (QDWH-eig)。
jax.numpy.ldexp()
不再靜默地將所有輸入提升為 float64,而是對於 int32 或更小尺寸的整數輸入,它會提升為 float32 (#10921)。為
jax.profiler.start_trace()
和jax.profiler.start_trace()
新增create_perfetto_link
選項。使用時,效能分析器將產生 Perfetto UI 的連結以檢視追蹤。變更了
jax.profiler.start_server(...)()
的語意,將 keepalive 全域儲存,而不是要求使用者保留對它的參考。新增
python -m jax.collect_profile
指令碼,以手動擷取程式追蹤,作為 TensorBoard UI 的替代方案。新增
jax.named_scope
環境管理器,將效能分析器中繼資料新增至 Python 程式(類似於jax.named_call
)。在分散更新操作(即 :attr:
jax.numpy.ndarray.at
)中,不安全的隱式 dtype 轉換已被棄用,現在會導致FutureWarning
。在未來的版本中,這將變成錯誤。不安全隱式轉換的一個範例是jnp.zeros(4, dtype=int).at[0].set(1.5)
,其中1.5
先前會被靜默截斷為1
。jax.experimental.compilation_cache.initialize_cache()
現在支援 gcs 儲存桶路徑作為輸入。當係數具有前導零時,
strip_zeros=False
的jax.numpy.roots()
現在表現更好 (#11215)。
jaxlib 0.3.14 (2022 年 6 月 27 日)#
-
x86-64 Mac wheels 現在需要 Mac OS 10.14 (Mojave) 或更新版本。Mac OS 10.14 於 2018 年發布,因此這應該不是一個非常繁瑣的要求。
捆綁的 NCCL 版本已更新至 2.12.12,修復了一些死鎖。
Python flatbuffers 套件不再是 jaxlib 的依賴項。
jax 0.3.13 (2022 年 5 月 16 日)#
jax 0.3.12 (2022 年 5 月 15 日)#
jax 0.3.11 (2022 年 5 月 15 日)#
變更
jax.lax.eigh()
現在接受可選的sort_eigenvalues
引數,允許使用者選擇在 TPU 上停用特徵值排序。
棄用
jax.lax.linalg
中函式的非陣列引數現在標記為僅限關鍵字。作為向後相容性步驟,以位置方式傳遞僅限關鍵字的引數會產生警告,但在未來的 JAX 版本中,以位置方式傳遞僅限關鍵字的引數將會失敗。但是,大多數使用者應優先使用jax.numpy.linalg
。作為 scipy API 的 JAX 擴充功能,
jax.scipy.linalg.polar_unitary()
已被棄用。請改用jax.scipy.linalg.polar()
。
jax 0.3.10 (2022 年 5 月 3 日)#
jaxlib 0.3.10 (2022 年 5 月 3 日)#
jax 0.3.9 (2022 年 5 月 2 日)#
變更
新增了對 GlobalDeviceArray 的完全非同步檢查點的支援。
jax 0.3.8 (2022 年 4 月 29 日)#
變更
TPU 上的
jax.numpy.linalg.svd()
使用 qdwh-svd 求解器。TPU 上的
jax.numpy.linalg.cond()
現在接受複數輸入。TPU 上的
jax.numpy.linalg.pinv()
現在接受複數輸入。TPU 上的
jax.numpy.linalg.matrix_rank()
現在接受複數輸入。jax.experimental.maps.mesh
已被刪除。請使用jax.experimental.maps.Mesh
。請參閱 https://jax.dev.org.tw/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh 以取得更多資訊。當
mode='r'
時,jax.scipy.linalg.qr()
現在傳回長度為 1 的元組,而不是原始陣列,以便與scipy.linalg.qr
的行為相符 (#10452)jax.numpy.take_along_axis()
現在接受可選的mode
參數,用於指定超出邊界索引的行為。預設情況下,超出邊界索引將傳回無效值(例如,NaN)。在先前的 JAX 版本中,無效索引會被鉗制在範圍內。先前的行為可以透過傳遞mode="clip"
來恢復。jax.numpy.take()
現在預設為mode="fill"
,這會為超出邊界索引傳回無效值(例如,NaN)。分散操作(例如
x.at[...].set(...)
)現在具有"drop"
語意。這對分散操作本身沒有影響,但這表示當微分分散時,超出邊界索引的梯度將產生零餘切。先前,超出邊界索引在梯度中被鉗制在範圍內,這在數學上是不正確的。如果
jax.numpy.take_along_axis()
的索引不是整數類型,則現在會引發TypeError
,與numpy.take_along_axis()
的行為相符。先前,非整數索引會被靜默轉換為整數。如果
jax.numpy.ravel_multi_index()
的dims
引數不是整數類型,則現在會引發TypeError
,與numpy.ravel_multi_index()
的行為相符。先前,非整數dims
會被靜默轉換為整數。如果
jax.numpy.split()
的axis
引數不是整數類型,則現在會引發TypeError
,與numpy.split()
的行為相符。先前,非整數axis
會被靜默轉換為整數。如果
jax.numpy.indices()
的維度不是整數類型,則現在會引發TypeError
,與numpy.indices()
的行為相符。先前,非整數維度會被靜默轉換為整數。如果
jax.numpy.diag()
的k
引數不是整數類型,則現在會引發TypeError
,與numpy.diag()
的行為相符。先前,非整數k
會被靜默轉換為整數。
棄用
jax.test_util
中提供的許多函式和物件現在已被棄用,並且在匯入時會引發警告。這包括cases_from_list
、check_close
、check_eq
、device_under_test
、format_shape_dtype_string
、rand_uniform
、skip_on_devices
、with_config
、xla_bridge
和_default_tolerance
(#10389)。這些以及先前已棄用的JaxTestCase
、JaxTestLoader
和BufferDonationTestCase
將在未來的 JAX 版本中移除。這些公用程式中的大多數都可以透過呼叫標準 python 和 numpy 測試公用程式來取代,例如在unittest
、absl.testing
、numpy.testing
等中找到。JAX 特定功能(例如裝置檢查)可以透過使用公用 API(例如jax.devices()
)來取代。許多已棄用的公用程式仍將存在於jax._src.test_util
中,但這些不是公用 API,因此可能會在未來的版本中更改或移除,恕不另行通知。
jax 0.3.7 (2022 年 4 月 15 日)#
變更
修復了如果傳遞給
jax.numpy.take_along_axis()
的索引被廣播時的效能問題 (#10281)。jax.scipy.special.expit()
和jax.scipy.special.logit()
現在要求其引數為純量或 JAX 陣列。它們現在也會將整數引數提升為浮點數。DeviceArray.tile()
方法已被棄用,因為 numpy 陣列沒有tile()
方法。作為替代方案,請使用jax.numpy.tile()
(#10266)。
jaxlib 0.3.7 (2022 年 4 月 15 日)#
變更
Linux wheels 現在根據
manylinux2014
標準而不是manylinux2010
標準建置。
jax 0.3.6 (2022 年 4 月 12 日)#
jax 0.3.5 (2022 年 4 月 7 日)#
變更
新增了
jax.random.loggamma()
,並改進了小參數值的jax.random.beta()
和jax.random.dirichlet()
的行為 (#9906)。私有
lax_numpy
子模組不再在jax.numpy
命名空間中公開 (#10029)。新增了陣列建立常式
jax.numpy.frombuffer()
、jax.numpy.fromfunction()
和jax.numpy.fromstring()
(#10049)。DeviceArray.copy()
現在傳回DeviceArray
而不是np.ndarray
(#10069)jax.experimental.sharded_jit
已被棄用,並將很快移除。
棄用
jax.nn.normalize()
正在被棄用。請改用jax.nn.standardize()
(#9899)。jax.tree_util.tree_multimap()
已被棄用。請改用jax.tree_util.tree_map()
(#5746)。jax.experimental.sharded_jit
已被棄用。請改用pjit
。
jaxlib 0.3.5 (2022 年 4 月 7 日)#
jax 0.3.4 (2022 年 3 月 18 日)#
jax 0.3.3 (2022 年 3 月 17 日)#
jax 0.3.2 (2022 年 3 月 16 日)#
變更
在 0.2.22 中已棄用的函式
jax.ops.index_update
、jax.ops.index_add
已被移除。請改用 JAX 陣列上的.at
屬性,例如x.at[idx].set(y)
。將
jax.experimental.ann.approx_*_k
移至jax.lax
中。這些函式是jax.lax.top_k
的最佳化替代方案。jax.numpy.broadcast_arrays()
和jax.numpy.broadcast_to()
現在需要純量或類陣列輸入,如果傳遞列表,則會失敗(#7737 的一部分)。標準 jax[tpu] 安裝現在可用於 Cloud TPU v4 VM。
pjit
現在可在 CPU 上運作(除了先前的 TPU 和 GPU 支援)。
jaxlib 0.3.2 (2022 年 3 月 16 日)#
變更
XlaComputation.as_hlo_text()
現在支援透過傳遞布林標誌print_large_constants=True
來列印大型常數。
棄用
JAX 陣列上的
.block_host_until_ready()
方法已被棄用。請改用.block_until_ready()
。
jax 0.3.1 (2022 年 2 月 18 日)#
變更
jax.test_util.JaxTestCase
和jax.test_util.JaxTestLoader
現在已被棄用。建議的替代方案是直接使用parametrized.TestCase
。對於依賴自訂斷言(例如JaxTestCase.assertAllClose()
)的測試,建議的替代方案是使用標準 numpy 測試公用程式(例如numpy.testing.assert_allclose()
),這些公用程式可直接與 JAX 陣列搭配使用 (#9620)。jax.test_util.JaxTestCase
現在預設設定jax_numpy_rank_promotion='raise'
(#9562)。若要恢復先前的行為,請使用新的jax.test_util.with_config
裝飾器。@jtu.with_config(jax_numpy_rank_promotion='allow') class MyTestCase(jtu.JaxTestCase): ...
新增了
jax.scipy.linalg.schur()
、jax.scipy.linalg.sqrtm()
、jax.scipy.signal.csd()
、jax.scipy.signal.stft()
、jax.scipy.signal.welch()
。
jax 0.3.0 (2022 年 2 月 10 日)#
jaxlib 0.3.0 (2022 年 2 月 10 日)#
變更
現在需要 Bazel 5.0.0 才能建置 jaxlib。
jaxlib 版本已升級至 0.3.0。請參閱 設計文件 以取得說明。
jax 0.2.28 (2022 年 2 月 1 日)#
-
如果未傳遞
dialect=
,jax.jit(f).lower(...).compiler_ir()
現在預設為 MHLO 方言。jax.jit(f).lower(...).compiler_ir(dialect='mhlo')
現在傳回 MLIRir.Module
物件,而不是其字串表示形式。
jaxlib 0.1.76 (2022 年 1 月 27 日)#
新功能
包含適用於 NVidia 計算能力 8.0 GPU (例如 A100) 的預編譯 SASS。移除了適用於計算能力 6.1 的預編譯 SASS,以免增加計算能力的數量:計算能力為 6.1 的 GPU 可以使用 6.0 SASS。
使用 jaxlib 0.1.76,JAX 預設使用 MHLO MLIR 方言作為其主要目標編譯器 IR。
重大變更
已停止支援 NumPy 1.18,根據廢棄政策。請升級至支援的 NumPy 版本。
錯誤修復
修正了由不同路徑建構,但外觀上相同的 pytreedef 物件,無法比較為相等的錯誤 (#9066)。
JAX jit 快取需要兩個靜態引數具有相同的類型,才能命中快取 (#9311)。
jax 0.2.27 (2022 年 1 月 18 日)#
重大變更
已停止支援 NumPy 1.18,根據廢棄政策。請升級至支援的 NumPy 版本。
已簡化 host_callback 原始運算,以移除針對 hcb.id_tap 和 id_print 的特殊自動微分處理。從現在起,僅會點擊(tapped)原始值。可以透過設定
JAX_HOST_CALLBACK_AD_TRANSFORMS
環境變數,或--jax_host_callback_ad_transforms
旗標(flag),來取得舊的行為(在有限的時間內)。此外,新增了關於如何使用 JAX 自訂 AD API 實作舊行為的文件 (#8678)。現在排序行為與 NumPy 針對
0.0
和NaN
的行為一致,無論其位元表示法為何。特別是,0.0
和-0.0
現在被視為相等,而先前-0.0
被視為小於0.0
。此外,所有NaN
表示法現在都被視為相等,並排序至陣列的末尾。先前,負NaN
值會排序至陣列的前端,且具有不同內部位元表示法的NaN
值不會被視為相等,而是根據這些位元模式進行排序 (#9178)。jax.numpy.unique()
現在以與 NumPy 版本 1.21 及更新版本中的np.unique
相同的方式處理NaN
值:最多只有一個NaN
值會出現在唯一化輸出中 (#9184)。
錯誤修復
host_callback 現在支援 ad_checkpoint.checkpoint (#8907)。
新功能
新增
jax.block_until_ready
({jax-issue}`#8941)新增了一個新的偵錯旗標/環境變數
JAX_DUMP_IR_TO=/path
。如果設定,JAX 會將其為每個計算產生的 MHLO/HLO IR 傾印到給定路徑下的檔案中。將
jax.ensure_compile_time_eval
新增至公開 API (#7987)。jax2tf 現在支援旗標 jax2tf_associative_scan_reductions,以變更關聯性歸約的降低(lowering),例如 jnp.cumsum,使其在 CPU 和 GPU 上的行為類似於 JAX(使用關聯性掃描)。請參閱 jax2tf README 以取得更多詳細資訊 (#9189)。
jaxlib 0.1.75 (2021 年 12 月 8 日)#
新功能
支援 python 3.10。
jax 0.2.26 (2021 年 12 月 8 日)#
jaxlib 0.1.74 (2021 年 11 月 17 日)#
啟用了 GPU 之間的點對點複製。先前,GPU 複製會透過主機彈回,通常速度較慢。
新增了實驗性的 MLIR Python 綁定,供 JAX 使用。
jax 0.2.25 (2021 年 11 月 10 日)#
jax 0.2.24 (2021 年 10 月 19 日)#
jaxlib 0.1.73 (2021 年 10 月 18 日)#
jaxlib GPU
cuda11
輪子(wheels)現在支援多個 cuDNN 版本。cuDNN 8.2 或更新版本。如果您的 cuDNN 安裝夠新,我們建議使用 cuDNN 8.2 輪子,因為它支援額外功能。
cuDNN 8.0.5 或更新版本。
重大變更
GPU jaxlib 的安裝指令如下
pip install --upgrade pip # Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer. pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html # Installs the wheel compatible with Cuda 11 and cudnn 8.2 or newer. pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html # Installs the wheel compatible with Cuda 11 and cudnn 8.0.5 or newer. pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
jax 0.2.22 (2021 年 10 月 12 日)#
重大變更
jax.pmap
的靜態引數現在必須是可雜湊的(hashable)。長期以來,
jax.jit
上一直不允許使用不可雜湊的靜態引數,但jax.pmap
仍然允許使用;jax.pmap
使用物件識別來比較不可雜湊的靜態引數。這種行為是一個隱藏的陷阱,因為使用物件識別來比較引數,每次物件識別變更時都會導致重新編譯。相反地,我們現在禁止不可雜湊的引數:如果
jax.pmap
的使用者想要透過物件識別來比較靜態引數,他們可以在物件上定義__hash__
和__eq__
方法來執行此操作,或者將他們的物件包裝在具有使用物件識別語意的這些操作的物件中。另一個選項是使用functools.partial
將不可雜湊的靜態引數封裝到函數物件中。jax.util.partial
是一個意外匯出的項目,現在已移除。請改用 Python 標準函式庫中的functools.partial
。
棄用
函數
jax.ops.index_update
、jax.ops.index_add
等已棄用,並將在未來的 JAX 版本中移除。請改用JAX 陣列上的.at
屬性,例如,x.at[idx].set(y)
。目前,這些函數會產生DeprecationWarning
。
新功能
當使用 jaxlib 0.1.72 或更新版本時,改善
pmap
的調度時間的優化 C++ 程式碼路徑現在是預設值。可以使用--experimental_cpp_pmap
旗標(或JAX_CPP_PMAP
環境變數)停用此功能。jax.numpy.unique
現在支援可選的fill_value
引數 (#8121)
jaxlib 0.1.72 (2021 年 10 月 12 日)#
重大變更
已停止支援 CUDA 10.2 和 CUDA 10.1。Jaxlib 現在支援 CUDA 11.1+。
錯誤修復
修正了 https://github.com/jax-ml/jax/issues/7461,該問題由於 XLA 編譯器內部的緩衝區別名錯誤,導致在所有平台上產生錯誤輸出。
jax 0.2.21 (2021 年 9 月 23 日)#
重大變更
jax.api
已移除。作為jax.api.*
提供的函數是jax.*
中函數的別名;請改用jax.*
中的函數。jax.partial
和jax.lax.partial
是意外匯出的項目,現在已移除。請改用 Python 標準函式庫中的functools.partial
。布林純量索引現在會引發
TypeError
;先前這會靜默地回傳錯誤結果 (#7925)。更多
jax.numpy
函數現在需要類陣列輸入,如果傳入列表(list)將會發生錯誤 (#7747 #7802 #7907)。請參閱 #7737 以了解此變更背後的理由。當在諸如
jax.jit
之類的轉換內部時,jax.numpy.array
總是將其產生的陣列分段到追蹤計算中。先前,即使在jax.jit
裝飾器下,jax.numpy.array
有時也會產生裝置上的陣列。此變更可能會破壞使用 JAX 陣列執行形狀或索引計算的程式碼,這些計算必須是靜態已知的;解決方法是改用經典的 NumPy 陣列執行這些計算。jnp.ndarray
現在是 JAX 陣列的真正基底類別。特別是,這表示對於標準 numpy 陣列x
,isinstance(x, jnp.ndarray)
現在將會回傳False
(#7927)。
新功能
新增
jax.numpy.insert()
實作 (#7936)。
jax 0.2.20 (2021 年 9 月 2 日)#
jaxlib 0.1.71 (2021 年 9 月 1 日)#
重大變更
已停止支援 CUDA 11.0 和 CUDA 10.1。Jaxlib 現在支援 CUDA 10.2 和 CUDA 11.1+。
jax 0.2.19 (2021 年 8 月 12 日)#
重大變更
已停止支援 NumPy 1.17,根據廢棄政策。請升級至支援的 NumPy 版本。
已在 JAX 陣列上的一些運算子的實作周圍新增了
jit
裝飾器。這加快了常見運算子(例如+
)的調度時間。此變更對大多數使用者來說應該基本上是透明的。但是,有一個已知的行為變更,那就是大型整數常數在直接傳遞給 JAX 運算子時(例如,
x + 2**40
)現在可能會產生錯誤。解決方法是將常數轉換為明確的類型(例如,np.float64(2**40)
)。
新功能
改善了 jax2tf 中對形狀多型性的支援,以用於需要在陣列計算中使用維度大小的操作,例如
jnp.mean
。 (#7317)
錯誤修復
先前版本中的一些洩漏追蹤錯誤 (#7613)
jaxlib 0.1.70 (2021 年 8 月 9 日)#
jax 0.2.18 (2021 年 7 月 21 日)#
重大變更
已停止支援 Python 3.6,根據廢棄政策。請升級至支援的 Python 版本。
最低 jaxlib 版本現在為 0.1.69。
已移除
jax.dlpack.from_dlpack()
的backend
引數。
新功能
新增了極分解 (
jax.scipy.linalg.polar()
)。
錯誤修復
收緊了對 lax.argmin 和 lax.argmax 的檢查,以確保它們不會與無效的
axis
值或空的歸約維度一起使用。 (#7196)
jaxlib 0.1.69 (2021 年 7 月 9 日)#
修正了 TFRT CPU 後端中導致結果不正確的錯誤。
jax 0.2.17 (2021 年 7 月 9 日)#
錯誤修復
對於 jaxlib <= 0.1.68,預設為較舊的 “stream_executor” CPU 執行階段,以解決 #7229,該問題由於並行問題導致 CPU 上產生錯誤輸出。
新功能
新的 SciPy 函數
jax.scipy.special.sph_harm()
。反向模式自動微分函數 (
jax.grad()
、jax.value_and_grad()
、jax.vjp()
和jax.linear_transpose()
) 支援一個參數,該參數指示如果前向傳遞中廣播(broadcasted)的具名軸,在反向傳遞中應對哪些具名軸求和。這使得這些 API 能夠以非逐個範例的方式在映射(maps)內使用(最初僅限jax.experimental.maps.xmap()
)(#6950)。
jax 0.2.16 (2021 年 6 月 23 日)#
jax 0.2.15 (2021 年 6 月 23 日)#
jaxlib 0.1.68 (2021 年 6 月 23 日)#
錯誤修復
修正了 TFRT CPU 後端中將 TPU 緩衝區傳輸到 CPU 時會出現 nan 的錯誤。
jax 0.2.14 (2021 年 6 月 10 日)#
新功能
jax2tf.convert()
現在支援pjit
和sharded_jit
。新的組態選項 JAX_TRACEBACK_FILTERING 控制 JAX 如何篩選回溯(tracebacks)。
在足夠新版本的 IPython 中,預設情況下已啟用使用
__tracebackhide__
的新回溯篩選模式。jax2tf.convert()
即使在算術運算中使用未知維度,也支援形狀多型性,例如jnp.reshape(-1)
(#6827)。jax2tf.convert()
在 TF 運算中產生具有位置資訊的自訂屬性。jax2tf 之後 XLA 產生的程式碼與 JAX/XLA 具有相同的位置資訊。新的 SciPy 函數
jax.scipy.special.lpmn()
。
錯誤修復
jaxlib 0.1.67 (2021 年 5 月 17 日)#
jaxlib 0.1.66 (2021 年 5 月 11 日)#
新功能
CUDA 11.1 輪子現在在所有 CUDA 11 版本 11.1 或更高版本上都受到支援。
NVidia 現在承諾 CUDA 次要版本之間的相容性,從 CUDA 11.1 開始。這表示 JAX 可以發布單個 CUDA 11.1 輪子,該輪子與 CUDA 11.2 和 11.3 相容。
不再有針對 CUDA 11.2(或更高版本)的單獨 jaxlib 版本;對於這些版本,請使用 CUDA 11.1 輪子 (cuda111)。
Jaxlib 現在在 CUDA 輪子中捆綁了
libdevice.10.bc
。應該不再需要指示 JAX CUDA 安裝位置來尋找此檔案。為
jit()
實作新增了對靜態關鍵字引數的自動支援。新增了對預轉換例外追蹤的支援。
初步支援從
jit()
轉換的計算中修剪未使用的引數。修剪仍在進行中。改善了
PyTreeDef
物件的字串表示形式。新增了對 XLA 的可變參數 ReduceWindow 的支援。
錯誤修復
修正了遠端雲端 TPU 支援中,將大量引數傳遞給計算時的錯誤。
修正了一個錯誤,該錯誤表示 JAX 垃圾回收未由
jit()
轉換的函數觸發。
jax 0.2.13 (2021 年 5 月 3 日)#
新功能
當與 jaxlib 0.1.66 結合使用時,
jax.jit()
現在支援靜態關鍵字引數。新增了一個新的static_argnames
選項,以將關鍵字引數指定為靜態。jax.nonzero()
有一個新的可選size
引數,使其可以在jit
內使用 (#6501)jax.numpy.unique()
現在支援axis
引數 (#6532)。jax.experimental.host_callback.call()
現在支援pjit.pjit
(#6569)。新增了
jax.scipy.linalg.eigh_tridiagonal()
,用於計算三對角矩陣的特徵值。目前僅支援特徵值。例外狀況中經過篩選和未經篩選的堆疊追蹤的順序已變更。附加到從 JAX 轉換的程式碼擲回的例外狀況的回溯現在已篩選,其中
UnfilteredStackTrace
例外狀況包含原始追蹤作為篩選例外狀況的__cause__
。篩選的堆疊追蹤現在也適用於 Python 3.6。如果由反向模式自動微分轉換的程式碼擲回例外狀況,JAX 現在會嘗試將
JaxStackTraceBeforeTransformation
物件作為例外狀況的__cause__
附加,該物件包含在前向傳遞中建立原始運算的堆疊追蹤。需要 jaxlib 0.1.66。
重大變更
以下函數名稱已變更。仍然有別名,因此這不應破壞現有程式碼,但別名最終將被移除,因此請變更您的程式碼。
host_id
–>process_index()
host_count
–>process_count()
host_ids
–>range(jax.process_count())
同樣地,
local_devices()
的引數已從host_id
重新命名為process_index
。除了函數之外,
jax.jit()
的引數現在標記為僅限關鍵字。此變更是為了防止在將引數新增至jit
時意外中斷。
錯誤修復
jaxlib 0.1.65 (2021 年 4 月 7 日)#
jax 0.2.12 (2021 年 4 月 1 日)#
新功能
新的分析 API:
jax.profiler.start_trace()
、jax.profiler.stop_trace()
和jax.profiler.trace()
jax.lax.reduce()
現在是可微分的。
重大變更
最低 jaxlib 版本現在為 0.1.64。
一些分析器 API 名稱已變更。仍然有別名,因此這不應破壞現有程式碼,但別名最終將被移除,因此請變更您的程式碼。
TraceContext
–>TraceAnnotation()
StepTraceContext
–>StepTraceAnnotation()
trace_function
–>annotate_function()
Omnistaging 功能已無法停用。請參閱 omnistaging 以了解更多資訊。
Python 整數若大於
int64
的最大值,在所有情況下都會導致溢位 (overflow),而不會像過去在某些情況下靜默轉換為uint64
(#6047)。在 X64 模式之外,Python 整數若超出
int32
可表示的範圍,現在會導致OverflowError
錯誤,而不是靜默截斷其值。
錯誤修復
host_callback
現在支援在引數和結果中使用空陣列 (#6262)。jax.random.randint()
對超出範圍的限制值進行裁剪 (clip) 而非環繞 (wrap),並且現在可以產生指定 dtype 完整範圍內的整數 (#5868)。
jax 0.2.11 (2021 年 3 月 23 日)#
新功能
錯誤修復
重大變更
最低 jaxlib 版本現在為 0.1.62。
jaxlib 0.1.64 (2021 年 3 月 18 日)#
jaxlib 0.1.63 (2021 年 3 月 17 日)#
jax 0.2.10 (2021 年 3 月 5 日)#
新功能
jax.scipy.stats.chi2()
現在可以作為具有 logpdf 和 pdf 方法的分布 (distribution) 使用。jax.scipy.stats.betabinom()
現在可以作為具有 logpmf 和 pmf 方法的分布 (distribution) 使用。新增了
jax.experimental.jax2tf.call_tf()
,以便從 JAX 呼叫 TensorFlow 函數 (#5627) 和 README)。擴展了
lax.pad
的批次處理規則,以支援填充值 (padding values) 的批次處理。
錯誤修復
jax.numpy.take()
正確處理負索引 (#5768)。
重大變更
JAX 的型別提升規則已調整,使提升更一致且不受 JIT 影響。特別是,二元運算現在可以在適當的情況下產生弱型別的值。此變更主要在使用者可見的影響是,某些運算會產生與之前不同精度的輸出;例如,運算式
jnp.bfloat16(1) + 0.1 * jnp.arange(10)
先前傳回float64
陣列,現在則傳回bfloat16
陣列。JAX 的型別提升行為請參閱 型別提升語意。jax.numpy.linspace()
現在計算整數值的下限值,即朝 -inf 而非 0 捨入。此變更旨在與 NumPy 1.20.0 相符。jax.numpy.i0()
不再接受複數。先前,此函數會計算複數引數的絕對值。此變更旨在與 NumPy 1.20.0 的語意相符。數個
jax.numpy
函數不再接受以元組或列表取代陣列引數:jax.numpy.pad()
、:funcjax.numpy.ravel
、jax.numpy.repeat()
、jax.numpy.reshape()
。一般而言,jax.numpy
函數應搭配純量或陣列引數使用。
jaxlib 0.1.62 (2021 年 3 月 9 日)#
新功能
jaxlib wheels 現在預設建置為在 x86-64 機台上需要 AVX 指令。如果您想在不支援 AVX 的機器上使用 JAX,可以使用
build.py
的--target_cpu_features
旗標從原始碼建置 jaxlib。--target_cpu_features
也取代了--enable_march_native
。
jaxlib 0.1.61 (2021 年 2 月 12 日)#
jaxlib 0.1.60 (2021 年 2 月 3 日)#
錯誤修復
修復了將 CPU DeviceArrays 轉換為 NumPy 陣列時的記憶體洩漏問題。記憶體洩漏存在於 jaxlib 版本 0.1.58 和 0.1.59 中。
bool
、int8
和uint8
現在被視為可以安全地轉換為bfloat16
NumPy 擴充型別。
jax 0.2.9 (2021 年 1 月 26 日)#
新功能
擴展了
jax.experimental.loops
模組,使其支援 pytrees。改進了錯誤檢查和錯誤訊息。新增了
jax.experimental.enable_x64()
和jax.experimental.disable_x64()
。這些是上下文管理器,允許在工作階段中暫時啟用/停用 X64 模式。
重大變更
jax.ops.segment_sum()
現在會捨棄超出範圍的區段 ID,而不是將其環繞到區段 ID 空間中。這樣做是為了效能考量。
jaxlib 0.1.59 (2021 年 1 月 15 日)#
jax 0.2.8 (2021 年 1 月 12 日)#
新功能
新增了
jax.closure_convert()
,用於高階自訂導數函數。(<#5244>)新增了
jax.experimental.host_callback.call()
,以便在主機上呼叫自訂 Python 函數,並將結果傳回裝置運算。(<#5243>)
錯誤修復
重大變更
jax.numpy.pad
現在接受關鍵字引數。位置引數constant_values
已移除。此外,傳遞不支援的關鍵字引數會引發錯誤。jax.experimental.host_callback.id_tap()
的變更 (#5243)移除了
jax.experimental.host_callback.id_tap()
的kwargs
支援。(此支援已棄用數個月。)變更了
jax.experimental.host_callback.id_print()
的元組列印方式,從 ‘[‘ 改為使用 ‘(‘。變更了 JVP 存在時的
jax.experimental.host_callback.id_print()
,以列印原始值和切線值組。先前,原始值和切線值有兩個個別的列印運算。host_callback.outfeed_receiver
已移除 (它不是必要的,且數個月前已棄用)。
新功能
用於偵錯
inf
的新旗標,類似於NaN
的旗標 (#5224)。
jax 0.2.7 (2020 年 12 月 4 日)#
新功能
新增了
jax.device_put_replicated
將多主機支援新增至
jax.experimental.sharded_jit
新增了對
jax.numpy.linalg.eig
計算出的特徵值進行微分的支援新增了在 Windows 平台上建置的支援
在
jax.pmap
中新增了對一般 in_axes 和 out_axes 的支援針對
jax.numpy.linalg.slogdet
新增了複數支援
錯誤修復
修復了
jax.numpy.sinc
在零點的高於二階的導數修復了轉置規則中一些難以命中的符號零錯誤
重大變更
jax.experimental.optix
已刪除,改用獨立的optax
Python 套件。現在使用非元組序列對 JAX 陣列進行索引會引發
TypeError
錯誤。自 Numpy v1.16 起已棄用此類型的索引,自 JAX v0.2.4 起亦同。請參閱 #4564。
jax 0.2.6 (2020 年 11 月 18 日)#
新功能
為 jax.experimental.jax2tf 轉換器新增了形狀多型追蹤的支援。請參閱 README.md。
重大變更清理
針對 jax.jit 和 xla_computation 的不可雜湊靜態引數引發錯誤。請參閱 cb48f42。
改進了型別提升行為的一致性 (#4744)
將複數 Python 純量新增至 JAX 浮點數時,會遵守 JAX 浮點數的精確度。例如,
jnp.float32(1) + 1j
現在會傳回complex64
,先前則傳回complex128
。涉及 uint64、帶正負號整數和第三種類型的 3 個或更多項目的型別提升結果,現在與引數的順序無關。例如:
jnp.result_type(jnp.uint64, jnp.int64, jnp.float16)
和jnp.result_type(jnp.float16, jnp.uint64, jnp.int64)
都會傳回float16
,先前前者傳回float64
,後者則傳回float16
。
未記載的
jax.lax_linalg
線性代數模組的內容現在公開為jax.lax.linalg
。jax.random.PRNGKey
現在在 JIT 編譯內外產生相同的結果 (#4877)。這需要在少數特定情況下變更給定種子的結果使用
jax_enable_x64=False
時,作為 Python 整數傳遞的負種子現在會在 JIT 模式外傳回不同的結果。例如,jax.random.PRNGKey(-1)
先前傳回[4294967295, 4294967295]
,現在則傳回[0, 4294967295]
。這與 JIT 中的行為相符。超出 JIT 外
int64
可表示範圍的種子,現在會導致OverflowError
錯誤,而不是TypeError
錯誤。這與 JIT 中的行為相符。
若要還原先前針對 JIT 外使用
jax_enable_x64=False
的負整數傳回的金鑰,您可以使用key = random.PRNGKey(-1).at[0].set(0xFFFFFFFF)
DeviceArray 現在會在嘗試存取已刪除的值時,引發
RuntimeError
錯誤,而不是ValueError
錯誤。
jaxlib 0.1.58 (約 2021 年 1 月 12 日)#
修復了一個錯誤,該錯誤表示 JAX 有時會傳回平台特定的型別 (例如
np.cint
),而不是標準型別 (例如np.int32
)。(#4903)修復了常數摺疊某些 int16 運算時發生的當機問題。(#4971)
為
pytree.flatten()
新增了is_leaf
述詞。
jaxlib 0.1.57 (2020 年 11 月 12 日)#
修復了 GPU wheels 中的 manylinux2010 相容性問題。
將 CPU FFT 實作從 Eigen 切換為 PocketFFT。
修復了 bfloat16 值的雜湊未正確初始化且可能會變更的錯誤 (#4651)。
新增了在將陣列傳遞至 DLPack 時保留擁有權的支援 (#4636)。
修復了批次三角解算大小大於 128 但不是 128 倍數時的錯誤。
修復了在多個 GPU 上執行並行 FFT 時的錯誤 (#3518)。
修復了工具遺失的分析器錯誤 (#4427)。
已停止支援 CUDA 10.0。
jax 0.2.5 (2020 年 10 月 27 日)#
改進
確保
check_jaxpr
不執行 FLOPS。請參閱 #4650。擴展了 jax2tf 轉換的 JAX 原始項目集。請參閱 primitives_with_limited_support.md。
jax 0.2.4 (2020 年 10 月 19 日)#
jaxlib 0.1.56 (2020 年 10 月 14 日)#
jax 0.2.3 (2020 年 10 月 14 日)#
這麼快再次發布的原因是,我們需要暫時回溯新的 jit 快速路徑,同時調查效能降低問題
jax 0.2.2 (2020 年 10 月 13 日)#
jax 0.2.1 (2020 年 10 月 6 日)#
改進
作為 omnistaging 的優點,即使
jax.experimental.host_callback.id_print()
/jax.experimental.host_callback.id_tap()
的結果未在運算中使用,host_callback 函數仍會 (依程式順序) 執行。
jax (0.2.0) (2020 年 9 月 23 日)#
改進
預設啟用 Omnistaging。請參閱 #3370 和 omnistaging
jax (0.1.77) (2020 年 9 月 15 日)#
重大變更
jax.experimental.host_callback.id_tap()
的新簡化介面 (#4101)
jaxlib 0.1.55 (2020 年 9 月 8 日)#
更新 XLA
修復 DLPackManagedTensorToBuffer 中的錯誤 (#4196)
jax 0.1.76 (2020 年 9 月 8 日)#
jax 0.1.75 (2020 年 7 月 30 日)#
錯誤修正
讓 jnp.abs() 適用於未帶正負號的輸入 (#3914)
改進
在旗標後方新增 “Omnistaging” 行為,預設為停用 (#3370)
jax 0.1.74 (2020 年 7 月 29 日)#
新功能
BFGS (#3101)
TPU 支援半精度算術 (#3878)
錯誤修正
防止一些意外的 dtype 警告 (#3874)
修復自訂導數中的多執行緒錯誤 (#3845, #3869)
改進
更快速的 searchsorted 實作 (#3873)
改善 jax.numpy 排序演算法的測試涵蓋率 (#3836)
jaxlib 0.1.52 (2020 年 7 月 22 日)#
更新 XLA。
jax 0.1.73 (2020 年 7 月 22 日)#
最低 jaxlib 版本現在為 0.1.51。
新功能
jax.image.resize。 (#3703)
hfft 和 ihfft (#3664)
jax.numpy.intersect1d (#3726)
jax.numpy.lexsort (#3812)
lax.scan
和scan
原始項目支援unroll
參數,以便在降低至 XLA 時進行迴圈展開 (#3738)。
錯誤修正
修復縮減重複軸錯誤 (#3618)
修復 lax.pad 針對大小為 0 的輸入維度的形狀規則。(#3608)
讓 psum 轉置處理零餘切 (#3653)
修復在大小為 0 軸上取得 reduce-prod 的 JVP 時的形狀錯誤。(#3729)
支援透過 jax.lax.all_to_all 進行微分 (#3733)
解決 jax.scipy.special.zeta 中的 nan 問題 (#3777)
改進
對 jax2tf 進行了許多改進
使用單遍可變引數縮減重新實作 argmin/argmax。 (#3611)
預設啟用 XLA SPMD 分割。 (#3151)
新增了對 0d 轉置捲積的支援 (#3643)
讓 LU 梯度適用於低秩矩陣 (#3610)
支援 jet 中的 multiple_results 和自訂 JVP (#3657)
將 reduce-window 填充廣泛化,以支援 (lo, hi) 配對。 (#3728)
在 CPU 和 GPU 上實作複數捲積。 (#3735)
讓 jnp.take 適用於空陣列的空切片。 (#3751)
放寬 dot_general 的維度排序規則。 (#3778)
為 GPU 啟用緩衝區捐贈。 (#3800)
為 reduce window 運算新增了對基本擴張和視窗擴張的支援… (#3803)
jaxlib 0.1.51 (2020 年 7 月 2 日)#
更新 XLA。
為 host_callback 新增了新的執行階段支援。
jax 0.1.72 (2020 年 6 月 28 日)#
jax 0.1.71 (2020 年 6 月 25 日)#
jaxlib 0.1.50 (2020 年 6 月 25 日)#
新增了 CUDA 11.0 的支援。
停止支援 CUDA 9.2 (我們僅維護對最新四個 CUDA 版本的支援。)
更新 XLA。
jaxlib 0.1.49 (2020 年 6 月 19 日)#
錯誤修復
修復了可能導致編譯速度緩慢的建置問題 (tensorflow/tensorflow)
jaxlib 0.1.48 (2020 年 6 月 12 日)#
新功能
新增了對快速追蹤收集的支援。
新增了對裝置上堆積分析的初步支援。
為
bfloat16
型別實作np.nextafter
。CPU 和 GPU 上 FFT 的 Complex128 支援。
錯誤修復
改進了 GPU 上 float64
tanh
的準確性。GPU 上的 float64 散佈速度更快。
CPU 上的複數矩陣乘法應該更快。
CPU 上的穩定排序現在應該確實穩定。
CPU 後端的並行錯誤修正。
jax 0.1.70 (2020 年 6 月 8 日)#
jax 0.1.69 (2020 年 6 月 3 日)#
jax 0.1.68 (2020 年 5 月 21 日)#
jax 0.1.67 (2020 年 5 月 12 日)#
新功能
使用
axis_index_groups
支援在 pmapped 軸的子集上進行縮減 #2382。從編譯碼進行列印和呼叫主機端 Python 函數的實驗性支援。請參閱 id_print 和 id_tap (#3006)。
值得注意的變更
已收緊從
jax.numpy
匯出的名稱的能見度。這可能會中斷使用先前意外匯出的名稱的程式碼。
jaxlib 0.1.47 (2020 年 5 月 8 日)#
修復了 outfeed 的當機問題。
jax 0.1.66 (2020 年 5 月 5 日)#
jaxlib 0.1.46 (2020 年 5 月 5 日)#
修復了 Mac OS X 上線性代數函數的當機問題 (#432)。
修復了在作業系統或 Hypervisor 停用 AVX512 指令時,因使用 AVX512 指令而導致的非法指令當機問題 (#2906)。
jax 0.1.65 (2020 年 4 月 30 日)#
jaxlib 0.1.45 (2020 年 4 月 21 日)#
修復了區段錯誤:#2755
將 Sort HLO 上的 is_stable 選項貫穿至 Python。
jax 0.1.64 (2020 年 4 月 21 日)#
新功能
為功能索引更新新增了語法糖 #2684。
新增了
jax.numpy.unique()
#2760。新增了
jax.numpy.rint()
#2724。新增了
jax.numpy.rint()
#2724。為
jax.experimental.jet()
新增更多原始規則。
錯誤修復
更好的錯誤訊息
改進了
lax.while_loop()
反向模式微分的錯誤訊息 #2129。
jaxlib 0.1.44 (2020 年 4 月 16 日)#
修復了一個錯誤,該錯誤導致當存在多個不同型號的 GPU 時,JAX 只會編譯適用於第一個 GPU 的程式。
修復了
batch_group_count
卷積的錯誤。為更多 GPU 版本新增了預編譯的 SASS,以避免啟動 PTX 編譯卡頓。
jax 0.1.63 (2020 年 4 月 12 日)#
從 #2026 新增了
jax.custom_jvp
和jax.custom_vjp
,請參閱教學筆記本。已棄用jax.custom_transforms
並從文件中移除(但它仍然有效)。新增
scipy.sparse.linalg.cg
#2566。變更了 Tracers 的列印方式,以顯示更多有用的除錯資訊 #2591。
使
jax.numpy.isclose
正確處理nan
和inf
#2501。為
jax.experimental.jet
新增了幾個新規則 #2537。修復了當未提供
scale
/center
時的jax.experimental.stax.BatchNorm
。修復了
jax.numpy.einsum
中一些遺失的廣播案例 #2512。實作了以平行前綴掃描表示的
jax.numpy.cumsum
和jax.numpy.cumprod
#2596,並使reduce_prod
可微分至任意階 #2597。將
batch_group_count
新增至conv_general_dilated
#2635。為
test_util.check_grads
新增了文件字串 #2656。新增
callback_transform
#2665。實作了
rollaxis
、convolve
/correlate
1d & 2d、copysign
、trunc
、roots
,以及quantile
/percentile
插值選項。
jaxlib 0.1.43 (2020 年 3 月 31 日)#
修復了 GPU 上 Resnet-50 的效能衰退問題。
jax 0.1.62 (2020 年 3 月 21 日)#
JAX 已停止支援 Python 3.5。請升級至 Python 3.6 或更新版本。
移除了內部函數
lax._safe_mul
,它實作了0. * nan == 0.
的慣例。此變更意味著某些程式在微分時會產生 nan 值,而之前它們會產生正確的值,儘管它確保為其他程式產生 nan 值而不是靜默地產生不正確的結果。詳情請參閱 #2447 和 #1052。新增了
all_gather
平行便利函數。核心程式碼中新增了更多類型註釋。
jaxlib 0.1.42 (2020 年 3 月 19 日)#
jaxlib 0.1.41 由於 API 不相容而中斷了雲端 TPU 支援。此版本再次修復了它。
JAX 已停止支援 Python 3.5。請升級至 Python 3.6 或更新版本。
jax 0.1.61 (2020 年 3 月 17 日)#
修復了 Python 3.5 支援。這將是最後一個支援 Python 3.5 的 JAX 或 jaxlib 版本。
jax 0.1.60 (2020 年 3 月 17 日)#
新功能
jax.pmap()
具有static_broadcast_argnums
參數,允許使用者指定應視為編譯時期常數並應廣播到所有裝置的參數。它的運作方式與jax.jit()
中的static_argnums
類似。改進了當追蹤器錯誤地儲存在全域狀態時的錯誤訊息。
新增了
jax.nn.one_hot()
實用函數。新增了用於指數級加速高階自動微分的
jax.experimental.jet
。為
jax.lax.broadcast_in_dim()
的引數新增了更多正確性檢查。
最低 jaxlib 版本現在為 0.1.41。
jaxlib 0.1.40 (2020 年 3 月 4 日)#
在 Jaxlib 中新增了對 TensorFlow profiler 的實驗性支援,允許從 TensorBoard 追蹤 CPU 和 GPU 計算。
包含透過 NCCL 通訊的多主機 GPU 計算的原型支援。
提高了 GPU 上 NCCL 集體運算的效能。
新增了 TopK、CustomCallWithoutLayout、CustomCallWithLayout、IGammaGradA 和 RandomGamma 實作。
支援在 XLA 編譯時已知的裝置分配。
jax 0.1.59 (2020 年 2 月 11 日)#
重大變更
最低 jaxlib 版本現在為 0.1.38。
透過移除
Jaxpr.freevars
和Jaxpr.bound_subjaxprs
簡化了Jaxpr
。呼叫基本運算 (xla_call
、xla_pmap
、sharded_call
和remat_call
) 取得了一個新的參數call_jaxpr
,其中包含完全封閉的(無constvars
)jaxpr。此外,為基本運算新增了一個新的欄位call_primitive
。
新功能
lax.cond
的反向模式自動微分 (例如grad
),使其現在在兩種模式下都可微分 (#2091)JAX 現在支援 DLPack,它允許以零複製方式與其他程式庫(例如 PyTorch)共用 CPU 和 GPU 陣列。
JAX GPU DeviceArrays 現在支援
__cuda_array_interface__
,這是另一個用於與其他程式庫(例如 CuPy 和 Numba)共用 GPU 陣列的零複製協定。JAX CPU 裝置緩衝區現在實作了 Python 緩衝區協定,允許在 JAX 和 NumPy 之間進行零複製緩衝區共用。
新增了 JAX_SKIP_SLOW_TESTS 環境變數,以跳過已知速度較慢的測試。
jaxlib 0.1.39 (2020 年 2 月 11 日)#
更新 XLA。
jaxlib 0.1.38 (2020 年 1 月 29 日)#
不再支援 CUDA 9.0。
CUDA 10.2 wheels 現在預設為建置。
jax 0.1.58 (2020 年 1 月 28 日)#
重大變更
JAX 已停止支援 Python 2,因為 Python 2 已於 2020 年 1 月 1 日終止生命週期。請更新至 Python 3.5 或更新版本。
新功能
while 迴圈的前向模式自動微分 (
jvp
) (#1980)
新的 NumPy 和 SciPy 函數
GPU 上的批次 Cholesky 分解現在使用更有效率的批次核心。
值得注意的錯誤修復#
隨著 Python 3 的升級,JAX 不再依賴
fastcache
,這應有助於安裝。