CPLibrary

This documentation is automatically generated by competitive-verifier/competitive-verifier

View the Project on GitHub o06660o/CPLibrary

:heavy_check_mark: src/math/ntt.hpp

Depends on

Verified with

Code

#include "modint.hpp"
template <int MOD = 998244353, int G = 114514, int GI = 137043501>
struct Poly {
  vector<mint> data;
  friend Poly operator*(const Poly& a, const Poly& b) {
    int n = 1, new_sz = a.data.size() + b.data.size() - 1;
    while (n < int(a.data.size() + b.data.size())) n *= 2;
    vector<mint> A = a.data, B = b.data;
    A.resize(n), B.resize(n);
    ntt(A, 1), ntt(B, 1);
    for (int i = 0; i < n; i++) A[i] *= B[i];
    ntt(A, -1), A.resize(new_sz);
    return {A};
  }

 private:
  static void ntt(vector<mint>& a, int type) {
    int len = a.size();
    vector<int> rev(len);
    for (int i = 0; i < len; i++) {
      rev[i] = rev[i >> 1] >> 1;
      if (i & 1) rev[i] |= len >> 1;
    }
    for (int i = 0; i < len; i++)
      if (i < rev[i]) swap(a[i], a[rev[i]]);
    for (int n = 2, m = 1; n <= len; n *= 2, m *= 2) {
      mint step = mint(type == 1 ? G : GI).pow((MOD - 1) / n);
      for (int j = 0; j < len; j += n) {
        mint w = 1;
        for (int k = j; k < j + m; k++) {
          mint u = a[k], v = w * a[k + m];
          a[k] = u + v, a[k + m] = u - v, w *= step;
        }
      }
    }
    if (type == -1) {
      mint inv = mint(len).pow(MOD - 2);
      for (int i = 0; i < len; i++) a[i] *= inv;
    }
  }
};
#line 1 "src/math/modint.hpp"
template <unsigned MOD>
struct ModInt {
  unsigned data;
  ModInt(ll v = 0) : data(norm(v % MOD)) {}
  ModInt operator-() const { return MOD - data; }
  ModInt& operator+=(ModInt rhs) { return data = norm(data + rhs.data), *this; }
  ModInt& operator-=(ModInt rhs) { return data = norm(data - rhs.data), *this; }
  ModInt& operator*=(ModInt rhs) {
    return data = ull(data) * rhs.data % MOD, *this;
  }
  ModInt& operator/=(ModInt rhs) {
    return data = ull(data) * rhs.inv() % MOD, *this;
  }
  friend ModInt operator+(ModInt lhs, ModInt rhs) { return lhs += rhs; }
  friend ModInt operator-(ModInt lhs, ModInt rhs) { return lhs -= rhs; }
  friend ModInt operator*(ModInt lhs, ModInt rhs) { return lhs *= rhs; }
  friend ModInt operator/(ModInt lhs, ModInt rhs) { return lhs /= rhs; }
  unsigned inv() const {
    ll x, y;  // Inverse does not exist if gcd(data, MOD) != 1.
    assert(exgcd(data, MOD, x, y) == 1);
    return norm(x);
  }
  ModInt pow(ull n) const { return pow_mod(data, n, MOD); }
  static ll exgcd(ll a, ll b, ll& x, ll& y) {
    x = 1, y = 0;
    ll x1 = 0, y1 = 1;
    while (b) {
      ll q = a / b;
      swap(a -= q * b, b);
      swap(x -= q * x1, x1);
      swap(y -= q * y1, y1);
    }
    return a;
  }
  static unsigned pow_mod(unsigned a, ull n, unsigned p) {
    unsigned ret = 1;
    for (; n; n /= 2) {
      if (n & 1) ret = ull(ret) * a % p;
      a = ull(a) * a % p;
    }
    return ret;
  }

 private:
  static unsigned norm(unsigned x) {
    if ((x >> (8 * sizeof(unsigned) - 1)) & 1) x += MOD;
    return x >= MOD ? x -= MOD : x;
  }
};
constexpr unsigned MOD = 998244353;
using mint = ModInt<MOD>;
#line 2 "src/math/ntt.hpp"
template <int MOD = 998244353, int G = 114514, int GI = 137043501>
struct Poly {
  vector<mint> data;
  friend Poly operator*(const Poly& a, const Poly& b) {
    int n = 1, new_sz = a.data.size() + b.data.size() - 1;
    while (n < int(a.data.size() + b.data.size())) n *= 2;
    vector<mint> A = a.data, B = b.data;
    A.resize(n), B.resize(n);
    ntt(A, 1), ntt(B, 1);
    for (int i = 0; i < n; i++) A[i] *= B[i];
    ntt(A, -1), A.resize(new_sz);
    return {A};
  }

 private:
  static void ntt(vector<mint>& a, int type) {
    int len = a.size();
    vector<int> rev(len);
    for (int i = 0; i < len; i++) {
      rev[i] = rev[i >> 1] >> 1;
      if (i & 1) rev[i] |= len >> 1;
    }
    for (int i = 0; i < len; i++)
      if (i < rev[i]) swap(a[i], a[rev[i]]);
    for (int n = 2, m = 1; n <= len; n *= 2, m *= 2) {
      mint step = mint(type == 1 ? G : GI).pow((MOD - 1) / n);
      for (int j = 0; j < len; j += n) {
        mint w = 1;
        for (int k = j; k < j + m; k++) {
          mint u = a[k], v = w * a[k + m];
          a[k] = u + v, a[k + m] = u - v, w *= step;
        }
      }
    }
    if (type == -1) {
      mint inv = mint(len).pow(MOD - 2);
      for (int i = 0; i < len; i++) a[i] *= inv;
    }
  }
};
Back to top page