FFT学习笔记<代码篇>

对FFT理论不明白的童鞋可以来这里( FFT学习笔记<理论篇>):
http://blog.csdn.net/Monkey_king2017cn/article/details/77542160

在了解完FFT的理论与算法流程之后,最重要的当然就是写代码啦,下面的两份代码将展示FFT在多项式乘法与高精度乘法中的运用。

在那之前,还有一个重要的东西:
因为下面写的是迭代的FFT代码,而不是采用递归,所以多了一个对rev[]的处理:
我们假设每次将奇数项元素提出来之后,将其放到了序列的最后,如下:

01234567

变成:


02461357

我们一直这样分下去,就变成了:


04261537

我们把其中每个数的二进制翻转过来,发现是递增的:从0到7

那么,rev[i]存的就是将i的二进制位翻转过来的值,也就值i在底层的位置。

题目描述

给定一个n次多项式F(x),和一个m次多项式G(x)。

请求出F(x)和G(x)的卷积。

输入输出格式

输入格式:
第一行2个正整数n,m。

接下来一行n+1个数字,从低到高表示F(x)的系数。

接下来一行m+1个数字,从低到高表示G(x))的系数。

输出格式:
一行n+m+1个数字,从低到高表示F(x)∗G(x)的系数。

输入输出样例

输入样例#1:
1 2
1 2
1 2 1
输出样例#1:
1 4 5 2

#include<complex>
#prag\
ma GCC optimize("O2")

#define N 2097152
#define For(a, b, c) for(int a = b; a <= c; ++a)
using namespace std;
typedef complex <double> CP;//定义复数类
const double pi = 2.0 * acos(-1.0);//这里π是两倍,因为单位复根指数上有2
CP A1[N], A2[N];//将两个多项式定义为复数类,方便运算
int n, m, n3, s, rev[N];

int read(){
    int u = 0;
    char x = getchar();
    while(!isdigit(x)) x = getchar();
    while(isdigit(x)) u = (u << 3) + (u << 1) + (x ^ 48), x = getchar();
    return u;
}//读入优化

void pre_work(){//预处理
    int bit;
    n3 = n + m - 1;//结果的次数界
    for(s = 1, bit = 0; s < n3; s <<= 1, ++bit) ;//将长度补为2的整数次幂
    For(i, 1, s - 1)
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
       //这可以表述成:将i的末尾去掉翻转过来,如果末尾是1就在最前面填1
}

void FFT(CP *A, int l, int f){//f==-1时,则为IDFT
    For(i, 0, l - 1) if(i < rev[i]) swap(A[i], A[rev[i]]);
    //先让元素处于底层位置,这样相邻的一段就是在递归中的同一层的多项式
    for(int i = 2; i <= l; i <<= 1){
    //枚举区间长度,也就是当前处理的多项式的次数(从底往上->由短到长)
        CP wi(cos(pi / i), f * sin(pi / i));//计算主根
        for(int j = 0; j < l; j += i){
        //枚举处理的多项式的开端
            CP w(1, 0);//初始值为1
            For(k, j, j + (i >> 1) - 1){//计算当前多项式的结果
                CP x = A[k];//这里跟理论篇相同,利用上一层计算当前
                CP y = w * A[k + (i >> 1)];
                A[k] = x + y;
                A[k + (i >> 1)] = x - y;
                w *= wi;
            }
        }
    }
    if(f == -1) For(i, 0, l - 1) A[i] /= l;//若为IDFT,就除一下
}

void solve(){
    pre_work();
    FFT(A1, s, 1);
    FFT(A2, s, 1);//求值
    For(i, 0, s - 1) A1[i] *= A2[i];//逐点相乘
    FFT(A1, s, -1);//插值
    For(i, 0, n3 - 1) printf("%d ", (int)(A1[i].real() + 0.5));
    //输出答案
}

int main(){
#ifndef ONLINE_JUDGE
    freopen("A.in", "r", stdin);
    freopen("A.out","w",stdout);
#endif
    n = read(), m = read();//n,m是A1,A2的次数
    ++n, ++m;//现在n,m是a1,A2的次数界
    For(i, 0, n - 1) A1[i] = read();
    For(i, 0, m - 1) A2[i] = read();//读入多项式的系数
    solve();
    return 0;
}

题目描述

给出两个n位10进制整数x和y,你需要计算x*y。

输入输出格式

输入格式:
第一行一个正整数n。 第二行描述一个位数为n的正整数x。 第三行描述一个位数为n的正整数y。

输出格式:
输出一行,即x*y的结果。(注意判断前导0)

输入输出样例

输入样例#1:
1
3
4
输出样例#1:
12

#include<cstdio>
#include<complex>
#include<iostream>

#define N 180005
#define For(a, b, c) for(int a = b; a <= c; ++a)
using namespace std;
typedef complex <double> CP;
const double pi = 2.0 * acos(-1.0);
int n, an, s, rev[N], y[N];
CP A1[N], A2[N];

void pre_work(){
    an = (n << 1) - 1;
    int bit = 0;
    for(s = 1; s < an; s <<= 1, ++bit) ;
    For(i, 1, s - 1)
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
}

void FFT(CP *a, int l, int f){
    For(i, 0, l - 1) if(i < rev[i]) swap(a[i], a[rev[i]]);
    for(int i = 2; i <= l; i <<= 1){
        CP wi (cos(pi / i), f * sin(pi / i));
        for(int j = 0; j < l; j += i){
            CP w (1, 0);
            For(k, j, j + (i >> 1) - 1){
                CP x = a[k];
                CP y = w * a[k + (i >> 1)];
                a[k] = x + y;
                a[k + (i >> 1)] = x - y;
                w *= wi;
            }
        }
    }
    if(f == -1) For(i, 0, l - 1) a[i] /= l;
}

void solve(){
    pre_work();
    FFT(A1, s, 1);
    FFT(A2, s, 1);
    For(i, 0, s - 1) A1[i] *= A2[i];
    FFT(A1, s, -1);
    For(i, 0, s - 1) y[i] = (int)(A1[i].real() + 0.5);
    //求出每一位上的值,并赋予y[]
    For(i, 0, an - 1){//进位
        while(y[i] >= 10){
            y[i + 1] += y[i] / 10;
            y[i] %= 10;
            ++i;
        }
        if(i >= an) an = i + 1;
    }
    while(!y[an - 1]) --an;//去前导0
    if(!an) ++an;//特判结果为0的情况
    For(i, 1, an) printf("%d", y[an - i]);//输出
}

int main(){
#ifndef ONLINE_JUDGE
    freopen("pro.in", "r", stdin);
    freopen("pro.out","w",stdout);
#endif
    char c;
    scanf("%d", &n);
    For(i, 1, n) cin >> c, A1[n - i] = CP ((c ^ 48), 0);
    For(i, 1, n) cin >> c, A2[n - i] = CP ((c ^ 48), 0);
    //将高位置于后面
    solve();
    return 0;
}
点赞