jax.random.orthogonal#
- jax.random.orthogonal(key, n, shape=(), dtype=<class 'float'>)[source]#
從正交群 O(n) 中均勻取樣。
如果 dtype 是複數,則從酉群 U(n) 中均勻取樣。
- 參數:
key (ArrayLike) – 作為隨機金鑰使用的 PRNG 金鑰。
n (int) – 指示結果維度的整數。
shape (Shape) – 選項,結果的批次維度。預設值為 ()。
dtype (DTypeLikeFloat) – 選項,返回值的浮點數 dtype(如果 jax_enable_x64 為 true,則預設值為 float64,否則為 float32)。
- 返回:
形狀為 (*shape, n, n) 和指定 dtype 的隨機陣列。
- 返回型別:
參考文獻