Yuri3's Code Library

Some Code Template Just for Fun.

View the Project on GitHub Yuri3-xr/CP-library

:heavy_check_mark: Verify/ConvolutionLarge.test.cpp

Depends on

Code

#define PROBLEM "https://judge.yosupo.jp/problem/convolution_mod_large"

#include "../ModInt/Modint32.hpp"
#include "../Polynomial/NttLarge.hpp"

int main() {
    constexpr int P = 998244353;
    using Z = mint<P>;

    int n, m;
    std::cin >> n >> m;

    std::vector<Z> a(n), b(m);

    for (int i = 0; i < n; i++) std::cin >> a[i];
    for (int i = 0; i < m; i++) std::cin >> b[i];

    auto ans = multiplyLarge(a, b);

    for (int i = 0; i < n + m - 1; i++) {
        std::cout << ans[i] << " \n"[i == n + m - 2];
    }

    return 0;
}
#line 1 "Verify/ConvolutionLarge.test.cpp"
#define PROBLEM "https://judge.yosupo.jp/problem/convolution_mod_large"

#line 2 "ModInt/Modint32.hpp"

#line 2 "Template/Template.hpp"

#include <bits/stdc++.h>

using i64 = std::int64_t;
#line 4 "ModInt/Modint32.hpp"

template <int mod>
struct mint {
    int x;
    mint() : x(0) {}
    mint(int64_t y) : x(y >= 0 ? y % mod : (mod - (-y) % mod) % mod) {}
    mint &operator+=(const mint &p) {
        if ((x += p.x) >= mod) x -= mod;
        return *this;
    }
    mint &operator-=(const mint &p) {
        if ((x += mod - p.x) >= mod) x -= mod;
        return *this;
    }
    mint &operator*=(const mint &p) {
        x = (int)(1LL * x * p.x % mod);
        return *this;
    }
    mint &operator/=(const mint &p) {
        *this *= p.inverse();
        return *this;
    }
    mint operator-() const { return mint(-x); }
    mint operator+(const mint &p) const { return mint(*this) += p; }
    mint operator-(const mint &p) const { return mint(*this) -= p; }
    mint operator*(const mint &p) const { return mint(*this) *= p; }
    mint operator/(const mint &p) const { return mint(*this) /= p; }
    bool operator==(const mint &p) const { return x == p.x; }
    bool operator!=(const mint &p) const { return x != p.x; }
    mint inverse() const {
        int a = x, b = mod, u = 1, v = 0, t;
        while (b > 0) {
            t = a / b;
            std::swap(a -= t * b, b);
            std::swap(u -= t * v, v);
        }
        return mint(u);
    }
    friend std::ostream &operator<<(std::ostream &os, const mint &p) {
        return os << p.x;
    }
    friend std::istream &operator>>(std::istream &is, mint &a) {
        int64_t t;
        is >> t;
        a = mint<mod>(t);
        return (is);
    }
    int get() const { return x; }
    static constexpr int get_mod() { return mod; }
};
#line 2 "Polynomial/Ntt.hpp"

#line 1 "Template/Power.hpp"
template <class T>
T power(T a, int b) {
    T res = 1;
    for (; b; b /= 2, a *= a) {
        if (b % 2) {
            res *= a;
        }
    }
    return res;
}
#line 5 "Polynomial/Ntt.hpp"
template <class Z>
struct NTT {
    std::vector<int> rev;
    std::vector<Z> roots{0, 1};

    static constexpr int getRoot() {
        auto _mod = Z::get_mod();
        using u64 = uint64_t;
        u64 ds[32] = {};
        int idx = 0;
        u64 m = _mod - 1;
        for (u64 i = 2; i * i <= m; ++i) {
            if (m % i == 0) {
                ds[idx++] = i;
                while (m % i == 0) m /= i;
            }
        }
        if (m != 1) ds[idx++] = m;

        int _pr = 2;
        for (;;) {
            int flg = 1;
            for (int i = 0; i < idx; ++i) {
                u64 a = _pr, b = (_mod - 1) / ds[i], r = 1;
                for (; b; a = a * a % _mod, b /= 2) {
                    if (b % 2 == 1) r = r * a % _mod;
                }
                if (r == 1) {
                    flg = 0;
                    break;
                }
            }
            if (flg == 1) break;
            ++_pr;
        }
        return _pr;
    };

    static constexpr int rt = getRoot();

    void dft(std::vector<Z> &a) {
        int n = a.size();

        if (int(rev.size()) != n) {
            int k = __builtin_ctz(n) - 1;
            rev.resize(n);
            for (int i = 0; i < n; i++) {
                rev[i] = rev[i >> 1] >> 1 | (i & 1) << k;
            }
        }

        for (int i = 0; i < n; i++) {
            if (rev[i] < i) {
                std::swap(a[i], a[rev[i]]);
            }
        }
        if (int(roots.size()) < n) {
            int k = __builtin_ctz(roots.size());
            roots.resize(n);
            while ((1 << k) < n) {
                Z e = power(Z(rt), (Z::get_mod() - 1) >> (k + 1));
                for (int i = 1 << (k - 1); i < (1 << k); i++) {
                    roots[2 * i] = roots[i];
                    roots[2 * i + 1] = roots[i] * e;
                }
                k++;
            }
        }
        for (int k = 1; k < n; k *= 2) {
            for (int i = 0; i < n; i += 2 * k) {
                for (int j = 0; j < k; j++) {
                    Z u = a[i + j];
                    Z v = a[i + j + k] * roots[k + j];
                    a[i + j] = u + v;
                    a[i + j + k] = u - v;
                }
            }
        }
    }
    void idft(std::vector<Z> &a) {
        int n = a.size();
        reverse(a.begin() + 1, a.end());
        dft(a);
        Z inv = (1 - Z::get_mod()) / n;
        for (int i = 0; i < n; i++) {
            a[i] *= inv;
        }
    }
    std::vector<Z> multiply(std::vector<Z> a, std::vector<Z> b) {
        int sz = 1, tot = a.size() + b.size() - 1;

        if (tot <= 20) {
            std::vector<Z> ret(tot);
            for (size_t i = 0; i < a.size(); i++)
                for (size_t j = 0; j < b.size(); j++) ret[i + j] += a[i] * b[j];
            return ret;
        }

        while (sz < tot) {
            sz *= 2;
        }

        a.resize(sz), b.resize(sz);
        dft(a), dft(b);

        for (int i = 0; i < sz; ++i) {
            a[i] = a[i] * b[i];
        }

        idft(a);
        a.resize(tot);
        return a;
    }
};
#line 2 "Polynomial/NttLarge.hpp"

template <class Z>
std::vector<Z> multiplyLarge(const std::vector<Z>& v1,
                             const std::vector<Z>& v2) {
    using FPS = std::vector<Z>;
    static NTT<Z> ntt;
    const int L = 1 << 23;
    if (v1.size() == 0 or v2.size() == 0 or v1.size() + v2.size() - 1 <= L)
        return ntt.multiply(v1, v2);

    std::vector<FPS> v1s, v2s;
    for (int i = 0; i < v1.size(); i += L / 2) {
        FPS nxt(
            {v1.begin() + i, v1.begin() + std::min(i + L / 2, int(v1.size()))});
        nxt.resize(L);
        ntt.dft(nxt);
        v1s.push_back(nxt);
    }
    for (int i = 0; i < v2.size(); i += L / 2) {
        FPS nxt(
            {v2.begin() + i, v2.begin() + std::min(i + L / 2, int(v2.size()))});
        nxt.resize(L);
        ntt.dft(nxt);
        v2s.push_back(nxt);
    }

    std::vector<FPS> cs(v1s.size() + v2s.size() - 1, FPS(L));
    for (int x = 0; x < v1s.size(); ++x)
        for (int y = 0; y < v2s.size(); ++y)
            for (int i = 0; i < L; ++i) cs[x + y][i] += v1s[x][i] * v2s[y][i];

    for (int i = 0; i < v1s.size() + v2s.size() - 1; ++i) {
        ntt.idft(cs[i]);
    }

    FPS ret(v1.size() + v2.size() - 1);

    for (int i = 0; i < cs.size(); ++i) {
        for (int j = 0; j < L; ++j) {
            int pos = i * L / 2 + j;
            if (pos >= ret.size()) break;
            ret[pos] += cs[i][j];
        }
    }
    return ret;
}
#line 5 "Verify/ConvolutionLarge.test.cpp"

int main() {
    constexpr int P = 998244353;
    using Z = mint<P>;

    int n, m;
    std::cin >> n >> m;

    std::vector<Z> a(n), b(m);

    for (int i = 0; i < n; i++) std::cin >> a[i];
    for (int i = 0; i < m; i++) std::cin >> b[i];

    auto ans = multiplyLarge(a, b);

    for (int i = 0; i < n + m - 1; i++) {
        std::cout << ans[i] << " \n"[i == n + m - 2];
    }

    return 0;
}
Back to top page