Pallas 更新日誌#
這是 jax.experimental.pallas
特有的變更列表。如需完整的 JAX 變更日誌,請參閱這裡。
隨 jax 0.5.0 版本發布#
新功能
在 TPU 上為
jax.experimental.pallas.debug_print()
新增向量支援。
隨 jax 0.4.37 版本發布#
新功能
為 Triton 後端的
dot
lowering 新增DotAlgorithmPreset
精度引數的支援。
隨 jax 0.4.36 版本發布 (2024 年 12 月 6 日)#
隨 jax 0.4.35 版本發布 (2024 年 10 月 22 日)#
移除
移除先前已棄用的別名
jax.experimental.pallas.tpu.CostEstimate
和jax.experimental.tpu.run_scoped()
。兩者現在都可在jax.experimental.pallas
中使用。
新功能
新增成本估算工具
pl.estimate_cost()
,用於從 JAX 參考函式自動建構核心成本估算。
隨 jax 0.4.34 版本發布 (2024 年 10 月 4 日)#
變更
jax.experimental.pallas.debug_print()
不再要求所有引數都必須是純量。引數的限制取決於後端:非純量引數目前僅在使用 Triton 時在 GPU 上受支援。jax.experimental.pallas.BlockSpec
不再支援先前已棄用的引數順序,其中index_map
在block_shape
之前。
棄用
jax.experimental.pallas.gpu
子模組已棄用,以避免與jax.experimental.pallas.mosaic_gpu
混淆。若要使用 Triton 後端,請匯入jax.experimental.pallas.triton
。
新功能
jax.experimental.pallas.pallas_call()
現在接受scratch_shapes
,這是一個 PyTree,用於指定核心所需的後端特定暫時物件,例如緩衝區、同步基本元件等。checkify.check()
現在可用於在呼叫 pallas_call 時插入執行階段斷言,並搭配pltpu.enable_runtime_assert(True)
內容管理器。
隨 jax 0.4.33 版本發布 (2024 年 9 月 16 日)#
隨 jax 0.4.32 版本發布 (2024 年 9 月 11 日)#
變更
核心函式不允許關閉常數。相反地,所有需要的陣列都必須作為輸入傳遞,並具有適當的區塊規格 (#22746)。
新功能
改進索引映射函式簽名中錯誤的錯誤訊息,以包含索引映射的名稱和來源位置。
隨 jax 0.4.31 版本發布 (2024 年 7 月 29 日)#
變更
jax.experimental.pallas.BlockSpec
現在預期block_shape
在index_map
之前 傳遞。舊的引數順序已棄用,並將在未來版本中移除。jax.experimental.pallas.GridSpec
不再具有in_specs_tree
和out_specs_tree
欄位,且in_specs
和out_specs
樹狀結構現在將值儲存為 BlockSpec 的 pytree。先前,in_specs
和out_specs
已扁平化 (#22552)。jax.experimental.pallas.GridSpec
的compute_index
方法已被移除,因為它是私有的。同樣地,get_grid_mapping
和unzip_dynamic_bounds
已從BlockSpec
中移除 (#22593)。修正了解譯模式,使其可與涉及填充的 BlockSpec 搭配使用 (#22275)。解譯模式中的填充將使用 NaN,以協助除錯超出範圍的錯誤,但此行為在自訂核心模式中不存在,且不應依賴。
先前可以匯入許多旨在設為私有的 API,例如
jax.experimental.pallas.pallas
。現在已不可能。
新功能
新增 BlockSpec 的文件:網格和 BlockSpecs。
改進
jax.experimental.pallas.pallas_call()
API 的錯誤訊息。為 TPU 新增
lax.shift_right_arithmetic
(#22279) 和lax.erf_inv
(#22310) 的 lowering 規則。為 Pallas TPU 自訂核心新增形狀多型性的初始支援
(#22084).為 checkify 新增 TPU 支援。( #22480)
當區塊大小與 TPU 需求不符時,新增更清晰的錯誤訊息。先前,錯誤來自 Mosaic 後端,且沒有實用的 Python 堆疊追蹤。
新增具有 1D 區塊的 TPU lowering 支援,並放寬對至少 2 個維度的區塊大小的要求:最後 2 個維度必須分別可被 8 和 128 整除,除非它們跨越整個對應的陣列維度。先前,僅當最後兩個維度中的區塊維度分別小於 8 和 128 時,才允許跨越整個陣列的區塊維度。
隨 JAX 0.4.30 版本發布 (2024 年 6 月 18 日)#
新功能
在解譯模式中為
jax.experimental.pallas.pallas_call()
新增 checkify 支援 (#21862)。改進 TPU 核心的 PRNG 金鑰支援 (#21773)。