jax.numpy.ndarray.at#
- abstract property ndarray.at[原始碼]#
用於索引更新功能的輔助屬性。
at
屬性提供功能純粹的等效原地陣列修改。特別是
替代語法
等效原地表達式
x = x.at[idx].set(y)
x[idx] = y
x = x.at[idx].add(y)
x[idx] += y
x = x.at[idx].subtract(y)
x[idx] -= y
x = x.at[idx].multiply(y)
x[idx] *= y
x = x.at[idx].divide(y)
x[idx] /= y
x = x.at[idx].power(y)
x[idx] **= y
x = x.at[idx].min(y)
x[idx] = minimum(x[idx], y)
x = x.at[idx].max(y)
x[idx] = maximum(x[idx], y)
x = x.at[idx].apply(ufunc)
ufunc.at(x, idx)
x = x.at[idx].get()
x = x[idx]
沒有任何
x.at
表達式會修改原始x
;相反地,它們會傳回x
的修改副本。然而,在jit()
編譯函數內部,諸如x = x.at[idx].set(y)
之類的表達式保證會原地套用。與 NumPy 原地運算(例如
x[idx] += y
)不同,如果多個索引指向相同位置,則所有更新都會套用(NumPy 只會套用最後一個更新,而不是套用所有更新)。衝突更新的套用順序是實作定義的,並且可能是不確定的(例如,由於某些硬體平台上的並行性)。預設情況下,JAX 假設所有索引都在範圍內。可以透過
mode
參數(見下文)指定替代的超出範圍索引語義。- 參數:
mode (str) –
指定超出範圍的索引模式。選項為
"promise_in_bounds"
:(預設)使用者保證索引在範圍內。不會執行額外檢查。實際上,這表示get()
中的超出範圍索引將被裁剪,而set()
、add()
等中的超出範圍索引將被捨棄。"clip"
:將超出範圍的索引箝制到有效範圍內。"drop"
:忽略超出範圍的索引。"fill"
:"drop"
的別名。對於 get(),選用的fill_value
引數指定將傳回的值。有關更多詳細資訊,請參閱
jax.lax.GatherScatterMode
。
indices_are_sorted (bool) – 如果為 True,則實作會假設傳遞給
at[]
的索引依升序排序,這可以在某些後端上實現更有效率的執行。unique_indices (bool) – 如果為 True,則實作會假設傳遞給
at[]
的索引是唯一的,這可能會在某些後端上實現更有效率的執行。fill_value (Any) – 僅適用於
get()
方法:當 mode 為'fill'
時,針對超出範圍的切片傳回的填充值。否則會忽略。預設值為非精確類型的NaN
、有號類型的最大負值、無號類型的最大正值,以及布林值的True
。
範例
>>> x = jnp.arange(5.0) >>> x Array([0., 1., 2., 3., 4.], dtype=float32) >>> x.at[2].add(10) Array([ 0., 1., 12., 3., 4.], dtype=float32) >>> x.at[10].add(10) # out-of-bounds indices are ignored Array([0., 1., 2., 3., 4.], dtype=float32) >>> x.at[20].add(10, mode='clip') Array([ 0., 1., 2., 3., 14.], dtype=float32) >>> x.at[2].get() Array(2., dtype=float32) >>> x.at[20].get() # out-of-bounds indices clipped Array(4., dtype=float32) >>> x.at[20].get(mode='fill') # out-of-bounds indices filled with NaN Array(nan, dtype=float32) >>> x.at[20].get(mode='fill', fill_value=-1) # custom fill value Array(-1., dtype=float32)