jax.numpy.linspace#

jax.numpy.linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, *, device=None)[原始碼]#

在間隔內傳回均勻間隔的數字。

numpy.linspace() 的 JAX 實作。

參數:
  • start (ArrayLike) – 起始值的純量或陣列。

  • stop (ArrayLike) – 停止值的純量或陣列。

  • num (int) – 要產生的值數量。預設值:50。

  • endpoint (bool) – 如果為 True (預設值),則在結果中包含 stop 值。如果為 False,則排除 stop 值。

  • retstep (bool) – 如果為 True,則傳回 (result, step) tuple,其中 stepresult 中相鄰值之間的間隔。

  • axis (int) – 沿著其產生 linspace 的整數軸。預設為零。

  • device (xc.Device | Sharding | None | None) – 可選的 DeviceSharding,建立的陣列將提交到此裝置或分片。

  • dtype (DTypeLike | None | None)

傳回:

  • values 是從 startstop 的均勻間隔值陣列

  • step 是相鄰值之間的間隔。

傳回類型:

陣列 values,或 tuple (values, step) 如果 retstep 為 True,其中

參見

範例

0 到 10 之間的 5 個值列表

>>> jnp.linspace(0, 10, 5)
Array([ 0. ,  2.5,  5. ,  7.5, 10. ], dtype=float32)

0 到 10 之間的 8 個值列表,排除端點

>>> jnp.linspace(0, 10, 8, endpoint=False)
Array([0.  , 1.25, 2.5 , 3.75, 5.  , 6.25, 7.5 , 8.75], dtype=float32)

值列表以及它們之間的步長

>>> vals, step = jnp.linspace(0, 10, 9, retstep=True)
>>> vals
Array([ 0.  ,  1.25,  2.5 ,  3.75,  5.  ,  6.25,  7.5 ,  8.75, 10.  ],      dtype=float32)
>>> step
Array(1.25, dtype=float32)

多維 linspace

>>> start = jnp.array([0, 5])
>>> stop = jnp.array([5, 10])
>>> jnp.linspace(start, stop, 5)
Array([[ 0.  ,  5.  ],
       [ 1.25,  6.25],
       [ 2.5 ,  7.5 ],
       [ 3.75,  8.75],
       [ 5.  , 10.  ]], dtype=float32)