快速傅里叶变换(FFT)

chzhc / 2025-02-18 / 原文

快速傅里叶变换(FFT)

前言

本文为个人学习笔记,大量参考了 oi-wiki 以及其他博客的内容。

问题

记:

\[f(x) = c_0 + c_1 x + c_2 x^2 + \cdots + c_{n}x^{n} \\ g(x) = d_0 + d_1 x + d_2 x^2 + \cdots + d_{m}x^{m} \\ h(x) = f(x) \times g(x) \]

\(\mathcal O(n \log n)\) 内解决两个多项式乘法后的系数(即给定 \(f(x)\)\(g(x)\) 的系数,要你求出 \(h(x)\) 的系数)。

分析

暴力显然是 \(\mathcal O(n^2)\) 的,优化的想法是先考虑点值表示,再考虑从点值表示转换为系数表示。

具体如下:

点值表示的意思是,你需要求出(\(\omega_n^k\) 是什么先忽略,当作是 \(n\) 个已知量即可):

\[f(\omega_n^0), f(\omega_n^1), \cdots f(\omega_n^{n-1}) \\ g(\omega_n^0), g(\omega_n^1), \cdots g(\omega_n^{n-1}) \]

那么:

\[h(\omega_n^k) = f(\omega_n^k) \times g(\omega_n^k) \]

实际上,\(n\) 个点的点值表示法也确定了一个 \(n - 1\) 次的多项式,因此,一定存在某个算法能将点值表示法转化为系数表示(这个后面再说)。

至此,FFT 的核心思想已经说清楚了,就是考虑求出 \(f, g\) 的点值表示,那么 \(h\) 的点值表示就可以在 \(\mathcal O(n)\) 的复杂度内求出,而后再考虑从点值表示转化为系数表示。

问题一:求出某个多项式的点值表示(离散傅里叶变换 DFT)

实际上这个问题的真实含义是:怎么选取 \(\omega_n^k\) 这个已知量才能使得在一个优秀的复杂度内求出多项式的点值表示。

\(\omega_n^k\) 表示将复数坐标系的单位圆平均分成 \(n\) 份,从 \(x\) 轴逆时针出发的第 \(k\) 条分界箭头的复数表示。

选取这个 \(\omega_n^k\) 的原因是它有某些性质,能 " 在一个优秀的复杂度内求出多项式的点值表示 "。

性质:

\[1) \ \omega_n^k = \omega_{\frac{n}{2}}^{\frac{k}{2}} \ \ \ \ \ \ \ \\ 2) \ \omega_n^{k + \frac{n}{2}} = - \omega_n^k \]

然后开始推式子:

\[f_1(x) = c_0 + c_2x + \cdots + c_{n - 2}x^{\frac{n - 2}{2}} \\ f_2(x) = c_1 + c_3x + \cdots + c_{n - 1}x^{\frac{n - 2}{2}} \]

显然有

\[f(x) = f_1(x^2) + xf_2(x^2) \]

\(\omega _n^k(0 \leq k < \frac{n}{2})\) 代入有:

\[\begin {split} f(\omega_{n}^k) &= f_1(\omega _n^{2k}) + \omega_n^k f_2(\omega_n^{2k}) \\ &= f_1(\omega _{\frac{n}{2}}^{k}) + \omega_n^k f_2(\omega _{\frac{n}{2}}^{k}) \end {split} \]

\(\omega_n^{k + \frac{n}{2}}(0 \leq k < \frac{n}{2})\) 代入有:

\[\begin {split} f(\omega_n^{k + \frac{n}{2}}) &= f_1(\omega _n^{2k + n}) + \omega_n^{k + \frac{n}{2}} f_2(\omega_n^{2k + n}) \\ &= f_1(\omega _{\frac{n}{2}}^{k}) - \omega_n^k f_2(\omega _{\frac{n}{2}}^{k}) \end {split} \]

递归求解即可,有 \(\log n\) 层,时间复杂度为 \(\mathcal O(n \log n)\),为了方便处理,一般把 \(n\) 处理为 \(\geq(n + m)\) 的二次幂,多出来的部分系数补为 \(0\) 即可。

具体实现中有以下注意事项:

1、虚数可以使用 C++ STL 库中的 complex 类型;

代码

#include <bits/stdc++.h>
template < typename T >
inline void read(T &cnt) {
    cnt = 0; char ch = getchar(); bool op = 1;
    for (; ! isdigit(ch); ch = getchar())
        if (ch == '-') op = 0;
    for (; isdigit(ch); ch = getchar())
        cnt = cnt * 10 + ch - 48;
    cnt = op ? cnt : - cnt;
}

const int N = (1 << 22) + 5;
const double PI = acos(-1);

inline void FFT(std::complex < double > *A, int n) {
    if (n == 1) return;
    int m = (n >> 1);
    std::complex < double > A0[m], A1[m];
    for (int i = 0; i < m; ++ i) {
        A0[i] = A[i * 2];
        A1[i] = A[i * 2 + 1];
    }
    FFT(A0, m); FFT(A1, m); // 递归处理
    auto W = std::complex < double > (cos(2.0 * PI / n), sin(2.0 * PI / n)),
         w = std::complex < double > (1.0, 0.0); // 从 w_n^0 出发
    for (int i = 0; i < m; ++ i) { // 根据式子计算 A 即可
        A[i] = A0[i] + w * A1[i];
        A[i + m] = A0[i] - w * A1[i];
        w *= W; // 等价于 w_n^k -> w_n^{k + 1}
    }
}

int n, m;
std::complex < double > F[N], G[N];

int main() {
    read(n), read(m);
    for (int i = 0; i <= n; ++ i) {
        int x; read(x);
        F[i] = x;
    }
    for (int i = 0; i <= m; ++ i) {
        int x; read(x);
        G[i] = x;
    }
    int sum = 1;
    while (sum <= n + m) sum *= 2; // 补齐成二次幂
    FFT(F, sum);
    FFT(G, sum);
    for (int i = 0; i < sum; ++ i)
        F[i] *= G[i];
    return 0;   
}

优化

递归实在太慢了!

\(8\) 项多项式为例,模拟拆分的过程:

  • 初始序列为 \(\{x_0, x_1, x_2, x_3, x_4, x_5, x_6, x_7\}\)
  • 一次二分之后 \(\{x_0, x_2, x_4, x_6\},\{x_1, x_3, x_5, x_7 \}\)
  • 两次二分之后 \(\{x_0,x_4\} \{x_2, x_6\},\{x_1, x_5\},\{x_3, x_7 \}\)
  • 三次二分之后 \(\{x_0\}\{x_4\}\{x_2\}\{x_6\}\{x_1\}\{x_5\}\{x_3\}\{x_7 \}\)

规律:其实就是原来的那个序列,每个数用二进制表示,然后把二进制翻转对称一下,就是最终那个位置的下标。比如 \(x_1\) 是 001,翻转是 100,也就是 4,而且最后那个位置确实是 4。我们称这个变换为位逆序置换(bit-reversal permutation),证明留给读者自证。

实际上,位逆序置换可以 \(\mathcal O(n)\) 从小到大递推实现,设 \(len=2^k\),其中 k 表示二进制数的长度,设 \(R(x)\) 表示长度为 \(k\) 的二进制数 \(x\) 翻转后的数(高位补 0)。我们要求的是 \(R(0),R(1),\cdots,R(n-1)\)

首先 \(R(0)=0\)

我们从小到大求 \(R(x)\)。因此在求 \(R(x)\) 时,\(R\left(\left\lfloor \dfrac{x}{2} \right\rfloor\right)\) 的值是已知的。因此我们把 \(x\) 右移一位(除以 \(2\)),然后翻转,再右移一位,就得到了 \(x\) 除了(二进制)个位之外其它位的翻转结果。

考虑个位的翻转结果:如果个位是 0,翻转之后最高位就是 0。如果个位是 1,则翻转后最高位是 1,因此还要加上 \(\dfrac{len}{2}=2^{k-1}\)。综上

\[R(x)=\left\lfloor \frac{R\left(\left\lfloor \frac{x}{2} \right\rfloor\right)}{2} \right\rfloor + (x\bmod 2)\times \frac{len}{2} \]

举个例子:设 \(k=5\)\(len=(100000)_2\)。为了翻转 \((11001)_2\)

  1. 考虑 \((1100)_2\),我们知道 \(R((1100)_2)=R((01100)_2)=(00110)_2\),再右移一位就得到了 \((00011)_2\)
  2. 考虑个位,如果是 \(1\),它就要翻转到数的最高位,即翻转数加上 \((10000)_2=2^{k-1}\),如果是 \(0\) 则不用更改。

蝶形运算优化

已知 \(f_1(\omega_{n/2}^k)\)\(f_2(\omega_{n/2}^k)\) 后,需要使用下面两个式子求出 \(f(\omega_n^k)\)\(f(\omega_n^{k+n/2})\)

\[\begin{aligned} f(\omega_n^k) & = f_1(\omega_{n/2}^k) + \omega_n^k \times f_2(\omega_{n/2}^k) \\ f(\omega_n^{k+n/2}) & = f_1(\omega_{n/2}^k) - \omega_n^k \times f_2(\omega_{n/2}^k) \end{aligned} \]

使用位逆序置换后,对于给定的 \(n, k\)

  • \(f_1(\omega_{n/2}^k)\) 的值存储在数组下标为 \(k\) 的位置,\(f_2(\omega_{n/2}^k)\) 的值存储在数组下标为 \(k + \dfrac{n}{2}\) 的位置。
  • \(f(\omega_n^k)\) 的值将存储在数组下标为 \(k\) 的位置,\(f(\omega_n^{k+n/2})\) 的值将存储在数组下标为 \(k + \dfrac{n}{2}\) 的位置。

因此可以直接在数组下标为 \(k\)\(k + \frac{n}{2}\) 的位置进行覆写,而不用开额外的数组保存值。此方法即称为 蝶形运算,或更准确的,基 - 2 蝶形运算。

再详细说明一下如何借助蝶形运算完成所有段长度为 \(\frac{n}{2}\) 的合并操作:

1、令段长度为 \(s = \frac{n}{2}\)
2、同时枚举序列 \(\{f_1(\omega_{n/2}^k)\}\) 的左端点 \(l_g = 0, 2s, 4s, \cdots, N-2s\) 和序列 \(\{f_2(\omega_{n/2}^k)\}\) 的左端点 \(l_h = s, 3s, 5s, \cdots, N-s\)
3、合并两个段时,枚举 \(k = 0, 1, 2, \cdots, s-1\),此时 \(f_1(\omega_{n/2}^k)\) 存储在数组下标为 \(l_g + k\) 的位置,\(f_2(\omega_{n/2}^k)\) 存储在数组下标为 \(l_h + k\) 的位置;
4、使用蝶形运算求出 \(f(\omega_n^k)\)\(f(\omega_n^{k+n/2})\),然后直接在原位置覆写。

代码

#include <bits/stdc++.h>
template < typename T >
inline void read(T &cnt) {
    cnt = 0; char ch = getchar(); bool op = 1;
    for (; ! isdigit(ch); ch = getchar())
        if (ch == '-') op = 0;
    for (; isdigit(ch); ch = getchar())
        cnt = cnt * 10 + ch - 48;
    cnt = op ? cnt : - cnt;
}

const int N = (1 << 22) + 5;
const double PI = acos(-1);

int rev[N];

inline void change(std::complex < double > *A, int n) {
    for (int i = 0; i < n; ++ i) { // 求 R 数组
        rev[i] = rev[i >> 1] >> 1;
        if (i & 1) {
            rev[i] |= (n >> 1);
        }
    }

    for (int i = 0; i < n; ++ i) // 将原序列 变为 底层对应的序列
        if (i < rev[i]) std::swap(A[i], A[rev[i]]);
} 
inline void FFT(std::complex < double > *A, int n) {
    change(A, n);
    for (int m = 2; m <= n; m *= 2) { // m 是当前处理的每段长度
        auto W = std::complex < double > 
          (cos(2.0 * PI / m), sin(2.0 * PI / m));  
        for (int x = 0; x < n; x += m) { // x 是每段的开头
            auto w = std::complex < double > (1.0, 0.0);
            for (int i = x; i < x + m / 2; ++ i) { // 求出每段的点值表示 根据公式求即可
                auto A0 = A[i], A1 = A[i + m / 2];
                A[i] = A0 + w * A1;
                A[i + m / 2] = A0 - w * A1;
                w *= W;
            }
        }
    }
}


int n, m;
std::complex < double > F[N], G[N];

int main() {
    change(F, 8);
    read(n), read(m);
    for (int i = 0; i <= n; ++ i) {
        int x; read(x);
        F[i] = x;
    }
    for (int i = 0; i <= m; ++ i) {
        int x; read(x);
        G[i] = x;
    }
    int sum = 1;
    while (sum <= n + m) sum *= 2; // 补齐成二次幂
    FFT(F, sum);
    FFT(G, sum);
    for (int i = 0; i < sum; ++ i)
        F[i] *= G[i];
    FFT(F, sum);
    return 0;   
}

问题二:将点值表示转化为系数表示(傅里叶反变换 IDFT)

点值表示的矩阵形式为:

\[\begin{bmatrix}f(\omega_n^0) \\ f(\omega_n^1) \\ f(\omega_n^2) \\ f(\omega_n^3) \\ \vdots \\ f(\omega_n^{n-1}) \end{bmatrix} = \begin{bmatrix}1 & 1 & 1 & 1 & \cdots & 1 \\ 1 & \omega_n^1 & \omega_n^2 & \omega_n^3 & \cdots & \omega_n^{n-1} \\ 1 & \omega_n^2 & \omega_n^4 & \omega_n^6 & \cdots & \omega_n^{2(n-1)} \\ 1 & \omega_n^3 & \omega_n^6 & \omega_n^9 & \cdots & \omega_n^{3(n-1)} \\ \vdots & \vdots & \vdots & \vdots & \ddots & \vdots \\ 1 & \omega_n^{n-1} & \omega_n^{2(n-1)} & \omega_n^{3(n-1)} & \cdots & \omega_n^{(n-1)^2} \end{bmatrix} \begin{bmatrix} a_0 \\ a_1 \\ a_2 \\ a_3 \\ \vdots \\ a_{n-1} \end{bmatrix} \]

怎么求系数 \(a\) 呢?根据线性代数的知识:

\[Ax = b \\ x = A^{-1}b \]

如果能求出 \(A^{-1}\),那么 \(A^{-1} b\) 也是两个多项式相乘的结果,FFT 即可。

唯一的问题变为怎么求解 \(A^{-1}\)

根据矩阵的逆的定义,有

\[A^{-1} \cdot A = E \]

\(V\) 为原矩阵,\(G\) 为逆矩阵,考虑最终落在 \(E(i, j)\) 的值:

\[E(i, j) = \sum_{k=0}^{n-1} G(i, k) \cdot V(k, j) = \sum_{k=0}^{n-1} G(i, k) \cdot \omega_n^{kj} = [i == j] \]

引理

\(k\) 不是 \(n\) 的倍数时,

\[\sum_{i=0}^{n-1}\omega_n^{ki} = 0 \]

证明如下:

\[\sum_{i=0}^{n-1}\omega_n^{ki} = \frac{\omega_n^{kn} - 1}{1 - \omega_n^{k}} = \frac{1 - 1}{1 - \omega_n^{k}} = 0 \]

\(G(i, k) = \omega_n^{-ik}\),则:

\[\sum_{k=0}^{n-1} G(i, k) \cdot \omega_n^{kj} = \sum_{k=0}^{n-1} \omega_n^{-ik} \cdot \omega_n^{kj} = \sum_{k=0}^{n-1} \omega_n^{k(j-i)} \]

\(j-i\) 不为 \(n\) 的倍数(0)时,上式为 0;

反之,有:

\[\sum_{k=0}^{n-1} \omega_n^{k(j-i)} = \sum_{k=0}^{n-1} \omega_n^{0} = n \]

再前面补个系数 \(\frac{1}{n}\) 即可,故:

\[G(i, k) = \frac{1}{n}\omega_n^{-ik} \]

int main() {
    change(F, 8);
    read(n), read(m);
    for (int i = 0; i <= n; ++ i) {
        int x; read(x);
        F[i] = x;
    }
    for (int i = 0; i <= m; ++ i) {
        int x; read(x);
        G[i] = x;
    }
    int sum = 1;
    while (sum <= n + m) sum *= 2; // 补齐成二次幂
    FFT(F, sum);
    FFT(G, sum);
    for (int i = 0; i < sum; ++ i)
        F[i] *= G[i];
    FFT(F, sum);
    std::reverse(F + 1, F + sum); // 从第一位开始翻转
								  // 翻转后变为 0 1-n, 2-n, ..., -1 
    							  // 实际上等价于 0, 1, 2, ..., n-1
    for (int i = 0; i <= n + m; ++ i) { // 四舍五入
        std::cout << (int)(F[i].real() / sum + 0.5) << ' ';
    }
    return 0;   
}