jax.random.multivariate_normal#

jax.random.multivariate_normal(key, mean, cov, shape=None, dtype=None, method='cholesky')[原始碼]#

採樣具有給定平均值和共變異數的多元常態隨機值。

這些值根據機率密度函數傳回

\[f(x;\mu, \Sigma) = (2\pi)^{-k/2} \det(\Sigma)^{-1}e^{-\frac{1}{2}(x - \mu)^T \Sigma^{-1} (x - \mu)}\]

其中 \(k\) 是維度,\(\mu\) 是平均值 (由 mean 給定),而 \(\Sigma\) 是共變異數矩陣 (由 cov 給定)。

參數:
  • key (ArrayLike) – 作為隨機金鑰使用的 PRNG 金鑰。

  • mean (RealArray) – 形狀為 (..., n) 的平均值向量。

  • cov (RealArray) – 形狀為 (..., n, n) 的正定共變異數矩陣。批次形狀 ... 必須與 mean 的批次形狀廣播相容。

  • shape (Shape | None | None) – 選用,指定結果批次形狀的非負整數元組;也就是說,結果形狀的前綴,不包括最後一個軸。必須與 mean.shape[:-1]cov.shape[:-2] 廣播相容。預設值 (None) 會透過一起廣播 meancov 的批次形狀來產生結果批次形狀。

  • dtype (DTypeLikeFloat | None | None) – 選用,用於傳回值的浮點 dtype (如果 jax_enable_x64 為 true,則預設為 float64,否則為 float32)。

  • method (str) – 選用,計算 cov 因子的方法。必須是 ‘svd’、‘eigh’ 和 ‘cholesky’ 之一。預設為 ‘cholesky’。對於奇異共變異數矩陣,請使用 ‘svd’ 或 ‘eigh’。

傳回:

具有指定 dtype 和形狀的隨機陣列,如果 shape 不是 None,則形狀由 shape + mean.shape[-1:] 給定,否則由 broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:] 給定。

傳回類型:

Array