jax.nn.initializers.orthogonal#

jax.nn.initializers.orthogonal(scale=1.0, column_axis=-1, dtype=<class 'jax.numpy.float64'>)[source]#

建構一個初始化器,其會傳回均勻分佈的正交矩陣。

如果形狀不是正方形,則矩陣將具有正交的行或列,取決於哪一側較小。

參數:
  • scale (RealNumeric) – 均勻分佈的上限。

  • column_axis (int) – 包含應為正交的列的軸。

  • dtype (DTypeLikeInexact) – 權重的預設 dtype。

傳回:

正交初始化器。

傳回型別:

Initializer

範例

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.orthogonal()
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 3.9026976e-01,  7.2495741e-01, -5.6756169e-01],
       [ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]],            dtype=float32)