jax.numpy.ufunc#
- class jax.numpy.ufunc(func, /, nin, nout, *, name=None, nargs=None, identity=None, call=None, reduce=None, accumulate=None, at=None, reduceat=None)#
通用函數,對陣列逐元素執行運算。
NumPy
numpy.ufunc
的 JAX 實作。這是 NumPy ufunc API 的 JAX 後端實作類別。大多數使用者永遠不需要實例化
ufunc
,而是使用jax.numpy
中預先定義的 ufunc。如需建構您自己的 ufunc,請參閱
jax.numpy.frompyfunc()
。範例
通用函數是逐元素應用於廣播陣列的函數,但它們也帶有一些額外的屬性和方法。
例如,考慮函數
jax.numpy.add
。此物件作為一個函數,將加法逐元素應用於廣播陣列>>> x = jnp.array([1, 2, 3, 4, 5]) >>> jnp.add(x, 1) Array([2, 3, 4, 5, 6], dtype=int32)
每個
ufunc
物件都包含許多描述其行為的屬性>>> jnp.add.nin # number of inputs 2 >>> jnp.add.nout # number of outputs 1 >>> jnp.add.identity # identity value, or None if no identity exists 0
二元 ufunc,例如
jax.numpy.add
,包含許多將函數以不同方式應用於陣列的方法。outer()
方法將函數應用於輸入陣列值的成對外積>>> jnp.add.outer(x, x) Array([[ 2, 3, 4, 5, 6], [ 3, 4, 5, 6, 7], [ 4, 5, 6, 7, 8], [ 5, 6, 7, 8, 9], [ 6, 7, 8, 9, 10]], dtype=int32)
ufunc.reduce()
方法對陣列執行歸約。例如,jnp.add.reduce()
等價於jnp.sum
>>> jnp.add.reduce(x) Array(15, dtype=int32)
ufunc.accumulate()
方法對陣列執行累積歸約。例如,jnp.add.accumulate()
等價於jax.numpy.cumulative_sum()
>>> jnp.add.accumulate(x) Array([ 1, 3, 6, 10, 15], dtype=int32)
ufunc.at()
方法在陣列中的特定索引處應用函數;對於jnp.add
,計算類似於jax.lax.scatter_add()
>>> jnp.add.at(x, 0, 100, inplace=False) Array([101, 2, 3, 4, 5], dtype=int32)
而
ufunc.reduceat()
方法在陣列的指定索引之間執行多個reduce
運算;對於jnp.add
,運算類似於jax.ops.segment_sum()
>>> jnp.add.reduceat(x, jnp.array([0, 2])) Array([ 3, 12], dtype=int32)
在此範例中,第一個元素是
x[0:2].sum()
,第二個元素是x[2:].sum()
。- 參數:
- __init__(func, /, nin, nout, *, name=None, nargs=None, identity=None, call=None, reduce=None, accumulate=None, at=None, reduceat=None)[source]#
- 參數:
func (Callable[..., Any])
nin (int)
nout (int)
name (str | None | None)
nargs (int | None | None)
identity (Any | None)
call (Callable[..., Any] | None | None)
reduce (Callable[..., Any] | None | None)
accumulate (Callable[..., Any] | None | None)
at (Callable[..., Any] | None | None)
reduceat (Callable[..., Any] | None | None)
方法
__init__
(func, /, nin, nout, *[, name, ...])accumulate
(a[, axis, dtype, out])從二元 ufunc 衍生的累積運算。
at
(a, indices[, b, inplace])透過指定的單元或二元 ufunc 更新陣列的元素。
outer
(A, B, /)將函數應用於
A
和B
中的所有值對。reduce
(a[, axis, dtype, out, keepdims, ...])從二元函數衍生的歸約運算。
reduceat
(a, indices[, axis, dtype, out])透過二元 ufunc 歸約指定索引之間的陣列。
屬性
identity
nargs
nin
nout