jax.numpy.linspace#
- jax.numpy.linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, *, device=None)[原始碼]#
在間隔內傳回均勻間隔的數字。
numpy.linspace()
的 JAX 實作。- 參數:
start (ArrayLike) – 起始值的純量或陣列。
stop (ArrayLike) – 停止值的純量或陣列。
num (int) – 要產生的值數量。預設值:50。
endpoint (bool) – 如果為 True (預設值),則在結果中包含
stop
值。如果為 False,則排除stop
值。retstep (bool) – 如果為 True,則傳回
(result, step)
tuple,其中step
是result
中相鄰值之間的間隔。axis (int) – 沿著其產生 linspace 的整數軸。預設為零。
device (xc.Device | Sharding | None | None) – 可選的
Device
或Sharding
,建立的陣列將提交到此裝置或分片。dtype (DTypeLike | None | None)
- 傳回:
values
是從start
到stop
的均勻間隔值陣列step
是相鄰值之間的間隔。
- 傳回類型:
陣列
values
,或 tuple(values, step)
如果retstep
為 True,其中
參見
jax.numpy.arange()
:給定起點和步長,產生N
個均勻間隔的值jax.numpy.logspace()
:產生對數間隔的值。jax.numpy.geomspace()
:產生幾何間隔的值。
範例
0 到 10 之間的 5 個值列表
>>> jnp.linspace(0, 10, 5) Array([ 0. , 2.5, 5. , 7.5, 10. ], dtype=float32)
0 到 10 之間的 8 個值列表,排除端點
>>> jnp.linspace(0, 10, 8, endpoint=False) Array([0. , 1.25, 2.5 , 3.75, 5. , 6.25, 7.5 , 8.75], dtype=float32)
值列表以及它們之間的步長
>>> vals, step = jnp.linspace(0, 10, 9, retstep=True) >>> vals Array([ 0. , 1.25, 2.5 , 3.75, 5. , 6.25, 7.5 , 8.75, 10. ], dtype=float32) >>> step Array(1.25, dtype=float32)
多維 linspace
>>> start = jnp.array([0, 5]) >>> stop = jnp.array([5, 10]) >>> jnp.linspace(start, stop, 5) Array([[ 0. , 5. ], [ 1.25, 6.25], [ 2.5 , 7.5 ], [ 3.75, 8.75], [ 5. , 10. ]], dtype=float32)