Some Code Template Just for Fun.
#include "Polynomial/NttLarge.hpp"
#include "Ntt.hpp"
template <class Z>
std::vector<Z> multiplyLarge(const std::vector<Z>& v1,
const std::vector<Z>& v2) {
using FPS = std::vector<Z>;
static NTT<Z> ntt;
const int L = 1 << 23;
if (v1.size() == 0 or v2.size() == 0 or v1.size() + v2.size() - 1 <= L)
return ntt.multiply(v1, v2);
std::vector<FPS> v1s, v2s;
for (int i = 0; i < v1.size(); i += L / 2) {
FPS nxt(
{v1.begin() + i, v1.begin() + std::min(i + L / 2, int(v1.size()))});
nxt.resize(L);
ntt.dft(nxt);
v1s.push_back(nxt);
}
for (int i = 0; i < v2.size(); i += L / 2) {
FPS nxt(
{v2.begin() + i, v2.begin() + std::min(i + L / 2, int(v2.size()))});
nxt.resize(L);
ntt.dft(nxt);
v2s.push_back(nxt);
}
std::vector<FPS> cs(v1s.size() + v2s.size() - 1, FPS(L));
for (int x = 0; x < v1s.size(); ++x)
for (int y = 0; y < v2s.size(); ++y)
for (int i = 0; i < L; ++i) cs[x + y][i] += v1s[x][i] * v2s[y][i];
for (int i = 0; i < v1s.size() + v2s.size() - 1; ++i) {
ntt.idft(cs[i]);
}
FPS ret(v1.size() + v2.size() - 1);
for (int i = 0; i < cs.size(); ++i) {
for (int j = 0; j < L; ++j) {
int pos = i * L / 2 + j;
if (pos >= ret.size()) break;
ret[pos] += cs[i][j];
}
}
return ret;
}
#line 2 "Polynomial/Ntt.hpp"
#line 1 "Template/Power.hpp"
template <class T>
T power(T a, int b) {
T res = 1;
for (; b; b /= 2, a *= a) {
if (b % 2) {
res *= a;
}
}
return res;
}
#line 2 "Template/Template.hpp"
#include <bits/stdc++.h>
using i64 = std::int64_t;
#line 5 "Polynomial/Ntt.hpp"
template <class Z>
struct NTT {
std::vector<int> rev;
std::vector<Z> roots{0, 1};
static constexpr int getRoot() {
auto _mod = Z::get_mod();
using u64 = uint64_t;
u64 ds[32] = {};
int idx = 0;
u64 m = _mod - 1;
for (u64 i = 2; i * i <= m; ++i) {
if (m % i == 0) {
ds[idx++] = i;
while (m % i == 0) m /= i;
}
}
if (m != 1) ds[idx++] = m;
int _pr = 2;
for (;;) {
int flg = 1;
for (int i = 0; i < idx; ++i) {
u64 a = _pr, b = (_mod - 1) / ds[i], r = 1;
for (; b; a = a * a % _mod, b /= 2) {
if (b % 2 == 1) r = r * a % _mod;
}
if (r == 1) {
flg = 0;
break;
}
}
if (flg == 1) break;
++_pr;
}
return _pr;
};
static constexpr int rt = getRoot();
void dft(std::vector<Z> &a) {
int n = a.size();
if (int(rev.size()) != n) {
int k = __builtin_ctz(n) - 1;
rev.resize(n);
for (int i = 0; i < n; i++) {
rev[i] = rev[i >> 1] >> 1 | (i & 1) << k;
}
}
for (int i = 0; i < n; i++) {
if (rev[i] < i) {
std::swap(a[i], a[rev[i]]);
}
}
if (int(roots.size()) < n) {
int k = __builtin_ctz(roots.size());
roots.resize(n);
while ((1 << k) < n) {
Z e = power(Z(rt), (Z::get_mod() - 1) >> (k + 1));
for (int i = 1 << (k - 1); i < (1 << k); i++) {
roots[2 * i] = roots[i];
roots[2 * i + 1] = roots[i] * e;
}
k++;
}
}
for (int k = 1; k < n; k *= 2) {
for (int i = 0; i < n; i += 2 * k) {
for (int j = 0; j < k; j++) {
Z u = a[i + j];
Z v = a[i + j + k] * roots[k + j];
a[i + j] = u + v;
a[i + j + k] = u - v;
}
}
}
}
void idft(std::vector<Z> &a) {
int n = a.size();
reverse(a.begin() + 1, a.end());
dft(a);
Z inv = (1 - Z::get_mod()) / n;
for (int i = 0; i < n; i++) {
a[i] *= inv;
}
}
std::vector<Z> multiply(std::vector<Z> a, std::vector<Z> b) {
int sz = 1, tot = a.size() + b.size() - 1;
if (tot <= 20) {
std::vector<Z> ret(tot);
for (size_t i = 0; i < a.size(); i++)
for (size_t j = 0; j < b.size(); j++) ret[i + j] += a[i] * b[j];
return ret;
}
while (sz < tot) {
sz *= 2;
}
a.resize(sz), b.resize(sz);
dft(a), dft(b);
for (int i = 0; i < sz; ++i) {
a[i] = a[i] * b[i];
}
idft(a);
a.resize(tot);
return a;
}
};
#line 2 "Polynomial/NttLarge.hpp"
template <class Z>
std::vector<Z> multiplyLarge(const std::vector<Z>& v1,
const std::vector<Z>& v2) {
using FPS = std::vector<Z>;
static NTT<Z> ntt;
const int L = 1 << 23;
if (v1.size() == 0 or v2.size() == 0 or v1.size() + v2.size() - 1 <= L)
return ntt.multiply(v1, v2);
std::vector<FPS> v1s, v2s;
for (int i = 0; i < v1.size(); i += L / 2) {
FPS nxt(
{v1.begin() + i, v1.begin() + std::min(i + L / 2, int(v1.size()))});
nxt.resize(L);
ntt.dft(nxt);
v1s.push_back(nxt);
}
for (int i = 0; i < v2.size(); i += L / 2) {
FPS nxt(
{v2.begin() + i, v2.begin() + std::min(i + L / 2, int(v2.size()))});
nxt.resize(L);
ntt.dft(nxt);
v2s.push_back(nxt);
}
std::vector<FPS> cs(v1s.size() + v2s.size() - 1, FPS(L));
for (int x = 0; x < v1s.size(); ++x)
for (int y = 0; y < v2s.size(); ++y)
for (int i = 0; i < L; ++i) cs[x + y][i] += v1s[x][i] * v2s[y][i];
for (int i = 0; i < v1s.size() + v2s.size() - 1; ++i) {
ntt.idft(cs[i]);
}
FPS ret(v1.size() + v2.size() - 1);
for (int i = 0; i < cs.size(); ++i) {
for (int j = 0; j < L; ++j) {
int pos = i * L / 2 + j;
if (pos >= ret.size()) break;
ret[pos] += cs[i][j];
}
}
return ret;
}