jax.numpy.arange#
- jax.numpy.arange(start, stop=None, step=None, dtype=None, *, device=None)[原始碼]#
建立均勻間隔值的陣列。
numpy.arange()
的 JAX 實作,以jax.lax.iota()
實作。類似於 Python 的
range()
函數,這可以使用一些不同的位置簽名來呼叫jnp.arange(stop)
:產生從 0 到stop
的值,步進為 1。jnp.arange(start, stop)
:產生從start
到stop
的值,步進為 1。jnp.arange(start, stop, step)
:產生從start
到stop
的值,步進為step
。
與 Python 的
range()
函數一樣,起始值是包含的,而停止值是排除的。- 參數:
start (ArrayLike | DimSize) – 間隔的開始,包含。
stop (ArrayLike | DimSize | None | None) – 間隔的可選結束,排除。如果未指定,則
(start, stop) = (0, start)
step (ArrayLike | None | None) – 間隔的可選步進大小。預設值 = 1。
dtype (DTypeLike | None | None) – 返回陣列的可選 dtype;如果未指定,將透過 start、stop 和 step 的類型提升來決定。
device (xc.Device | Sharding | None | None) – (可選)
Device
或Sharding
,將在該裝置或分片上建立陣列。
- 返回:
從
start
到stop
的均勻間隔值陣列,以step
分隔。- 返回類型:
注意
使用帶有浮點
step
參數的arange
可能會由於浮點錯誤的累積而導致意外結果,尤其是在使用較低精度資料類型(如float8_*
和bfloat16
)時。為了避免精度錯誤,請考慮產生整數範圍,並將其縮放到所需的範圍。例如,取代這種方式jnp.arange(-1, 1, 0.01, dtype='bfloat16')
更準確的方式是產生整數序列,並縮放它們
(jnp.arange(-100, 100) * 0.01).astype('bfloat16')
範例
單一參數版本僅指定
stop
值>>> jnp.arange(4) Array([0, 1, 2, 3], dtype=int32)
傳遞浮點
stop
值會產生浮點結果>>> jnp.arange(4.0) Array([0., 1., 2., 3.], dtype=float32)
雙參數版本指定
start
和stop
,其中step=1
>>> jnp.arange(1, 6) Array([1, 2, 3, 4, 5], dtype=int32)
三參數版本指定
start
、stop
和step
>>> jnp.arange(0, 2, 0.5) Array([0. , 0.5, 1. , 1.5], dtype=float32)
另請參閱
jax.numpy.linspace()
:產生固定數量的均勻間隔值。jax.lax.iota()
:直接在 XLA 中產生整數序列。