jax.scipy.fft.dct#

jax.scipy.fft.dct(x, type=2, n=None, axis=-1, norm=None)[source]#

計算輸入的離散餘弦轉換

JAX 實現的 scipy.fft.dct()

參數:
  • x (Array) – 陣列

  • type (int) – 整數,預設值 = 2。目前僅支援 type 2。

  • n (int | None | None) – 整數,預設值 = x.shape[axis]。轉換的長度。如果大於 x.shape[axis],輸入將會補零;如果小於,輸入將會被截斷。

  • axis (int) – 整數,預設值=-1。執行 dct 的軸。

  • norm (str | None | None) – 字串。正規化模式:[None, "backward", "ortho"] 其中之一。預設值為 None,相當於 "backward"

回傳:

包含 x 的離散餘弦轉換的陣列

回傳型別:

Array

另請參閱

範例

>>> x = jax.random.normal(jax.random.key(0), (3, 3))
>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jax.scipy.fft.dct(x))
[[ 6.43  3.56 -2.86]
 [-1.75  1.55 -1.4 ]
 [ 1.33 -2.01 -0.82]]

n 小於 x.shape[axis]

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jax.scipy.fft.dct(x, n=2))
[[ 7.3  -0.57]
 [ 0.19 -0.36]
 [-0.   -1.4 ]]

n 小於 x.shape[axis]axis=0

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jax.scipy.fft.dct(x, n=2, axis=0))
[[ 3.09  4.4  -2.81]
 [ 2.41  2.62  0.76]]

n 大於 x.shape[axis]axis=1

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jax.scipy.fft.dct(x, n=4, axis=1))
[[ 6.43  4.88  0.04 -3.3 ]
 [-1.75  0.73  1.01 -2.18]
 [ 1.33 -1.05 -2.34 -0.07]]