jax.numpy.ufunc#

class jax.numpy.ufunc(func, /, nin, nout, *, name=None, nargs=None, identity=None, call=None, reduce=None, accumulate=None, at=None, reduceat=None)#

通用函數,對陣列逐元素執行運算。

NumPy numpy.ufunc 的 JAX 實作。

這是 NumPy ufunc API 的 JAX 後端實作類別。大多數使用者永遠不需要實例化 ufunc,而是使用 jax.numpy 中預先定義的 ufunc。

如需建構您自己的 ufunc,請參閱 jax.numpy.frompyfunc()

範例

通用函數是逐元素應用於廣播陣列的函數,但它們也帶有一些額外的屬性和方法。

例如,考慮函數 jax.numpy.add。此物件作為一個函數,將加法逐元素應用於廣播陣列

>>> x = jnp.array([1, 2, 3, 4, 5])
>>> jnp.add(x, 1)
Array([2, 3, 4, 5, 6], dtype=int32)

每個 ufunc 物件都包含許多描述其行為的屬性

>>> jnp.add.nin  # number of inputs
2
>>> jnp.add.nout  # number of outputs
1
>>> jnp.add.identity  # identity value, or None if no identity exists
0

二元 ufunc,例如 jax.numpy.add,包含許多將函數以不同方式應用於陣列的方法。

outer() 方法將函數應用於輸入陣列值的成對外積

>>> jnp.add.outer(x, x)
Array([[ 2,  3,  4,  5,  6],
       [ 3,  4,  5,  6,  7],
       [ 4,  5,  6,  7,  8],
       [ 5,  6,  7,  8,  9],
       [ 6,  7,  8,  9, 10]], dtype=int32)

ufunc.reduce() 方法對陣列執行歸約。例如,jnp.add.reduce() 等價於 jnp.sum

>>> jnp.add.reduce(x)
Array(15, dtype=int32)

ufunc.accumulate() 方法對陣列執行累積歸約。例如,jnp.add.accumulate() 等價於 jax.numpy.cumulative_sum()

>>> jnp.add.accumulate(x)
Array([ 1,  3,  6, 10, 15], dtype=int32)

ufunc.at() 方法在陣列中的特定索引處應用函數;對於 jnp.add,計算類似於 jax.lax.scatter_add()

>>> jnp.add.at(x, 0, 100, inplace=False)
Array([101,   2,   3,   4,   5], dtype=int32)

ufunc.reduceat() 方法在陣列的指定索引之間執行多個 reduce 運算;對於 jnp.add,運算類似於 jax.ops.segment_sum()

>>> jnp.add.reduceat(x, jnp.array([0, 2]))
Array([ 3, 12], dtype=int32)

在此範例中,第一個元素是 x[0:2].sum(),第二個元素是 x[2:].sum()

參數:
  • func (Callable[..., Any])

  • nin (int)

  • nout (int)

  • name (str | None)

  • nargs (int | None)

  • identity (Any)

  • call (Callable[..., Any] | None)

  • reduce (Callable[..., Any] | None)

  • accumulate (Callable[..., Any] | None)

  • at (Callable[..., Any] | None)

  • reduceat (Callable[..., Any] | None)

__init__(func, /, nin, nout, *, name=None, nargs=None, identity=None, call=None, reduce=None, accumulate=None, at=None, reduceat=None)[source]#
參數:
  • func (Callable[..., Any])

  • nin (int)

  • nout (int)

  • name (str | None | None)

  • nargs (int | None | None)

  • identity (Any | None)

  • call (Callable[..., Any] | None | None)

  • reduce (Callable[..., Any] | None | None)

  • accumulate (Callable[..., Any] | None | None)

  • at (Callable[..., Any] | None | None)

  • reduceat (Callable[..., Any] | None | None)

方法

__init__(func, /, nin, nout, *[, name, ...])

accumulate(a[, axis, dtype, out])

從二元 ufunc 衍生的累積運算。

at(a, indices[, b, inplace])

透過指定的單元或二元 ufunc 更新陣列的元素。

outer(A, B, /)

將函數應用於 AB 中的所有值對。

reduce(a[, axis, dtype, out, keepdims, ...])

從二元函數衍生的歸約運算。

reduceat(a, indices[, axis, dtype, out])

透過二元 ufunc 歸約指定索引之間的陣列。

屬性

identity

nargs

nin

nout