jax.numpy.apply_along_axis#

jax.numpy.apply_along_axis(func1d, axis, arr, *args, **kwargs)[原始碼]#

沿著軸將函數應用於 1D 陣列切片。

JAX 實作的 numpy.apply_along_axis()。雖然 NumPy 以迭代方式實作此功能,但 JAX 透過 jax.vmap() 實作此功能,因此 func1d 必須與 vmap 相容。

參數:
  • func1d (Callable) – 可調用函數,簽名為 func1d(arr, /, *args, **kwargs),其中 *args**kwargs 是傳遞給 apply_along_axis() 的額外位置和關鍵字參數。

  • axis (int) – 應用函數的整數軸。

  • arr (ArrayLike) – 要在其上應用函數的陣列。

  • args – 額外的位置和關鍵字參數會傳遞給 func1d

  • kwargs – 額外的位置和關鍵字參數會傳遞給 func1d

返回:

沿指定軸應用 func1d 的結果。

返回類型:

Array

另請參閱

範例

二維的簡單範例,其中函數以逐行或逐列方式應用

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> def func1d(x):
...   return jnp.sum(x ** 2)
>>> jnp.apply_along_axis(func1d, 0, x)
Array([17, 29, 45], dtype=int32)
>>> jnp.apply_along_axis(func1d, 1, x)
Array([14, 77], dtype=int32)

對於 2D 輸入,這可以使用 jax.vmap() 等效地表示,但請注意 vmap 指定的是映射軸而不是應用軸

>>> jax.vmap(func1d, in_axes=1)(x)  # same as applying along axis 0
Array([17, 29, 45], dtype=int32)
>>> jax.vmap(func1d, in_axes=0)(x)  # same as applying along axis 1
Array([14, 77], dtype=int32)

對於 3D 輸入,apply_along_axis() 等效於跨兩個維度進行映射

>>> x_3d = jnp.arange(24).reshape(2, 3, 4)
>>> jnp.apply_along_axis(func1d, 2, x_3d)
Array([[  14,  126,  366],
       [ 734, 1230, 1854]], dtype=int32)
>>> jax.vmap(jax.vmap(func1d))(x_3d)
Array([[  14,  126,  366],
       [ 734, 1230, 1854]], dtype=int32)

應用函數也可以接受任意位置或關鍵字參數,這些參數應直接作為額外參數傳遞給 apply_along_axis()

>>> def func1d(x, exponent):
...   return jnp.sum(x ** exponent)
>>> jnp.apply_along_axis(func1d, 0, x, exponent=3)
Array([ 65, 133, 243], dtype=int32)