jax.nn.initializers.ones#

jax.nn.initializers.ones(key, shape, dtype=<class 'jax.numpy.float64'>)[原始碼]#

傳回充滿 1 的常數陣列的初始化器。

key 引數會被忽略。

>>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.ones(jax.random.key(42), (3, 2), jnp.float32)
Array([[1., 1.],
       [1., 1.],
       [1., 1.]], dtype=float32)
參數:
  • key (Array)

  • shape (core.Shape)

  • dtype (DTypeLikeInexact)

傳回類型:

Array