jax.Array.at#
- abstract property Array.at[原始碼]#
用於索引更新功能的輔助屬性。
at
屬性提供功能上純粹的等效於就地陣列修改的方式。特別是
替代語法
等效的就地表達式
x = x.at[idx].set(y)
x[idx] = y
x = x.at[idx].add(y)
x[idx] += y
x = x.at[idx].subtract(y)
x[idx] -= y
x = x.at[idx].multiply(y)
x[idx] *= y
x = x.at[idx].divide(y)
x[idx] /= y
x = x.at[idx].power(y)
x[idx] **= y
x = x.at[idx].min(y)
x[idx] = minimum(x[idx], y)
x = x.at[idx].max(y)
x[idx] = maximum(x[idx], y)
x = x.at[idx].apply(ufunc)
ufunc.at(x, idx)
x = x.at[idx].get()
x = x[idx]
沒有任何
x.at
表達式會修改原始的x
;相反地,它們會傳回x
的修改後副本。然而,在jit()
編譯函數內部,像是x = x.at[idx].set(y)
這樣的表達式保證會就地應用。與 NumPy 就地操作(例如
x[idx] += y
)不同,如果多個索引指向相同的位置,所有更新都會被應用(NumPy 只會應用最後一個更新,而不是應用所有更新)。衝突更新的應用順序是實作定義的,並且可能是不確定的(例如,由於某些硬體平台上的並發)。預設情況下,JAX 假設所有索引都在範圍內。可以使用
mode
參數(見下文)指定替代的越界索引語意。- 參數:
mode (str) –
指定越界索引模式。選項為
"promise_in_bounds"
:(預設)使用者承諾索引在範圍內。不會執行額外的檢查。實際上,這表示get()
中的越界索引將被剪裁,而set()
、add()
等中的越界索引將被捨棄。"clip"
:將越界索引箝制到有效範圍內。"drop"
:忽略越界索引。"fill"
:"drop"
的別名。對於 get(),可選的fill_value
參數指定將傳回的值。有關更多詳細資訊,請參閱
jax.lax.GatherScatterMode
。
indices_are_sorted (bool) – 如果為 True,實作將假設傳遞給
at[]
的索引以升序排序,這可以在某些後端上實現更有效率的執行。unique_indices (bool) – 如果為 True,實作將假設傳遞給
at[]
的索引是唯一的,這可以在某些後端上實現更有效率的執行。fill_value (Any) – 僅適用於
get()
方法:當 mode 為'fill'
時,要為越界切片傳回的填充值。否則將被忽略。預設值為非精確型別的NaN
、有號型別的最大負值、無號型別的最大正值,以及布林值的True
。
範例
>>> x = jnp.arange(5.0) >>> x Array([0., 1., 2., 3., 4.], dtype=float32) >>> x.at[2].add(10) Array([ 0., 1., 12., 3., 4.], dtype=float32) >>> x.at[10].add(10) # out-of-bounds indices are ignored Array([0., 1., 2., 3., 4.], dtype=float32) >>> x.at[20].add(10, mode='clip') Array([ 0., 1., 2., 3., 14.], dtype=float32) >>> x.at[2].get() Array(2., dtype=float32) >>> x.at[20].get() # out-of-bounds indices clipped Array(4., dtype=float32) >>> x.at[20].get(mode='fill') # out-of-bounds indices filled with NaN Array(nan, dtype=float32) >>> x.at[20].get(mode='fill', fill_value=-1) # custom fill value Array(-1., dtype=float32)