jax.numpy.arange#

jax.numpy.arange(start, stop=None, step=None, dtype=None, *, device=None)[原始碼]#

建立均勻間隔值的陣列。

numpy.arange() 的 JAX 實作,以 jax.lax.iota() 實作。

類似於 Python 的 range() 函數,這可以使用一些不同的位置簽名來呼叫

  • jnp.arange(stop):產生從 0 到 stop 的值,步進為 1。

  • jnp.arange(start, stop):產生從 startstop 的值,步進為 1。

  • jnp.arange(start, stop, step):產生從 startstop 的值,步進為 step

與 Python 的 range() 函數一樣,起始值是包含的,而停止值是排除的。

參數:
  • start (ArrayLike | DimSize) – 間隔的開始,包含。

  • stop (ArrayLike | DimSize | None | None) – 間隔的可選結束,排除。如果未指定,則 (start, stop) = (0, start)

  • step (ArrayLike | None | None) – 間隔的可選步進大小。預設值 = 1。

  • dtype (DTypeLike | None | None) – 返回陣列的可選 dtype;如果未指定,將透過 startstopstep 的類型提升來決定。

  • device (xc.Device | Sharding | None | None) – (可選) DeviceSharding,將在該裝置或分片上建立陣列。

返回:

startstop 的均勻間隔值陣列,以 step 分隔。

返回類型:

Array

注意

使用帶有浮點 step 參數的 arange 可能會由於浮點錯誤的累積而導致意外結果,尤其是在使用較低精度資料類型(如 float8_*bfloat16)時。為了避免精度錯誤,請考慮產生整數範圍,並將其縮放到所需的範圍。例如,取代這種方式

jnp.arange(-1, 1, 0.01, dtype='bfloat16')

更準確的方式是產生整數序列,並縮放它們

(jnp.arange(-100, 100) * 0.01).astype('bfloat16')

範例

單一參數版本僅指定 stop

>>> jnp.arange(4)
Array([0, 1, 2, 3], dtype=int32)

傳遞浮點 stop 值會產生浮點結果

>>> jnp.arange(4.0)
Array([0., 1., 2., 3.], dtype=float32)

雙參數版本指定 startstop,其中 step=1

>>> jnp.arange(1, 6)
Array([1, 2, 3, 4, 5], dtype=int32)

三參數版本指定 startstopstep

>>> jnp.arange(0, 2, 0.5)
Array([0. , 0.5, 1. , 1.5], dtype=float32)

另請參閱