jax.Array.at#

abstract property Array.at[原始碼]#

用於索引更新功能的輔助屬性。

at 屬性提供功能上純粹的等效於就地陣列修改的方式。

特別是

替代語法

等效的就地表達式

x = x.at[idx].set(y)

x[idx] = y

x = x.at[idx].add(y)

x[idx] += y

x = x.at[idx].subtract(y)

x[idx] -= y

x = x.at[idx].multiply(y)

x[idx] *= y

x = x.at[idx].divide(y)

x[idx] /= y

x = x.at[idx].power(y)

x[idx] **= y

x = x.at[idx].min(y)

x[idx] = minimum(x[idx], y)

x = x.at[idx].max(y)

x[idx] = maximum(x[idx], y)

x = x.at[idx].apply(ufunc)

ufunc.at(x, idx)

x = x.at[idx].get()

x = x[idx]

沒有任何 x.at 表達式會修改原始的 x;相反地,它們會傳回 x 的修改後副本。然而,在 jit() 編譯函數內部,像是 x = x.at[idx].set(y) 這樣的表達式保證會就地應用。

與 NumPy 就地操作(例如 x[idx] += y)不同,如果多個索引指向相同的位置,所有更新都會被應用(NumPy 只會應用最後一個更新,而不是應用所有更新)。衝突更新的應用順序是實作定義的,並且可能是不確定的(例如,由於某些硬體平台上的並發)。

預設情況下,JAX 假設所有索引都在範圍內。可以使用 mode 參數(見下文)指定替代的越界索引語意。

參數:
  • mode (str) –

    指定越界索引模式。選項為

    • "promise_in_bounds":(預設)使用者承諾索引在範圍內。不會執行額外的檢查。實際上,這表示 get() 中的越界索引將被剪裁,而 set()add() 等中的越界索引將被捨棄。

    • "clip":將越界索引箝制到有效範圍內。

    • "drop":忽略越界索引。

    • "fill""drop" 的別名。對於 get(),可選的 fill_value 參數指定將傳回的值。

      有關更多詳細資訊,請參閱 jax.lax.GatherScatterMode

  • indices_are_sorted (bool) – 如果為 True,實作將假設傳遞給 at[] 的索引以升序排序,這可以在某些後端上實現更有效率的執行。

  • unique_indices (bool) – 如果為 True,實作將假設傳遞給 at[] 的索引是唯一的,這可以在某些後端上實現更有效率的執行。

  • fill_value (Any) – 僅適用於 get() 方法:當 mode'fill' 時,要為越界切片傳回的填充值。否則將被忽略。預設值為非精確型別的 NaN、有號型別的最大負值、無號型別的最大正值,以及布林值的 True

範例

>>> x = jnp.arange(5.0)
>>> x
Array([0., 1., 2., 3., 4.], dtype=float32)
>>> x.at[2].add(10)
Array([ 0.,  1., 12.,  3.,  4.], dtype=float32)
>>> x.at[10].add(10)  # out-of-bounds indices are ignored
Array([0., 1., 2., 3., 4.], dtype=float32)
>>> x.at[20].add(10, mode='clip')
Array([ 0.,  1.,  2.,  3., 14.], dtype=float32)
>>> x.at[2].get()
Array(2., dtype=float32)
>>> x.at[20].get()  # out-of-bounds indices clipped
Array(4., dtype=float32)
>>> x.at[20].get(mode='fill')  # out-of-bounds indices filled with NaN
Array(nan, dtype=float32)
>>> x.at[20].get(mode='fill', fill_value=-1)  # custom fill value
Array(-1., dtype=float32)