jax.scipy.linalg.block_diag#

jax.scipy.linalg.block_diag(*arrs)[原始碼]#

從輸入陣列建立區塊對角矩陣。

scipy.linalg.block_diag() 的 JAX 實作。

參數:

*arrs (ArrayLike) – 最多二維的陣列

傳回:

透過將輸入陣列沿對角線放置而建構的 2D 區塊對角陣列。

傳回類型:

Array

範例

>>> A = jnp.ones((1, 1))
>>> B = jnp.ones((2, 2))
>>> C = jnp.ones((3, 3))
>>> jax.scipy.linalg.block_diag(A, B, C)
Array([[1., 0., 0., 0., 0., 0.],
       [0., 1., 1., 0., 0., 0.],
       [0., 1., 1., 0., 0., 0.],
       [0., 0., 0., 1., 1., 1.],
       [0., 0., 0., 1., 1., 1.],
       [0., 0., 0., 1., 1., 1.]], dtype=float32)