jax.numpy.fill_diagonal#
- jax.numpy.fill_diagonal(a, val, wrap=False, *, inplace=True)[原始碼]#
傳回對角線被覆寫的陣列副本。
JAX 實作的
numpy.fill_diagonal()
。numpy.fill_diagonal()
的語意是就地修改陣列,這對於 JAX 的不可變陣列來說是不可能的。JAX 版本傳回輸入的修改副本,並新增了inplace
參數,使用者必須將其設定為 False`,以提醒使用者此 API 差異。- 參數:
- 傳回:
將對角線設定為
val
的a
副本。- 傳回型別:
範例
>>> x = jnp.zeros((3, 3), dtype=int) >>> jnp.fill_diagonal(x, jnp.array([1, 2, 3]), inplace=False) Array([[1, 0, 0], [0, 2, 0], [0, 0, 3]], dtype=int32)
與
numpy.fill_diagonal()
不同,輸入x
不會被修改。如果對角線值條目過多,將會被截斷
>>> jnp.fill_diagonal(x, jnp.arange(100, 200), inplace=False) Array([[100, 0, 0], [ 0, 101, 0], [ 0, 0, 102]], dtype=int32)
如果對角線條目過少,將會被重複
>>> x = jnp.zeros((4, 4), dtype=int) >>> jnp.fill_diagonal(x, jnp.array([3, 4]), inplace=False) Array([[3, 0, 0, 0], [0, 4, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]], dtype=int32)
對於非方形陣列,將填滿前導方形切片的對角線
>>> x = jnp.zeros((3, 5), dtype=int) >>> jnp.fill_diagonal(x, 1, inplace=False) Array([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0]], dtype=int32)
對於方形 N 維陣列,將填滿 N 維對角線
>>> y = jnp.zeros((2, 2, 2)) >>> jnp.fill_diagonal(y, 1, inplace=False) Array([[[1., 0.], [0., 0.]], [[0., 0.], [0., 1.]]], dtype=float32)