jax.numpy.apply_along_axis#
- jax.numpy.apply_along_axis(func1d, axis, arr, *args, **kwargs)[原始碼]#
沿著軸將函數應用於 1D 陣列切片。
JAX 實作的
numpy.apply_along_axis()
。雖然 NumPy 以迭代方式實作此功能,但 JAX 透過jax.vmap()
實作此功能,因此func1d
必須與vmap
相容。- 參數:
func1d (Callable) – 可調用函數,簽名為
func1d(arr, /, *args, **kwargs)
,其中*args
和**kwargs
是傳遞給apply_along_axis()
的額外位置和關鍵字參數。axis (int) – 應用函數的整數軸。
arr (ArrayLike) – 要在其上應用函數的陣列。
args – 額外的位置和關鍵字參數會傳遞給
func1d
。kwargs – 額外的位置和關鍵字參數會傳遞給
func1d
。
- 返回:
沿指定軸應用
func1d
的結果。- 返回類型:
另請參閱
jax.vmap()
:更直接地建立函數的向量化版本的方法。jax.numpy.apply_over_axes()
:重複跨多個軸應用函數。jax.numpy.vectorize()
:建立函數的向量化版本。
範例
二維的簡單範例,其中函數以逐行或逐列方式應用
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> def func1d(x): ... return jnp.sum(x ** 2) >>> jnp.apply_along_axis(func1d, 0, x) Array([17, 29, 45], dtype=int32) >>> jnp.apply_along_axis(func1d, 1, x) Array([14, 77], dtype=int32)
對於 2D 輸入,這可以使用
jax.vmap()
等效地表示,但請注意 vmap 指定的是映射軸而不是應用軸>>> jax.vmap(func1d, in_axes=1)(x) # same as applying along axis 0 Array([17, 29, 45], dtype=int32) >>> jax.vmap(func1d, in_axes=0)(x) # same as applying along axis 1 Array([14, 77], dtype=int32)
對於 3D 輸入,
apply_along_axis()
等效於跨兩個維度進行映射>>> x_3d = jnp.arange(24).reshape(2, 3, 4) >>> jnp.apply_along_axis(func1d, 2, x_3d) Array([[ 14, 126, 366], [ 734, 1230, 1854]], dtype=int32) >>> jax.vmap(jax.vmap(func1d))(x_3d) Array([[ 14, 126, 366], [ 734, 1230, 1854]], dtype=int32)
應用函數也可以接受任意位置或關鍵字參數,這些參數應直接作為額外參數傳遞給
apply_along_axis()
>>> def func1d(x, exponent): ... return jnp.sum(x ** exponent) >>> jnp.apply_along_axis(func1d, 0, x, exponent=3) Array([ 65, 133, 243], dtype=int32)