jax.numpy.block#

jax.numpy.block(arrays)[原始碼]#

從區塊列表建立陣列。

numpy.block() 的 JAX 實作。

參數:

arrays (ArrayLike | list[ArrayLike]) – 一個陣列,或陣列的巢狀列表,它們將串連在一起以形成最終陣列。

返回:

從輸入建構的單一陣列。

返回類型:

Array

參見

範例

考慮這些區塊

>>> zeros = jnp.zeros((2, 2))
>>> ones = jnp.ones((2, 2))
>>> twos = jnp.full((2, 2), 2)
>>> threes = jnp.full((2, 2), 3)

將單一陣列傳遞給 block() 會返回該陣列

>>> jnp.block(zeros)
Array([[0., 0.],
       [0., 0.]], dtype=float32)

傳遞簡單的陣列列表會沿最後一個軸串連它們

>>> jnp.block([zeros, ones])
Array([[0., 0., 1., 1.],
       [0., 0., 1., 1.]], dtype=float32)

傳遞雙層巢狀的陣列列表會沿最後一個軸串連內部列表,並沿倒數第二個軸串連外部列表

>>> jnp.block([[zeros, ones],
...            [twos, threes]])
Array([[0., 0., 1., 1.],
       [0., 0., 1., 1.],
       [2., 2., 3., 3.],
       [2., 2., 3., 3.]], dtype=float32)

請注意,區塊不需要在所有維度上對齊,儘管沿串連軸的大小必須匹配。例如,這是有效的,因為在內部的水平串連之後,產生的區塊具有適用於外部垂直串連的有效形狀。

>>> a = jnp.zeros((2, 1))
>>> b = jnp.ones((2, 3))
>>> c = jnp.full((1, 2), 2)
>>> d = jnp.full((1, 2), 3)
>>> jnp.block([[a, b], [c, d]])
Array([[0., 1., 1., 1.],
       [0., 1., 1., 1.],
       [2., 2., 3., 3.]], dtype=float32)

另請注意,此邏輯可推廣到 3 個或更多維度的區塊。以下是一個 3 維區塊式陣列

>>> x = jnp.arange(6).reshape((1, 2, 3))
>>> blocks = [[[x for i in range(3)] for j in range(4)] for k in range(5)]
>>> jnp.block(blocks).shape
(5, 8, 9)