jax.random.orthogonal#

jax.random.orthogonal(key, n, shape=(), dtype=<class 'float'>)[source]#

從正交群 O(n) 中均勻取樣。

如果 dtype 是複數,則從酉群 U(n) 中均勻取樣。

參數:
  • key (ArrayLike) – 作為隨機金鑰使用的 PRNG 金鑰。

  • n (int) – 指示結果維度的整數。

  • shape (Shape) – 選項,結果的批次維度。預設值為 ()。

  • dtype (DTypeLikeFloat) – 選項,返回值的浮點數 dtype(如果 jax_enable_x64 為 true,則預設值為 float64,否則為 float32)。

返回:

形狀為 (*shape, n, n) 和指定 dtype 的隨機陣列。

返回型別:

Array

參考文獻