jax.numpy.diagonal#
- jax.numpy.diagonal(a, offset=0, axis1=0, axis2=1)[原始碼]#
傳回陣列的指定對角線。
JAX 版本的
numpy.diagonal()
實作。JAX 版本總是傳回輸入的副本,儘管如果在 JIT 編譯中使用,編譯器可能會避免複製。
- 參數:
- 傳回類型:
範例
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6], ... [7, 8, 9]]) >>> jnp.diagonal(x) Array([1, 5, 9], dtype=int32) >>> jnp.diagonal(x, offset=1) Array([2, 6], dtype=int32) >>> jnp.diagonal(x, offset=-1) Array([4, 8], dtype=int32)