jax.numpy.block#
- jax.numpy.block(arrays)[原始碼]#
從區塊列表建立陣列。
numpy.block()
的 JAX 實作。- 參數:
arrays (ArrayLike | list[ArrayLike]) – 一個陣列,或陣列的巢狀列表,它們將串連在一起以形成最終陣列。
- 返回:
從輸入建構的單一陣列。
- 返回類型:
範例
考慮這些區塊
>>> zeros = jnp.zeros((2, 2)) >>> ones = jnp.ones((2, 2)) >>> twos = jnp.full((2, 2), 2) >>> threes = jnp.full((2, 2), 3)
將單一陣列傳遞給
block()
會返回該陣列>>> jnp.block(zeros) Array([[0., 0.], [0., 0.]], dtype=float32)
傳遞簡單的陣列列表會沿最後一個軸串連它們
>>> jnp.block([zeros, ones]) Array([[0., 0., 1., 1.], [0., 0., 1., 1.]], dtype=float32)
傳遞雙層巢狀的陣列列表會沿最後一個軸串連內部列表,並沿倒數第二個軸串連外部列表
>>> jnp.block([[zeros, ones], ... [twos, threes]]) Array([[0., 0., 1., 1.], [0., 0., 1., 1.], [2., 2., 3., 3.], [2., 2., 3., 3.]], dtype=float32)
請注意,區塊不需要在所有維度上對齊,儘管沿串連軸的大小必須匹配。例如,這是有效的,因為在內部的水平串連之後,產生的區塊具有適用於外部垂直串連的有效形狀。
>>> a = jnp.zeros((2, 1)) >>> b = jnp.ones((2, 3)) >>> c = jnp.full((1, 2), 2) >>> d = jnp.full((1, 2), 3) >>> jnp.block([[a, b], [c, d]]) Array([[0., 1., 1., 1.], [0., 1., 1., 1.], [2., 2., 3., 3.]], dtype=float32)
另請注意,此邏輯可推廣到 3 個或更多維度的區塊。以下是一個 3 維區塊式陣列
>>> x = jnp.arange(6).reshape((1, 2, 3)) >>> blocks = [[[x for i in range(3)] for j in range(4)] for k in range(5)] >>> jnp.block(blocks).shape (5, 8, 9)