FFT

【模板】多项式乘法(FFT)

#include <bits/stdc++.h>
#define fo(i, a, b) for (int i = a; i <= b; i++)
using namespace std;
typedef complex<double> cp;

const int N = 4e6 + 5;
const double pi = acos(-1);

cp a[N], b[N];

int n, m, limit, rev[N];

void fft(cp *a, int n, int inv)
{
    int bit = 0;
    while ((1 << bit) < n) bit++;
    for (int i = 0; i < n; i++) {
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
        if (i < rev[i]) swap(a[i], a[rev[i]]);//不加这条if会交换两次(就是没交换)
    }
    for (int mid = 1; mid < n; mid *= 2) {//mid是准备合并序列的长度的二分之一
        cp temp(cos(pi / mid), inv * sin(pi / mid));//单位根,pi的系数2已经约掉了
        for (int i=0;i<n;i+=mid*2)//mid*2是准备合并序列的长度,i是合并到了哪一位
        {
            cp omega(1,0);
            for (int j=0;j<mid;j++,omega*=temp)//只扫左半部分,得到右半部分的答案
            {
                cp x=a[i+j],y=omega*a[i+j+mid];
                a[i+j]=x+y,a[i+j+mid]=x-y;
            }
        }
    }
}

int main() {
    scanf("%d %d", &n, &m);
    limit = 1; while(limit < n + m + 1) limit <<= 1;
    for (int i = 0; i <= n; i++) {
        int x; scanf("%d", &x);
        a[i].real(x);
    }
    for (int i = 0; i <= m; i++) {
        int x; scanf("%d", &x);
        b[i].real(x);
    }
    fft(a, limit, 1); fft(b, limit, 1);
    for (int i = 0; i <= limit; i++) a[i] *= b[i];
    fft(a, limit, -1);
    for (int i = 0; i <= n + m; i++) printf("%d ", (int)(a[i].real() / limit + 0.5));
    return 0;
}