jax.random.multivariate_normal#
- jax.random.multivariate_normal(key, mean, cov, shape=None, dtype=None, method='cholesky')[原始碼]#
採樣具有給定平均值和共變異數的多元常態隨機值。
這些值根據機率密度函數傳回
\[f(x;\mu, \Sigma) = (2\pi)^{-k/2} \det(\Sigma)^{-1}e^{-\frac{1}{2}(x - \mu)^T \Sigma^{-1} (x - \mu)}\]其中 \(k\) 是維度,\(\mu\) 是平均值 (由
mean
給定),而 \(\Sigma\) 是共變異數矩陣 (由cov
給定)。- 參數:
key (ArrayLike) – 作為隨機金鑰使用的 PRNG 金鑰。
mean (RealArray) – 形狀為
(..., n)
的平均值向量。cov (RealArray) – 形狀為
(..., n, n)
的正定共變異數矩陣。批次形狀...
必須與mean
的批次形狀廣播相容。shape (Shape | None | None) – 選用,指定結果批次形狀的非負整數元組;也就是說,結果形狀的前綴,不包括最後一個軸。必須與
mean.shape[:-1]
和cov.shape[:-2]
廣播相容。預設值 (None) 會透過一起廣播mean
和cov
的批次形狀來產生結果批次形狀。dtype (DTypeLikeFloat | None | None) – 選用,用於傳回值的浮點 dtype (如果 jax_enable_x64 為 true,則預設為 float64,否則為 float32)。
method (str) – 選用,計算
cov
因子的方法。必須是 ‘svd’、‘eigh’ 和 ‘cholesky’ 之一。預設為 ‘cholesky’。對於奇異共變異數矩陣,請使用 ‘svd’ 或 ‘eigh’。
- 傳回:
具有指定 dtype 和形狀的隨機陣列,如果
shape
不是 None,則形狀由shape + mean.shape[-1:]
給定,否則由broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:]
給定。- 傳回類型: