jax.numpy.fill_diagonal#

jax.numpy.fill_diagonal(a, val, wrap=False, *, inplace=True)[原始碼]#

傳回對角線被覆寫的陣列副本。

JAX 實作的 numpy.fill_diagonal()

numpy.fill_diagonal() 的語意是就地修改陣列,這對於 JAX 的不可變陣列來說是不可能的。JAX 版本傳回輸入的修改副本,並新增了 inplace 參數,使用者必須將其設定為 False`,以提醒使用者此 API 差異。

參數:
  • a (ArrayLike) – 輸入陣列。必須具有 a.ndim >= 2。如果 a.ndim >= 3,則所有維度的大小必須相同。

  • val (ArrayLike) – 用於填滿對角線的純量或陣列。如果是陣列,它將被展平並重複以填滿對角線條目。

  • inplace (bool) – 必須設定為 False,以指示輸入不會就地修改,而是傳回修改後的副本。

  • wrap (bool)

傳回:

將對角線設定為 vala 副本。

傳回型別:

Array

範例

>>> x = jnp.zeros((3, 3), dtype=int)
>>> jnp.fill_diagonal(x, jnp.array([1, 2, 3]), inplace=False)
Array([[1, 0, 0],
       [0, 2, 0],
       [0, 0, 3]], dtype=int32)

numpy.fill_diagonal() 不同,輸入 x 不會被修改。

如果對角線值條目過多,將會被截斷

>>> jnp.fill_diagonal(x, jnp.arange(100, 200), inplace=False)
Array([[100,   0,   0],
       [  0, 101,   0],
       [  0,   0, 102]], dtype=int32)

如果對角線條目過少,將會被重複

>>> x = jnp.zeros((4, 4), dtype=int)
>>> jnp.fill_diagonal(x, jnp.array([3, 4]), inplace=False)
Array([[3, 0, 0, 0],
       [0, 4, 0, 0],
       [0, 0, 3, 0],
       [0, 0, 0, 4]], dtype=int32)

對於非方形陣列,將填滿前導方形切片的對角線

>>> x = jnp.zeros((3, 5), dtype=int)
>>> jnp.fill_diagonal(x, 1, inplace=False)
Array([[1, 0, 0, 0, 0],
       [0, 1, 0, 0, 0],
       [0, 0, 1, 0, 0]], dtype=int32)

對於方形 N 維陣列,將填滿 N 維對角線

>>> y = jnp.zeros((2, 2, 2))
>>> jnp.fill_diagonal(y, 1, inplace=False)
Array([[[1., 0.],
        [0., 0.]],

       [[0., 0.],
        [0., 1.]]], dtype=float32)