Pallas 更新日誌#

這是 jax.experimental.pallas 特有的變更列表。如需完整的 JAX 變更日誌,請參閱這裡

隨 jax 0.5.0 版本發布#

隨 jax 0.4.37 版本發布#

  • 新功能

    • 為 Triton 後端的 dot lowering 新增 DotAlgorithmPreset 精度引數的支援。

隨 jax 0.4.36 版本發布 (2024 年 12 月 6 日)#

隨 jax 0.4.35 版本發布 (2024 年 10 月 22 日)#

  • 移除

    • 移除先前已棄用的別名 jax.experimental.pallas.tpu.CostEstimatejax.experimental.tpu.run_scoped()。兩者現在都可在 jax.experimental.pallas 中使用。

  • 新功能

    • 新增成本估算工具 pl.estimate_cost(),用於從 JAX 參考函式自動建構核心成本估算。

隨 jax 0.4.34 版本發布 (2024 年 10 月 4 日)#

隨 jax 0.4.33 版本發布 (2024 年 9 月 16 日)#

隨 jax 0.4.32 版本發布 (2024 年 9 月 11 日)#

  • 變更

    • 核心函式不允許關閉常數。相反地,所有需要的陣列都必須作為輸入傳遞,並具有適當的區塊規格 (#22746)。

  • 新功能

    • 改進索引映射函式簽名中錯誤的錯誤訊息,以包含索引映射的名稱和來源位置。

隨 jax 0.4.31 版本發布 (2024 年 7 月 29 日)#

  • 變更

    • jax.experimental.pallas.BlockSpec 現在預期 block_shapeindex_map 之前 傳遞。舊的引數順序已棄用,並將在未來版本中移除。

    • jax.experimental.pallas.GridSpec 不再具有 in_specs_treeout_specs_tree 欄位,且 in_specsout_specs 樹狀結構現在將值儲存為 BlockSpec 的 pytree。先前,in_specsout_specs 已扁平化 (#22552)。

    • jax.experimental.pallas.GridSpeccompute_index 方法已被移除,因為它是私有的。同樣地,get_grid_mappingunzip_dynamic_bounds 已從 BlockSpec 中移除 (#22593)。

    • 修正了解譯模式,使其可與涉及填充的 BlockSpec 搭配使用 (#22275)。解譯模式中的填充將使用 NaN,以協助除錯超出範圍的錯誤,但此行為在自訂核心模式中不存在,且不應依賴。

    • 先前可以匯入許多旨在設為私有的 API,例如 jax.experimental.pallas.pallas。現在已不可能。

  • 新功能

    • 新增 BlockSpec 的文件:網格和 BlockSpecs

    • 改進 jax.experimental.pallas.pallas_call() API 的錯誤訊息。

    • 為 TPU 新增 lax.shift_right_arithmetic (#22279) 和 lax.erf_inv (#22310) 的 lowering 規則。

    • 為 Pallas TPU 自訂核心新增形狀多型性的初始支援
      (#22084).

    • 為 checkify 新增 TPU 支援。( #22480)

    • 當區塊大小與 TPU 需求不符時,新增更清晰的錯誤訊息。先前,錯誤來自 Mosaic 後端,且沒有實用的 Python 堆疊追蹤。

    • 新增具有 1D 區塊的 TPU lowering 支援,並放寬對至少 2 個維度的區塊大小的要求:最後 2 個維度必須分別可被 8 和 128 整除,除非它們跨越整個對應的陣列維度。先前,僅當最後兩個維度中的區塊維度分別小於 8 和 128 時,才允許跨越整個陣列的區塊維度。

隨 JAX 0.4.30 版本發布 (2024 年 6 月 18 日)#