jax.numpy.frompyfunc#
- jax.numpy.frompyfunc(func, /, nin, nout, *, identity=None)[原始碼]#
從任意與 JAX 相容的純量函式建立 JAX ufunc。
- 參數:
- 傳回:
func 的 jax.numpy.ufunc 包裝器。
- 傳回類型:
wrapped
範例
以下範例展示如何建立類似
jax.numpy.add
的 ufunc>>> import operator >>> add = frompyfunc(operator.add, nin=2, nout=1, identity=0)
現在所有標準的
jax.numpy.ufunc
方法都可用了>>> x = jnp.arange(4) >>> add(x, 10) Array([10, 11, 12, 13], dtype=int32) >>> add.outer(x, x) Array([[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6]], dtype=int32) >>> add.reduce(x) Array(6, dtype=int32) >>> add.accumulate(x) Array([0, 1, 3, 6], dtype=int32) >>> add.at(x, 1, 10, inplace=False) Array([ 0, 11, 2, 3], dtype=int32)