Yuri3's Code Library

Some Code Template Just for Fun.

View the Project on GitHub Yuri3-xr/CP-library

:heavy_check_mark: Verify/SumofTotientFunction2.test.cpp

Depends on

Code

#define PROBLEM "https://judge.yosupo.jp/problem/sum_of_totient_function"

#include "../ModInt/Modint32.hpp"
#include "../Number_Theory/MultiplicativeFunctionPSumFast.hpp"
constexpr int mod = 998244353;
using Z = mint<mod>;

Z f(i64 p, i64 c) { return power(p, c) - power(p, c - 1); }
int main() {
    i64 n;
    std::cin >> n;

    MfPrefixSum<Z, f> mf(n);

    auto psum = mf.prime_sum_table(1);
    auto ptable = mf.pi_table();
    for (int i = 0; i < int(psum.size()); i++) {
        psum[i] = psum[i] - ptable[i];
    }

    auto ans = mf.run(psum);

    std::cout << ans << std::endl;

    return 0;
}
#line 1 "Verify/SumofTotientFunction2.test.cpp"
#define PROBLEM "https://judge.yosupo.jp/problem/sum_of_totient_function"

#line 2 "ModInt/Modint32.hpp"

#line 2 "Template/Template.hpp"

#include <bits/stdc++.h>

using i64 = std::int64_t;
#line 4 "ModInt/Modint32.hpp"

template <int mod>
struct mint {
    int x;
    mint() : x(0) {}
    mint(int64_t y) : x(y >= 0 ? y % mod : (mod - (-y) % mod) % mod) {}
    mint &operator+=(const mint &p) {
        if ((x += p.x) >= mod) x -= mod;
        return *this;
    }
    mint &operator-=(const mint &p) {
        if ((x += mod - p.x) >= mod) x -= mod;
        return *this;
    }
    mint &operator*=(const mint &p) {
        x = (int)(1LL * x * p.x % mod);
        return *this;
    }
    mint &operator/=(const mint &p) {
        *this *= p.inverse();
        return *this;
    }
    mint operator-() const { return mint(-x); }
    mint operator+(const mint &p) const { return mint(*this) += p; }
    mint operator-(const mint &p) const { return mint(*this) -= p; }
    mint operator*(const mint &p) const { return mint(*this) *= p; }
    mint operator/(const mint &p) const { return mint(*this) /= p; }
    bool operator==(const mint &p) const { return x == p.x; }
    bool operator!=(const mint &p) const { return x != p.x; }
    mint inverse() const {
        int a = x, b = mod, u = 1, v = 0, t;
        while (b > 0) {
            t = a / b;
            std::swap(a -= t * b, b);
            std::swap(u -= t * v, v);
        }
        return mint(u);
    }
    friend std::ostream &operator<<(std::ostream &os, const mint &p) {
        return os << p.x;
    }
    friend std::istream &operator>>(std::istream &is, mint &a) {
        int64_t t;
        is >> t;
        a = mint<mod>(t);
        return (is);
    }
    int get() const { return x; }
    static constexpr int get_mod() { return mod; }
};
#line 2 "Polynomial/LagrangeInterpolation.hpp"

#line 4 "Polynomial/LagrangeInterpolation.hpp"
template <typename T>
T LagrangeInterpolation(const std::vector<T> &y, i64 x) {
    //(0,y[0]),(1,y[1]),...,(N,y[N])
    int N = (int)y.size() - 1;
    if (x <= N) return y[x];
    T ret = 0;
    std::vector<T> dp(N + 1, 1), pd(N + 1, 1), finv(N + 1, 0);
    T a = x, one = 1;
    finv[N] = T(1);
    for (int i = 1; i <= N; i++) finv[N] *= T(i);
    finv[N] = finv[N].inverse();
    for (int i = N - 1; i >= 0; i--) finv[i] = finv[i + 1] * T(i + 1);
    for (int i = 0; i < N; i++) dp[i + 1] = dp[i] * a, a -= one;
    for (int i = N; i > 0; i--) pd[i - 1] = pd[i] * a, a += one;
    for (int i = 0; i <= N; i++) {
        T tmp = y[i] * dp[i] * pd[i] * finv[i] * finv[N - i];
        ret += ((N - i) & 1) ? -tmp : tmp;
    }
    return ret;
}
#line 1 "Template/Power.hpp"
template <class T>
T power(T a, int b) {
    T res = 1;
    for (; b; b /= 2, a *= a) {
        if (b % 2) {
            res *= a;
        }
    }
    return res;
}
#line 2 "Number_Theory/Prime_Sieve.hpp"

#line 4 "Number_Theory/Prime_Sieve.hpp"

std::vector<int> prime_sieve(int N) {
    std::vector<bool> sieve(N / 3 + 1, 1);
    for (int p = 5, d = 4, i = 1, sqn = sqrt(N); p <= sqn;
         p += d = 6 - d, i++) {
        if (!sieve[i]) continue;
        for (int q = p * p / 3, r = d * p / 3 + (d * p % 3 == 2), s = 2 * p,
                 qe = sieve.size();
             q < qe; q += r = s - r)
            sieve[q] = 0;
    }
    std::vector<int> ret{2, 3};
    for (int p = 5, d = 4, i = 1; p <= N; p += d = 6 - d, i++)
        if (sieve[i]) ret.push_back(p);
    while (!ret.empty() && ret.back() > N) ret.pop_back();
    return ret;
}
#line 6 "Number_Theory/MultiplicativeFunctionPSumFast.hpp"

template <typename T, T (*f)(i64, i64)>
struct MfPrefixSum {
    i64 M, sq, s;
    std::vector<int> p;
    int ps;
    std::vector<T> buf;
    T ans;

    MfPrefixSum(i64 m) : M(m) {
        assert(m < (1LL << 42));
        sq = sqrt(M);
        while (sq * sq > M) sq--;
        while ((sq + 1) * (sq + 1) <= M) sq++;

        if (M != 0) {
            i64 hls = md(M, sq);
            if (hls != 1 && md(M, hls - 1) == sq) hls--;
            s = hls + sq;

            p = prime_sieve(sq);
            ps = p.size();
            ans = T{};
        }
    }
    T PSumPower(i64 n, int k) {
        std::vector<T> now(k + 2);
        now[0] = T(0);
        for (int i = 1; i < k + 2; i++) {
            T res = i;
            now[i] = now[i - 1] + power(res, k);
        }
        return LagrangeInterpolation<T>(now, n);
    }
    std::vector<T> pi_table() {
        //\sum_{p\in prime \and p\leq m} p^0
        if (M == 0) return {};
        i64 hls = md(M, sq);
        if (hls != 1 && md(M, hls - 1) == sq) hls--;

        std::vector<i64> hl(hls);
        for (int i = 1; i < hls; i++) hl[i] = md(M, i) - 1;

        std::vector<int> hs(sq + 1);
        std::iota(begin(hs), end(hs), -1);

        int pi = 0;
        for (auto &x : p) {
            i64 x2 = i64(x) * x;
            i64 imax = std::min<i64>(hls, md(M, x2) + 1);
            for (i64 i = 1, ix = x; i < imax; ++i, ix += x) {
                hl[i] -= (ix < hls ? hl[ix] : hs[md(M, ix)]) - pi;
            }
            for (int n = sq; n >= x2; n--) hs[n] -= hs[md(n, x)] - pi;
            pi++;
        }

        std::vector<T> res;
        res.reserve(2 * sq + 10);
        for (auto &x : hl) res.push_back(x);
        for (int i = hs.size(); --i;) res.push_back(hs[i]);
        assert((int)res.size() == s);
        return res;
    }
    std::vector<T> prime_sum_table(int k) {
        //\sum_{p\in prime \and p\leq m} p^k
        if (M == 0) return {};
        i64 hls = md(M, sq);
        if (hls != 1 && md(M, hls - 1) == sq) hls--;

        std::vector<T> h(s);
        T inv2 = T{2}.inverse();
        for (int i = 1; i < hls; i++) {
            T x = md(M, i);
            h[i] = PSumPower(x.get(), k) - 1;
        }
        for (int i = 1; i <= sq; i++) {
            T x = i;
            h[s - i] = PSumPower(x.get(), k) - 1;
        }

        for (auto &x : p) {
            T xt = x;
            xt = power(xt, k);
            T pi = h[s - x + 1];
            i64 x2 = i64(x) * x;
            i64 imax = std::min<i64>(hls, md(M, x2) + 1);
            i64 ix = x;
            for (i64 i = 1; i < imax; ++i, ix += x) {
                h[i] -= ((ix < hls ? h[ix] : h[s - md(M, ix)]) - pi) * xt;
            }
            for (int n = sq; n >= x2; n--) {
                h[s - n] -= (h[s - md(n, x)] - pi) * xt;
            }
        }

        assert((int)h.size() == s);
        return h;
    }

    void dfs(int i, int c, i64 prod, T cur) {
        ans += cur * f(p[i], c + 1);
        i64 lim = md(M, prod);
        if (lim >= 1LL * p[i] * p[i]) dfs(i, c + 1, p[i] * prod, cur);
        cur *= f(p[i], c);
        ans += cur * (buf[idx(lim)] - buf[idx(p[i])]);
        int j = i + 1;
        // M < 2**42 -> p_j < 2**21 -> (p_j)^3 < 2**63
        for (; j < ps && 1LL * p[j] * p[j] * p[j] <= lim; j++) {
            dfs(j, 1, prod * p[j], cur);
        }
        for (; j < ps && 1LL * p[j] * p[j] <= lim; j++) {
            T sm = f(p[j], 2);
            int id1 = idx(md(lim, p[j])), id2 = idx(p[j]);
            sm += f(p[j], 1) * (buf[id1] - buf[id2]);
            ans += cur * sm;
        }
    }

    T run(std::vector<T> &fprime) {
        if (M == 0) return {};
        set_buf(fprime);
        assert((int)buf.size() == s);
        ans = buf[idx(M)] + 1;
        for (int i = 0; i < ps; i++) dfs(i, 1, p[i], 1);
        return ans;
    }

   private:
    i64 md(i64 n, i64 d) { return double(n) / d; }
    i64 idx(i64 n) { return n <= sq ? s - n : md(M, n); }
    void set_buf(std::vector<T> &_buf) { swap(buf, _buf); }
};
#line 5 "Verify/SumofTotientFunction2.test.cpp"
constexpr int mod = 998244353;
using Z = mint<mod>;

Z f(i64 p, i64 c) { return power(p, c) - power(p, c - 1); }
int main() {
    i64 n;
    std::cin >> n;

    MfPrefixSum<Z, f> mf(n);

    auto psum = mf.prime_sum_table(1);
    auto ptable = mf.pi_table();
    for (int i = 0; i < int(psum.size()); i++) {
        psum[i] = psum[i] - ptable[i];
    }

    auto ans = mf.run(psum);

    std::cout << ans << std::endl;

    return 0;
}
Back to top page