jax.experimental.sparse 模組#

注意

jax.experimental.sparse 中的方法是實驗性的參考實作,不建議用於效能關鍵的應用程式。

jax.experimental.sparse 模組包含 JAX 中稀疏矩陣運算的實驗性支援。它正在積極開發中,API 可能會變更。主要提供的介面是 BCOO 稀疏陣列類型,以及 sparsify() 轉換。

批次座標 (BCOO) 稀疏矩陣#

目前 JAX 中可用的主要高階稀疏物件是 BCOO,或批次座標稀疏陣列,它提供與 JAX 轉換相容的壓縮儲存格式,特別是 JIT (例如 jax.jit())、批次處理 (例如 jax.vmap()) 和自動微分 (例如 jax.grad())。

以下是從密集陣列建立稀疏陣列的範例

>>> from jax.experimental import sparse
>>> import jax.numpy as jnp
>>> import numpy as np
>>> M = jnp.array([[0., 1., 0., 2.],
...                [3., 0., 0., 0.],
...                [0., 0., 4., 0.]])
>>> M_sp = sparse.BCOO.fromdense(M)
>>> M_sp
BCOO(float32[3, 4], nse=4)

使用 todense() 方法轉換回密集陣列

>>> M_sp.todense()
Array([[0., 1., 0., 2.],
       [3., 0., 0., 0.],
       [0., 0., 4., 0.]], dtype=float32)

BCOO 格式是標準 COO 格式的修改版本,密集表示可以在 dataindices 屬性中看到

>>> M_sp.data  # Explicitly stored data
Array([1., 2., 3., 4.], dtype=float32)
>>> M_sp.indices # Indices of the stored data
Array([[0, 1],
       [0, 3],
       [1, 0],
       [2, 2]], dtype=int32)

BCOO 物件具有熟悉的陣列式屬性,以及稀疏專用屬性

>>> M_sp.ndim
2
>>> M_sp.shape
(3, 4)
>>> M_sp.dtype
dtype('float32')
>>> M_sp.nse  # "number of specified elements"
4

BCOO 物件也實作了許多陣列式方法,讓您可以在 jax 程式中直接使用它們。例如,這裡我們計算轉置矩陣向量乘積

>>> y = jnp.array([3., 6., 5.])
>>> M_sp.T @ y
Array([18.,  3., 20.,  6.], dtype=float32)
>>> M.T @ y  # Compare to dense version
Array([18.,  3., 20.,  6.], dtype=float32)

BCOO 物件設計為與 JAX 轉換相容,包括 jax.jit()jax.vmap()jax.grad() 等。例如

>>> from jax import grad, jit
>>> def f(y):
...   return (M_sp.T @ y).sum()
...
>>> jit(grad(f))(y)
Array([3., 3., 4.], dtype=float32)

但是請注意,在正常情況下,jax.numpyjax.lax 函數不知道如何處理稀疏矩陣,因此嘗試計算諸如 jnp.dot(M_sp.T, y) 之類的東西會導致錯誤(但是,請參閱下一節)。

稀疏化轉換#

JAX 稀疏實作的首要目標是提供一種從密集計算無縫切換到稀疏計算的方法,而無需修改密集實作。此稀疏實驗透過 sparsify() 轉換來實現此目的。

考慮這個函數,它從矩陣和向量輸入計算更複雜的結果

>>> def f(M, v):
...   return 2 * jnp.dot(jnp.log1p(M.T), v) + 1
...
>>> f(M, y)
Array([17.635532,  5.158883, 17.09438 ,  7.591674], dtype=float32)

如果我們直接將稀疏矩陣傳遞給它,則會導致錯誤,因為 jnp 函數無法識別稀疏輸入。但是,透過 sparsify(),我們獲得了這個函數的一個版本,它確實接受稀疏矩陣

>>> f_sp = sparse.sparsify(f)
>>> f_sp(M_sp, y)
Array([17.635532,  5.158883, 17.09438 ,  7.591674], dtype=float32)

sparsify() 的支援包括大量最常見的基本運算,包括

  • 廣義(批次處理)矩陣乘積和愛因斯坦求和 (dot_general_p)

  • 零保留元素二元運算(例如 add_pmul_p 等)

  • 零保留元素一元運算(例如 abs_pjax.lax.neg_p 等)

  • 求和縮減 (reduce_sum_p)

  • 一般索引運算 (slice_plax.dynamic_slice_plax.gather_p)

  • 串聯和堆疊 (concatenate_p)

  • 轉置和重塑 ((transpose_preshape_psqueeze_pbroadcast_in_dim_p)

  • 一些高階函數 (cond_pwhile_pscan_p)

  • 一些簡單的 1D 卷積 (conv_general_dilated_p)

幾乎任何降低到這些受支援基本運算的 jax.numpy 函數都可以在 sparsify 轉換中使用,以對稀疏陣列進行運算。這組基本運算足以實現相對複雜的稀疏工作流程,如下一節所示。

範例:稀疏邏輯迴歸#

作為更複雜稀疏工作流程的範例,讓我們考慮在 JAX 中實作的簡單邏輯迴歸。請注意,以下實作沒有參考稀疏性

>>> import functools
>>> from sklearn.datasets import make_classification
>>> from jax.scipy import optimize
>>> def sigmoid(x):
...   return 0.5 * (jnp.tanh(x / 2) + 1)
...
>>> def y_model(params, X):
...   return sigmoid(jnp.dot(X, params[1:]) + params[0])
...
>>> def loss(params, X, y):
...   y_hat = y_model(params, X)
...   return -jnp.mean(y * jnp.log(y_hat) + (1 - y) * jnp.log(1 - y_hat))
...
>>> def fit_logreg(X, y):
...   params = jnp.zeros(X.shape[1] + 1)
...   result = optimize.minimize(functools.partial(loss, X=X, y=y),
...                              x0=params, method='BFGS')
...   return result.x
>>> X, y = make_classification(n_classes=2, random_state=1701)
>>> params_dense = fit_logreg(X, y)
>>> print(params_dense)  
[-0.7298445   0.29893667  1.0248291  -0.44436368  0.8785025  -0.7724008
 -0.62893456  0.2934014   0.82974285  0.16838408 -0.39774987 -0.5071844
  0.2028872   0.5227761  -0.3739224  -0.7104083   2.4212713   0.6310087
 -0.67060554  0.03139788 -0.05359547]

這會傳回密集邏輯迴歸問題的最佳擬合參數。若要在稀疏資料上擬合相同的模型,我們可以套用 sparsify() 轉換

>>> Xsp = sparse.BCOO.fromdense(X)  # Sparse version of the input
>>> fit_logreg_sp = sparse.sparsify(fit_logreg)  # Sparse-transformed fit function
>>> params_sparse = fit_logreg_sp(Xsp, y)
>>> print(params_sparse)  
[-0.72971725  0.29878938  1.0246326  -0.44430563  0.8784217  -0.77225566
 -0.6288222   0.29335397  0.8293481   0.16820715 -0.39764675 -0.5069753
  0.202579    0.522672   -0.3740134  -0.7102678   2.4209507   0.6310593
 -0.670236    0.03132951 -0.05356663]

稀疏 API 參考#

sparsify(f[, use_tracer])

實驗性稀疏化轉換。

grad(fun[, argnums, has_aux])

jax.grad() 的稀疏感知版本

value_and_grad(fun[, argnums, has_aux])

jax.value_and_grad() 的稀疏感知版本

empty(shape[, dtype, index_dtype, sparse_format])

建立空的稀疏陣列。

eye(N[, M, k, dtype, index_dtype, sparse_format])

建立 2D 稀疏單位矩陣。

todense(arr)

將輸入轉換為密集矩陣。

random_bcoo(key, shape, *[, dtype, ...])

產生隨機 BCOO 矩陣。

JAXSparse(args, *, shape)

高階 JAX 稀疏物件的基底類別。

BCOO 資料結構#

BCOO批次 COO 格式,是 jax.experimental.sparse 中實作的主要稀疏資料結構。其運算與 JAX 的核心轉換相容,包括批次處理 (例如 jax.vmap()) 和自動微分 (例如 jax.grad())。

BCOO(args, *, shape[, indices_sorted, ...])

在 JAX 中實作的實驗性批次 COO 矩陣

bcoo_broadcast_in_dim(mat, *, shape, ...[, ...])

透過複製資料來擴展 BCOO 陣列的大小和秩。

bcoo_concatenate(operands, *, dimension)

jax.lax.concatenate() 的稀疏實作

bcoo_dot_general(lhs, rhs, *, dimension_numbers)

一般收縮運算。

bcoo_dot_general_sampled(A, B, indices, *, ...)

輸出在給定稀疏索引處計算的收縮運算。

bcoo_dynamic_slice(mat, start_indices, ...)

{func}`jax.lax.dynamic_slice` 的稀疏實作。

bcoo_extract(sparr, arr, *[, assume_unique])

根據稀疏陣列的索引,從密集陣列中提取值。

bcoo_fromdense(mat, *[, nse, n_batch, ...])

從密集矩陣建立 BCOO 格式稀疏矩陣。

bcoo_gather(operand, start_indices, ...[, ...])

lax.gather 的 BCOO 版本。

bcoo_multiply_dense(sp_mat, v)

稀疏陣列和密集陣列之間的元素乘法。

bcoo_multiply_sparse(lhs, rhs)

兩個稀疏陣列的元素乘法。

bcoo_update_layout(mat, *[, n_batch, ...])

更新 BCOO 矩陣的儲存佈局(即 n_batch & n_dense)。

bcoo_reduce_sum(mat, *, axes)

在給定軸上對陣列元素求和。

bcoo_reshape(mat, *, new_sizes[, ...])

{func}`jax.lax.reshape` 的稀疏實作。

bcoo_slice(mat, *, start_indices, limit_indices)

{func}`jax.lax.slice` 的稀疏實作。

bcoo_sort_indices(mat)

排序 BCOO 陣列的索引。

bcoo_squeeze(arr, *, dimensions)

{func}`jax.lax.squeeze` 的稀疏實作。

bcoo_sum_duplicates(mat[, nse])

對 BCOO 陣列中的重複索引求和,傳回具有排序索引的陣列。

bcoo_todense(mat)

將批次稀疏矩陣轉換為密集矩陣。

bcoo_transpose(mat, *, permutation)

轉置 BCOO 格式陣列。

BCSR 資料結構#

BCSR批次壓縮稀疏列格式,正在開發中。其運算與 JAX 的核心轉換相容,包括批次處理 (例如 jax.vmap()) 和自動微分 (例如 jax.grad())。

BCSR(args, *, shape[, indices_sorted, ...])

在 JAX 中實作的實驗性批次 CSR 矩陣。

bcsr_dot_general(lhs, rhs, *, dimension_numbers)

一般收縮運算。

bcsr_extract(indices, indptr, mat)

從給定的 BCSR (indices, indptr) 從稠密矩陣中提取數值。

bcsr_fromdense(mat, *[, nse, n_batch, ...])

從稠密矩陣創建 BCSR 格式的稀疏矩陣。

bcsr_todense(mat)

將批次稀疏矩陣轉換為密集矩陣。

其他稀疏資料結構#

其他稀疏資料結構包括 COOCSRCSC。 這些是用於簡單稀疏結構的參考實作,其中實作了一些核心操作。 它們的操作通常與自動微分轉換(例如 jax.grad())相容,但不與批次轉換(例如 jax.vmap())相容。

COO(args, *, shape[, rows_sorted, cols_sorted])

在 JAX 中實作的實驗性 COO 矩陣。

CSC(args, *, shape)

在 JAX 中實作的實驗性 CSC 矩陣;API 可能會變更。

CSR(args, *, shape)

在 JAX 中實作的實驗性 CSR 矩陣。

coo_fromdense(mat, *[, nse, index_dtype])

從稠密矩陣創建 COO 格式的稀疏矩陣。

coo_matmat(mat, B, *[, transpose])

COO 稀疏矩陣和稠密矩陣的乘積。

coo_matvec(mat, v[, transpose])

COO 稀疏矩陣和稠密向量的乘積。

coo_todense(mat)

將 COO 格式的稀疏矩陣轉換為稠密矩陣。

csr_fromdense(mat, *[, nse, index_dtype])

從稠密矩陣創建 CSR 格式的稀疏矩陣。

csr_matmat(mat, B, *[, transpose])

CSR 稀疏矩陣和稠密矩陣的乘積。

csr_matvec(mat, v[, transpose])

CSR 稀疏矩陣和稠密向量的乘積。

csr_todense(mat)

將 CSR 格式的稀疏矩陣轉換為稠密矩陣。

jax.experimental.sparse.linalg#

稀疏線性代數常式。

spsolve(data, indices, indptr, b[, tol, reorder])

使用 QR 分解的稀疏直接求解器。

lobpcg_standard(A, X[, m, tol])

使用 LOBPCG 常式計算前 k 個標準特徵值。