本博客学习于洛谷P3803各路大佬题解,所以你肯定会找到很多相似的地方QWQ
什么是FFT 傅里叶变换
傅里叶说明了:一切函数/波形都可以用有限或无限个弦函数/波形叠加形成。
例如,现在有一个由三个不同正弦波组成的近似矩形波,将这些波形关于 $x$ 轴平行地放在一起,就可以得到下图:
其中红色的是合成出来的近似矩形波。
(当然,图肯定是不太标准的,意思一下就行)
从 $xOz$ 平面看,我们可以获得各个波的时域信号,也就是 $t-f(t)$ 图像。而从 $zOy$ 平面看,则可以获得各个波的频域信号。
先不看频域信号中的红色部分,注意其他三个波其实是有排序规则的:按照频率大小排序。也就是说,频率越快, $y$ 越大。且每个频率下的 $z$ 将反映该波的振幅。两种不同的信号包含了不同的信息。
而傅里叶变换,就是对一个时域信号作变换,生成频域信号。
离散傅里叶变换(DFT)
顾名思义,离散傅里叶变换就是在确定时域信号和频域信号都是离散的时候作的傅里叶变换。
例如,现在有一个多项式函数 $f(x)=\sum_{i=0}^{n}a_ix^i$,那么这个函数其实就是“时域信号”,我们也可以将其表示成点集$\{P_0,P_1,…,P_n\}$,其中 $P_i$ 是函数图像上的某一个点。容易知道 $n$ 个点是可以确定一个 $n-1$ 次多项式的。那么这个点集其实就是“频域信号”。从函数式到点集的变化,就是一种离散傅里叶变换。
另外,我们称IDFT为DFT的逆运算,即从点集得到函数式的运算。
为什么需要FFT 离散型傅里叶变换解决的问题一般是多项式乘法问题。
最经典的FFT问题即多项式卷积问题:
P3803 【模板】多项式乘法(FFT) - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)
给一个 $n$ 次多项式 $F(x)$ 和一个 $m$ 次多项式 $G(x)$,求 $F(x)$ 和 $G(x)$ 的卷积。
如果直接暴力计算,第一种就是直接展开函数式算,第二种就是执行一次DFT,将函数式转换为点集,通过对 $x$ 相同的点的 $y$ 值相乘,得到目标函数的点集,再通过 DFT 的逆运算 IDFT 转化为答案。
显然,不管采取哪种方法,一共都有 $m · n$ 项,复杂度为 $O((max\{n, m\})^2)$. 在 $nm\geq 1e9$ 的时候是无法接受的。虽然朴素方法下第二种方法的常数肯定比第一种还要大,但其实我们可以在DFT变换的基础上使用FFT来加速多项式乘法。实际上,FFT的加速原因在于其采用了分治的思想。
多项式乘法问题 前置知识-复数根 考虑欧拉公式
,容易知道 $e^{2\pi i}=1$.
不妨令 $w_n=e^{2\pi i\over n}=\sin{2\pi\over n}+i\cos{2\pi\over n}$, 则 $w_n^n=1$.
则称 $w_n$ 为 $n$ 的一个复数z单位根. 容易知道 $n$ 的复数根最多有 $n$ 个,为 $\{w_i|i\in [0,n-1]\}$. 这是因为复数根具有性质 $w_n^{n+k}=w_n^k$.
另外,还有一个通过消去定理得到的引理:$w_n^k=w_{n\over2}^{k\over2}$.
使用FFT解决多项式乘法问题 例题:
P3803 【模板】多项式乘法(FFT)
DFT
现在要将$F(x)$的表达式转化为点值表示法,那么我们需要取$n$个点。前面已经讨论过,朴素的取法复杂度是 $O(n^2)$ 的,现在我们需要构造一个可以分治的取法:
设$ F(x)$ 的项数 $n$ 为偶数,并令
$A(x)=a_0+a_2x+a_4x^2+…+a_nx^{n\over 2}$,
$B(x)=a_1+a_3x+a_5x^2+…+a_{n-1}x^{\frac{n}{2}-1}$.
则 $F(x)=A(x^2)+xB(x^2)$.
代入 $n$ 个单位复数根 $w_n^k, k\in[0, n - 1]$,有
$F(w_n^k)=A(w_{n}^{2k})+w_n^kB(w_n^{2k})$
而 $A(w_n^{2k})=a_0+a_2w_n^{2k}+…+a_nw_n^{2k\over n}$
$=a_0+a_2w_{n\over2}^k+…+a_nw_{n\over2}^{k\over n}$.
且对 $B(w_n^{2k})$,我们也有相似的结论。
故
$F(w_n^k)=A(w_{n\over2}^k)+w_n^kB(w_{n\over2}^{k})\ (k<{n\over2})$.
考虑 $k\geq {n\over2}$ 的情况,令 $k+{n\over 2}$ 取代原来位置上的 $k$,
$F(w_n^{{n\over 2}+k})=A(w_n^{n+2k})+w_n^{{n\over 2}+k}B(w_n^{2k})$,
$=A(w_n^{2k})-w_n^{k}B(w_n^{2k})$,
$=A(w_{n\over2}^k)-w_n^kB(w_{n\over2}^k)$.
那么在求 $F(w_n^k)$ 时,我们可以先求 $A(w_{n\over2}^{k})$ 以及 $B(w_{n\over2}^{k})$,然后再合并出 $F(w_n^k)$. 到这里,分治递归的可行性就十分显然了。由主定理可知复杂度为 $T(n)=T(\frac{n}{2})+O(n)=O(n\log n)$.
**IDFT**
IDFT 即 DFT 的逆变换。我们的目的是求最终表达式的各个系数,而通过 DFT 得到最终的点值表达式后,我们还需要将其逆向变回系数表达式。
其实,这里只需要取 $w'=\overline {w_n^k}$. 即单位根的共轭复数,然后再执行一遍相似的分治,并且最后得到的多项式系数都除以一个 $n $ 就可以了。(不再展开证明 QWQ,了解就好,因为本蒟蒻也不会)
CPP:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 #define PI acos(-1.0) void FFT (complex<double > F[], int n, int op) { if (!n) return ; complex<double > a[n], b[n]; for (int i = 0 ; i < n; ++i) { a[i] = F[i << 1 ], b[i] = F[i << 1 | 1 ]; } FFT (a, n >> 1 , op), FFT (b, n >> 1 , op); complex<double > wn (cos(PI / n), sin(PI / n) * op) , w (1 , 0 ) ; for (int i = 0 ; i < n; ++i) { F[i] = a[i] + w * b[i]; F[i + n] = a[i] - w * b[i]; w *= wn; } }
此为FFT的递归写法,由于单位根的原因需要使用 `double` 数据类型,且需要一个复数结构体,每层分治还要额外开数组,其常数比较大。另外还有一种FFT的迭代写法,常数更小。(需要的可以去别处找,或者我说不定会更?QAQ)
****
回到该问题,由于需要分治的任何时候都要保证 $n$ 为偶数,故我们需要把 $n,m$ 补成相同的一个 $2$ 的幂次方。项数补全不影响结果,只要让多出来的项系数都为0就行了。
故整个流程即:
1. 对 $F(x),G(x)$ 分别求 DFT
2. 直接将每个点值相乘
3. 通过IDFT将点值还原成系数表达式,得到答案
对于例题[P3803](https://www.luogu.com.cn/problem/P3803),AC代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 #include <bits/stdc++.h> #define PI acos(-1.0) using namespace std;const int maxn = 3e6 + 10 ;complex <double > F[maxn], G[maxn]; void fft (complex<double > f[], int n, int op) { if (!n) return ; complex <double > a[n], b[n]; for (int k = 0 ; k < n; ++k) { a[k] = f[k << 1 ], b[k] = f[k << 1 | 1 ]; } fft (a, n >> 1 , op), fft (b, n >> 1 , op); complex<double > wn (cos(PI / n), sin(PI / n) * op) , w (1 , 0 ) ; for (int k = 0 ; k < n; ++k, w *= wn) { f[k] = a[k] + w * b[k], f[k + n] = a[k] - w * b[k]; } } int n, m;int main () { scanf ("%d%d" , &n, &m); for (int i = 0 ; i <= n; ++i) { scanf ("%lf" , &F[i]); } for (int i = 0 ; i <= m; ++i) { scanf ("%lf" , &G[i]); } m += n, n = 1 ; while (n <= m) n <<= 1 ; fft (F, n >> 1 , 1 ); fft (G, n >> 1 , 1 ); for (int i = 0 ; i < n; ++i) F[i] *= G[i]; fft (F, n >> 1 , -1 ); for (int i = 0 ; i <= m; ++i) { printf ("%.0f " , fabs (F[i].real ()) / n); } return 0 ; }
基本上和洛谷上 FlashHu 的题解代码一样...因为我的FFT基本上是从这里理解的
## NTT-快速数论变换
很多时候,要求的多项式系数并不会有多大,`double` 显得有些浪费空间了,而且很容易产生精度问题。故可以使用另一种 “原根” 替代复数来解决多项式乘法问题,这样就可以避免使用 `double` 以及 `complex`. 如果有系数较大的多项式需要计算且考察 NTT 时,一般会要求对结果取模。
### 前置知识-原根
在 NTT 算法中,原根是 FFT 中复数的替代品。~~(?又说一遍)~~
要理解原根,首先我们需要群论中的若干定义。
**循环群**
如果一个群 $A$ 的所有元素 $x$ 都是 $a$ 的幂次,即$\forall x\in A, x=a^i(i\in \Z)$,则称 $A$ 是一个循环群,$a$ 是它的生成元。
**阶**
假设 $e$ 是某个循环群的单位元,则若 $a^p=1$,则 $p$ 的最小值即为 $a$ 的阶,记作 $Ord(a)$.
**原根**
原根的一般定义是:对于群 $G$, $\exist g\in G$, $Ord(g)=|G|$,则 $g$ 是 $G$ 的原根。
现在有一个以加法取模为运算,元素为整数的群,容易知道这是一个循环群,单位元是1. 我们约定,如果这样的一个群的模数是 $p$,则其所在的群是 $G_p$.
那么如果我们的问题是这样:求多项式系数,同时多项式系数对一个模数 $p$ 取余,显然系数构成一个群 $G_p$. 设 $n=|G|$,容易知道 $g^{p-1}=1$, 则不妨令 $g_n=g^{\frac{p-1}{n}}$, 有 $g_n^n=1$.
从这里我们就看出,$g_n$ 和 FFT 中的 $w_n$ 是有一一对应关系的,所以我们只需要将 FFT 的点值取值修改为 $g_n^0,g_n^1,...g_n^{n-1}$ 即可。
至于原根如何求出,我们可以找一个 $p-1$ 的质因子 $q$, 则 $g$ 为原根 $\lrArr$ $g^{\frac{p-1}{q}}\neq 1$ (由欧拉定理可证),将 $p-1$ 的质因子逐个找出再判断一下即可。通常题目会要求 $p=119×2^{23}+1=998244353,\ g=3$.
> 据说,因为大多数NTT题目都会取更方便造数据的998244353为模数,这就导致很多题一看模数就知道这题是要做NTT。所以后来许多和NTT无关甚至完全不需要取模的题目也会加上一句“对998244353取模”。
理解粗浅,还是老老实实看模板吧。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 #include <bits/stdc++.h> #define int long long using namespace std;const int maxn = 3e6 + 10 ;int F[maxn], G[maxn];int r[maxn];const int mod = 998244353 ;int fpow (int a, int p) { int ans = 1 ; while (p) { if (p & 1 ) ans = (ans * a) % mod; a = (a * a) % mod; p >>= 1 ; } return ans; } void ntt (int f[], int n, int op) { if (!n) return ; for (int i = 0 ; i < n; ++i) { if (i < r[i]) swap (f[i], f[r[i]]); } for (int mid = 1 ; mid < n; mid <<= 1 ) { int wn = fpow (op == 1 ? 3 : 332748118 , (mod - 1 ) / (mid << 1 )); for (int j = 0 ; j < n; j += (mid << 1 )) { int w = 1 ; for (int k = 0 ; k < mid; ++k, w = w * wn % mod) { int x = f[j + k], y = w * f[j + k + mid] % mod; f[j + k] = (x + y) % mod; f[j + k + mid] = ((x - y) % mod + mod) % mod; } } } } int n, m;signed main () { scanf ("%lld%lld" , &n, &m); for (int i = 0 ; i <= n; ++i) { scanf ("%lld" , &F[i]); F[i] = (F[i] + mod) % mod; } for (int i = 0 ; i <= m; ++i) { G[i] = (G[i] + mod) % mod; scanf ("%lld" , &G[i]); } m += n, n = 1 ; int cnt = 0 ; while (n <= m) n <<= 1 , cnt++; for (int i = 0 ; i < n; ++i) r[i] = (r[i >> 1 ] >> 1 ) | ((i & 1 ) << (cnt - 1 )); ntt (F, n, 1 ), ntt (G, n, 1 ); for (int i = 0 ; i < n; ++i) F[i] = (F[i] * G[i]) % mod; ntt (F, n, -1 ); int inv = fpow (n, mod - 2 ); for (int i = 0 ; i <= m; ++i) { printf ("%lld " , (F[i] * inv) % mod); } return 0 ; }