Yuri3's Code Library

Some Code Template Just for Fun.

View the Project on GitHub Yuri3-xr/CP-library

:heavy_check_mark: Polynomial/NttLarge.hpp

Depends on

Verified with

Code

#include "Ntt.hpp"

template <class Z>
std::vector<Z> multiplyLarge(const std::vector<Z>& v1,
                             const std::vector<Z>& v2) {
    using FPS = std::vector<Z>;
    static NTT<Z> ntt;
    const int L = 1 << 23;
    if (v1.size() == 0 or v2.size() == 0 or v1.size() + v2.size() - 1 <= L)
        return ntt.multiply(v1, v2);

    std::vector<FPS> v1s, v2s;
    for (int i = 0; i < v1.size(); i += L / 2) {
        FPS nxt(
            {v1.begin() + i, v1.begin() + std::min(i + L / 2, int(v1.size()))});
        nxt.resize(L);
        ntt.dft(nxt);
        v1s.push_back(nxt);
    }
    for (int i = 0; i < v2.size(); i += L / 2) {
        FPS nxt(
            {v2.begin() + i, v2.begin() + std::min(i + L / 2, int(v2.size()))});
        nxt.resize(L);
        ntt.dft(nxt);
        v2s.push_back(nxt);
    }

    std::vector<FPS> cs(v1s.size() + v2s.size() - 1, FPS(L));
    for (int x = 0; x < v1s.size(); ++x)
        for (int y = 0; y < v2s.size(); ++y)
            for (int i = 0; i < L; ++i) cs[x + y][i] += v1s[x][i] * v2s[y][i];

    for (int i = 0; i < v1s.size() + v2s.size() - 1; ++i) {
        ntt.idft(cs[i]);
    }

    FPS ret(v1.size() + v2.size() - 1);

    for (int i = 0; i < cs.size(); ++i) {
        for (int j = 0; j < L; ++j) {
            int pos = i * L / 2 + j;
            if (pos >= ret.size()) break;
            ret[pos] += cs[i][j];
        }
    }
    return ret;
}
#line 2 "Polynomial/Ntt.hpp"

#line 1 "Template/Power.hpp"
template <class T>
T power(T a, int b) {
    T res = 1;
    for (; b; b /= 2, a *= a) {
        if (b % 2) {
            res *= a;
        }
    }
    return res;
}
#line 2 "Template/Template.hpp"

#include <bits/stdc++.h>

using i64 = std::int64_t;
#line 5 "Polynomial/Ntt.hpp"
template <class Z>
struct NTT {
    std::vector<int> rev;
    std::vector<Z> roots{0, 1};

    static constexpr int getRoot() {
        auto _mod = Z::get_mod();
        using u64 = uint64_t;
        u64 ds[32] = {};
        int idx = 0;
        u64 m = _mod - 1;
        for (u64 i = 2; i * i <= m; ++i) {
            if (m % i == 0) {
                ds[idx++] = i;
                while (m % i == 0) m /= i;
            }
        }
        if (m != 1) ds[idx++] = m;

        int _pr = 2;
        for (;;) {
            int flg = 1;
            for (int i = 0; i < idx; ++i) {
                u64 a = _pr, b = (_mod - 1) / ds[i], r = 1;
                for (; b; a = a * a % _mod, b /= 2) {
                    if (b % 2 == 1) r = r * a % _mod;
                }
                if (r == 1) {
                    flg = 0;
                    break;
                }
            }
            if (flg == 1) break;
            ++_pr;
        }
        return _pr;
    };

    static constexpr int rt = getRoot();

    void dft(std::vector<Z> &a) {
        int n = a.size();

        if (int(rev.size()) != n) {
            int k = __builtin_ctz(n) - 1;
            rev.resize(n);
            for (int i = 0; i < n; i++) {
                rev[i] = rev[i >> 1] >> 1 | (i & 1) << k;
            }
        }

        for (int i = 0; i < n; i++) {
            if (rev[i] < i) {
                std::swap(a[i], a[rev[i]]);
            }
        }
        if (int(roots.size()) < n) {
            int k = __builtin_ctz(roots.size());
            roots.resize(n);
            while ((1 << k) < n) {
                Z e = power(Z(rt), (Z::get_mod() - 1) >> (k + 1));
                for (int i = 1 << (k - 1); i < (1 << k); i++) {
                    roots[2 * i] = roots[i];
                    roots[2 * i + 1] = roots[i] * e;
                }
                k++;
            }
        }
        for (int k = 1; k < n; k *= 2) {
            for (int i = 0; i < n; i += 2 * k) {
                for (int j = 0; j < k; j++) {
                    Z u = a[i + j];
                    Z v = a[i + j + k] * roots[k + j];
                    a[i + j] = u + v;
                    a[i + j + k] = u - v;
                }
            }
        }
    }
    void idft(std::vector<Z> &a) {
        int n = a.size();
        reverse(a.begin() + 1, a.end());
        dft(a);
        Z inv = (1 - Z::get_mod()) / n;
        for (int i = 0; i < n; i++) {
            a[i] *= inv;
        }
    }
    std::vector<Z> multiply(std::vector<Z> a, std::vector<Z> b) {
        int sz = 1, tot = a.size() + b.size() - 1;

        if (tot <= 20) {
            std::vector<Z> ret(tot);
            for (size_t i = 0; i < a.size(); i++)
                for (size_t j = 0; j < b.size(); j++) ret[i + j] += a[i] * b[j];
            return ret;
        }

        while (sz < tot) {
            sz *= 2;
        }

        a.resize(sz), b.resize(sz);
        dft(a), dft(b);

        for (int i = 0; i < sz; ++i) {
            a[i] = a[i] * b[i];
        }

        idft(a);
        a.resize(tot);
        return a;
    }
};
#line 2 "Polynomial/NttLarge.hpp"

template <class Z>
std::vector<Z> multiplyLarge(const std::vector<Z>& v1,
                             const std::vector<Z>& v2) {
    using FPS = std::vector<Z>;
    static NTT<Z> ntt;
    const int L = 1 << 23;
    if (v1.size() == 0 or v2.size() == 0 or v1.size() + v2.size() - 1 <= L)
        return ntt.multiply(v1, v2);

    std::vector<FPS> v1s, v2s;
    for (int i = 0; i < v1.size(); i += L / 2) {
        FPS nxt(
            {v1.begin() + i, v1.begin() + std::min(i + L / 2, int(v1.size()))});
        nxt.resize(L);
        ntt.dft(nxt);
        v1s.push_back(nxt);
    }
    for (int i = 0; i < v2.size(); i += L / 2) {
        FPS nxt(
            {v2.begin() + i, v2.begin() + std::min(i + L / 2, int(v2.size()))});
        nxt.resize(L);
        ntt.dft(nxt);
        v2s.push_back(nxt);
    }

    std::vector<FPS> cs(v1s.size() + v2s.size() - 1, FPS(L));
    for (int x = 0; x < v1s.size(); ++x)
        for (int y = 0; y < v2s.size(); ++y)
            for (int i = 0; i < L; ++i) cs[x + y][i] += v1s[x][i] * v2s[y][i];

    for (int i = 0; i < v1s.size() + v2s.size() - 1; ++i) {
        ntt.idft(cs[i]);
    }

    FPS ret(v1.size() + v2.size() - 1);

    for (int i = 0; i < cs.size(); ++i) {
        for (int j = 0; j < L; ++j) {
            int pos = i * L / 2 + j;
            if (pos >= ret.size()) break;
            ret[pos] += cs[i][j];
        }
    }
    return ret;
}
Back to top page