jax.numpy.frompyfunc#

jax.numpy.frompyfunc(func, /, nin, nout, *, identity=None)[原始碼]#

從任意與 JAX 相容的純量函式建立 JAX ufunc。

參數:
  • func (Callable[..., Any]) – 一個可調用物件,接受 nin 個純量引數並傳回 nout 個輸出。

  • nin (int) – 指定純量輸入數量的整數

  • nout (int) – 指定純量輸出數量的整數

  • identity (Any | None) – (選用) 指定運算恆等元素的純量(如果有的話)。

傳回:

func 的 jax.numpy.ufunc 包裝器。

傳回類型:

wrapped

範例

以下範例展示如何建立類似 jax.numpy.add 的 ufunc

>>> import operator
>>> add = frompyfunc(operator.add, nin=2, nout=1, identity=0)

現在所有標準的 jax.numpy.ufunc 方法都可用了

>>> x = jnp.arange(4)
>>> add(x, 10)
Array([10, 11, 12, 13], dtype=int32)
>>> add.outer(x, x)
Array([[0, 1, 2, 3],
       [1, 2, 3, 4],
       [2, 3, 4, 5],
       [3, 4, 5, 6]], dtype=int32)
>>> add.reduce(x)
Array(6, dtype=int32)
>>> add.accumulate(x)
Array([0, 1, 3, 6], dtype=int32)
>>> add.at(x, 1, 10, inplace=False)
Array([ 0, 11,  2,  3], dtype=int32)