jax.numpy.fft.rfftn#

jax.numpy.fft.rfftn(a, s=None, axes=None, norm=None)[原始碼]#

計算實值陣列的多維離散傅立葉轉換。

JAX 實作的 numpy.fft.rfftn()

參數:
  • a (ArrayLike) – 實值輸入陣列。

  • s (Shape | None | None) – 整數的可選序列。控制沿每個指定軸的輸入有效大小。如果未指定,則預設為輸入沿 axes 的維度。

  • axes (Sequence[int] | None | None) – 整數的可選序列,預設值=None。指定計算轉換的軸。如果未指定,則沿最後 len(s) 個軸計算轉換。如果 axess 均未指定,則沿所有軸計算轉換。

  • norm (str | None | None) – 字串,預設值=”backward”。標準化模式。“backward”、“ortho” 和 “forward” 受到支援。

傳回:

一個陣列,包含 a 的多維離散傅立葉轉換,其大小在軸 axes 上指定為 s,但沿軸 axes[-1] 除外。沿軸 axes[-1] 的輸出大小為 s[-1]//2+1

傳回類型:

Array

另請參閱

範例

>>> x = jnp.array([[[1, 3, 5],
...                 [2, 4, 6]],
...                [[7, 9, 11],
...                 [8, 10, 12]]])
>>> with jnp.printoptions(precision=2, suppress=True):
...   jnp.fft.rfftn(x)
Array([[[ 78.+0.j  , -12.+6.93j],
        [ -6.+0.j  ,   0.+0.j  ]],

       [[-36.+0.j  ,   0.+0.j  ],
        [  0.+0.j  ,   0.+0.j  ]]], dtype=complex64)

s=[3, 3, 4] 時,沿 axes (-3, -2) 的轉換大小將為 (3, 3),而沿 axis -1 的轉換大小將為 4//2+1 = 3,而沿其他軸的大小將與輸入相同。

>>> with jnp.printoptions(precision=2, suppress=True):
...   jnp.fft.rfftn(x, s=[3, 3, 4])
Array([[[ 78.   +0.j  , -16.  -26.j  ,  26.   +0.j  ],
        [ 15.  -36.37j, -16.12 +1.93j,   5.  -12.12j],
        [ 15.  +36.37j,   8.12-11.93j,   5.  +12.12j]],

       [[ -7.5 -49.36j, -20.45 +9.43j,  -2.5 -16.45j],
        [-25.5  -7.79j,  -0.6 +11.96j,  -8.5  -2.6j ],
        [ 19.5 -12.99j,  -8.33 -6.5j ,   6.5  -4.33j]],

       [[ -7.5 +49.36j,  12.45 -4.43j,  -2.5 +16.45j],
        [ 19.5 +12.99j,   0.33 -6.5j ,   6.5  +4.33j],
        [-25.5  +7.79j,   4.6  +5.04j,  -8.5  +2.6j ]]], dtype=complex64)

s=[3, 5]axes=(0, 1) 時,沿 axis 0 的轉換大小將為 3,沿 axis 1 的轉換大小將為 5//2+1 = 3,而沿其他軸的維度將與輸入相同。

>>> with jnp.printoptions(precision=2, suppress=True):
...   jnp.fft.rfftn(x, s=[3, 5], axes=[0, 1])
Array([[[ 18.   +0.j  ,  26.   +0.j  ,  34.   +0.j  ],
        [ 11.09 -9.51j,  16.33-13.31j,  21.56-17.12j],
        [ -0.09 -5.88j,   0.67 -8.23j,   1.44-10.58j]],

       [[ -4.5 -12.99j,  -2.5 -16.45j,  -0.5 -19.92j],
        [ -9.71 -6.3j , -10.05 -9.52j, -10.38-12.74j],
        [ -4.95 +0.72j,  -5.78 -0.2j ,  -6.61 -1.12j]],

       [[ -4.5 +12.99j,  -2.5 +16.45j,  -0.5 +19.92j],
        [  3.47+10.11j,   6.43+11.42j,   9.38+12.74j],
        [  3.19 +1.63j,   4.4  +1.38j,   5.61 +1.12j]]], dtype=complex64)

對於 1-D 輸入

>>> x1 = jnp.array([1, 2, 3, 4])
>>> jnp.fft.rfftn(x1)
Array([10.+0.j, -2.+2.j, -2.+0.j], dtype=complex64)