diff --git a/atcoder/convolution.hpp b/atcoder/convolution.hpp index c01956c..ce3d272 100644 --- a/atcoder/convolution.hpp +++ b/atcoder/convolution.hpp @@ -14,95 +14,169 @@ namespace atcoder { namespace internal { +template , + internal::is_static_modint_t* = nullptr> +struct fft_info { + static constexpr int rank2 = bsf_constexpr(mint::mod() - 1); + std::array root; // root[i]^(2^i) == 1 + std::array iroot; // root[i] * iroot[i] == 1 + + std::array rate2; + std::array irate2; + + std::array rate3; + std::array irate3; + + fft_info() { + root[rank2] = mint(g).pow((mint::mod() - 1) >> rank2); + iroot[rank2] = root[rank2].inv(); + for (int i = rank2 - 1; i >= 0; i--) { + root[i] = root[i + 1] * root[i + 1]; + iroot[i] = iroot[i + 1] * iroot[i + 1]; + } + + { + mint prod = 1, iprod = 1; + for (int i = 0; i <= rank2 - 2; i++) { + rate2[i] = root[i + 2] * prod; + irate2[i] = iroot[i + 2] * iprod; + prod *= iroot[i + 2]; + iprod *= root[i + 2]; + } + } + { + mint prod = 1, iprod = 1; + for (int i = 0; i <= rank2 - 3; i++) { + rate3[i] = root[i + 3] * prod; + irate3[i] = iroot[i + 3] * iprod; + prod *= iroot[i + 3]; + iprod *= root[i + 3]; + } + } + } +}; + template * = nullptr> void butterfly(std::vector& a) { - static constexpr int g = internal::primitive_root; int n = int(a.size()); int h = internal::ceil_pow2(n); - static bool first = true; - static mint sum_e[30]; // sum_e[i] = ies[0] * ... * ies[i - 1] * es[i] - if (first) { - first = false; - mint es[30], ies[30]; // es[i]^(2^(2+i)) == 1 - int cnt2 = bsf(mint::mod() - 1); - mint e = mint(g).pow((mint::mod() - 1) >> cnt2), ie = e.inv(); - for (int i = cnt2; i >= 2; i--) { - // e^(2^i) == 1 - es[i - 2] = e; - ies[i - 2] = ie; - e *= e; - ie *= ie; - } - mint now = 1; - for (int i = 0; i <= cnt2 - 2; i++) { - sum_e[i] = es[i] * now; - now *= ies[i]; - } - } - for (int ph = 1; ph <= h; ph++) { - int w = 1 << (ph - 1), p = 1 << (h - ph); - mint now = 1; - for (int s = 0; s < w; s++) { - int offset = s << (h - ph + 1); - for (int i = 0; i < p; i++) { - auto l = a[i + offset]; - auto r = a[i + offset + p] * now; - a[i + offset] = l + r; - a[i + offset + p] = l - r; + static const fft_info info; + + int len = 0; // a[i, i+(n>>len), i+2*(n>>len), ..] is transformed + while (len < h) { + if (h - len == 1) { + int p = 1 << (h - len - 1); + mint rot = 1; + for (int s = 0; s < (1 << len); s++) { + int offset = s << (h - len); + for (int i = 0; i < p; i++) { + auto l = a[i + offset]; + auto r = a[i + offset + p] * rot; + a[i + offset] = l + r; + a[i + offset + p] = l - r; + } + if (s + 1 != (1 << len)) + rot *= info.rate2[bsf(~(unsigned int)(s))]; + } + len++; + } else { + // 4-base + int p = 1 << (h - len - 2); + mint rot = 1, imag = info.root[2]; + for (int s = 0; s < (1 << len); s++) { + mint rot2 = rot * rot; + mint rot3 = rot2 * rot; + int offset = s << (h - len); + for (int i = 0; i < p; i++) { + auto mod2 = 1ULL * mint::mod() * mint::mod(); + auto a0 = 1ULL * a[i + offset].val(); + auto a1 = 1ULL * a[i + offset + p].val() * rot.val(); + auto a2 = 1ULL * a[i + offset + 2 * p].val() * rot2.val(); + auto a3 = 1ULL * a[i + offset + 3 * p].val() * rot3.val(); + auto a1na3imag = + 1ULL * mint(a1 + mod2 - a3).val() * imag.val(); + auto na2 = mod2 - a2; + a[i + offset] = a0 + a2 + a1 + a3; + a[i + offset + 1 * p] = a0 + a2 + (2 * mod2 - (a1 + a3)); + a[i + offset + 2 * p] = a0 + na2 + a1na3imag; + a[i + offset + 3 * p] = a0 + na2 + (mod2 - a1na3imag); + } + if (s + 1 != (1 << len)) + rot *= info.rate3[bsf(~(unsigned int)(s))]; } - now *= sum_e[bsf(~(unsigned int)(s))]; + len += 2; } } } template * = nullptr> void butterfly_inv(std::vector& a) { - static constexpr int g = internal::primitive_root; int n = int(a.size()); int h = internal::ceil_pow2(n); - static bool first = true; - static mint sum_ie[30]; // sum_ie[i] = es[0] * ... * es[i - 1] * ies[i] - if (first) { - first = false; - mint es[30], ies[30]; // es[i]^(2^(2+i)) == 1 - int cnt2 = bsf(mint::mod() - 1); - mint e = mint(g).pow((mint::mod() - 1) >> cnt2), ie = e.inv(); - for (int i = cnt2; i >= 2; i--) { - // e^(2^i) == 1 - es[i - 2] = e; - ies[i - 2] = ie; - e *= e; - ie *= ie; - } - mint now = 1; - for (int i = 0; i <= cnt2 - 2; i++) { - sum_ie[i] = ies[i] * now; - now *= es[i]; - } - } + static const fft_info info; + + int len = h; // a[i, i+(n>>len), i+2*(n>>len), ..] is transformed + while (len) { + if (len == 1) { + int p = 1 << (h - len); + mint irot = 1; + for (int s = 0; s < (1 << (len - 1)); s++) { + int offset = s << (h - len + 1); + for (int i = 0; i < p; i++) { + auto l = a[i + offset]; + auto r = a[i + offset + p]; + a[i + offset] = l + r; + a[i + offset + p] = + (unsigned long long)(mint::mod() + l.val() - r.val()) * + irot.val(); + ; + } + if (s + 1 != (1 << (len - 1))) + irot *= info.irate2[bsf(~(unsigned int)(s))]; + } + len--; + } else { + // 4-base + int p = 1 << (h - len); + mint irot = 1, iimag = info.iroot[2]; + for (int s = 0; s < (1 << (len - 2)); s++) { + mint irot2 = irot * irot; + mint irot3 = irot2 * irot; + int offset = s << (h - len + 2); + for (int i = 0; i < p; i++) { + auto a0 = 1ULL * a[i + offset + 0 * p].val(); + auto a1 = 1ULL * a[i + offset + 1 * p].val(); + auto a2 = 1ULL * a[i + offset + 2 * p].val(); + auto a3 = 1ULL * a[i + offset + 3 * p].val(); + + auto a2na3iimag = + 1ULL * + mint((mint::mod() + a2 - a3) * iimag.val()).val(); - for (int ph = h; ph >= 1; ph--) { - int w = 1 << (ph - 1), p = 1 << (h - ph); - mint inow = 1; - for (int s = 0; s < w; s++) { - int offset = s << (h - ph + 1); - for (int i = 0; i < p; i++) { - auto l = a[i + offset]; - auto r = a[i + offset + p]; - a[i + offset] = l + r; - a[i + offset + p] = - (unsigned long long)(mint::mod() + l.val() - r.val()) * - inow.val(); + a[i + offset] = a0 + a1 + a2 + a3; + a[i + offset + 1 * p] = + (a0 + (mint::mod() - a1) + a2na3iimag) * irot.val(); + a[i + offset + 2 * p] = + (a0 + a1 + (mint::mod() - a2) + (mint::mod() - a3)) * + irot2.val(); + a[i + offset + 3 * p] = + (a0 + (mint::mod() - a1) + (mint::mod() - a2na3iimag)) * + irot3.val(); + } + if (s + 1 != (1 << (len - 2))) + irot *= info.irate3[bsf(~(unsigned int)(s))]; } - inow *= sum_ie[bsf(~(unsigned int)(s))]; + len -= 2; } } } template * = nullptr> -std::vector convolution_naive(const std::vector& a, const std::vector& b) { +std::vector convolution_naive(const std::vector& a, + const std::vector& b) { int n = int(a.size()), m = int(b.size()); std::vector ans(n + m - 1); if (n < m) { @@ -150,7 +224,8 @@ std::vector convolution(std::vector&& a, std::vector&& b) { } template * = nullptr> -std::vector convolution(const std::vector& a, const std::vector& b) { +std::vector convolution(const std::vector& a, + const std::vector& b) { int n = int(a.size()), m = int(b.size()); if (!n || !m) return {}; if (std::min(n, m) <= 60) return convolution_naive(a, b); diff --git a/atcoder/internal_bit.hpp b/atcoder/internal_bit.hpp index d219b0f..ada311a 100644 --- a/atcoder/internal_bit.hpp +++ b/atcoder/internal_bit.hpp @@ -17,6 +17,14 @@ int ceil_pow2(int n) { return x; } +// @param n `1 <= n` +// @return minimum non-negative `x` s.t. `(n & (1 << x)) != 0` +constexpr int bsf_constexpr(unsigned int n) { + int x = 0; + while (!(n & (1 << x))) x++; + return x; +} + // @param n `1 <= n` // @return minimum non-negative `x` s.t. `(n & (1 << x)) != 0` int bsf(unsigned int n) { diff --git a/test/unittest/convolution_test.cpp b/test/unittest/convolution_test.cpp index cf8f55a..d1648f9 100644 --- a/test/unittest/convolution_test.cpp +++ b/test/unittest/convolution_test.cpp @@ -385,3 +385,47 @@ TEST(ConvolutionTest, Conv18433) { ASSERT_EQ(conv_naive(a, b), convolution(a, b)); } + +TEST(ConvolutionTest, Conv2) { + std::vector empty = {}; + ASSERT_EQ(empty, convolution<2>(empty, empty)); +} + +TEST(ConvolutionTest, Conv257) { + const int MOD = 257; + std::vector a(128), b(129); + for (int i = 0; i < 128; i++) { + a[i] = randint(0, MOD - 1); + } + for (int i = 0; i < 129; i++) { + b[i] = randint(0, MOD - 1); + } + + ASSERT_EQ(conv_naive(a, b), convolution(a, b)); +} + +TEST(ConvolutionTest, Conv2147483647) { + const int MOD = 2147483647; + using mint = static_modint; + std::vector a(1), b(2); + for (int i = 0; i < 1; i++) { + a[i] = randint(0, MOD - 1); + } + for (int i = 0; i < 2; i++) { + b[i] = randint(0, MOD - 1); + } + ASSERT_EQ(conv_naive(a, b), convolution(a, b)); +} + +TEST(ConvolutionTest, Conv2130706433) { + const int MOD = 2130706433; + using mint = static_modint; + std::vector a(1024), b(1024); + for (int i = 0; i < 1024; i++) { + a[i] = randint(0, MOD - 1); + } + for (int i = 0; i < 1024; i++) { + b[i] = randint(0, MOD - 1); + } + ASSERT_EQ(conv_naive(a, b), convolution(a, b)); +} diff --git a/test/unittest/modint_test.cpp b/test/unittest/modint_test.cpp index 81f33cc..222c1bd 100644 --- a/test/unittest/modint_test.cpp +++ b/test/unittest/modint_test.cpp @@ -100,6 +100,29 @@ TEST(ModintTest, Mod1) { ASSERT_EQ(0, mint(true).val()); } +TEST(ModintTest, ModIntMax) { + modint::set_mod(INT32_MAX); + for (int i = 0; i < 100; i++) { + for (int j = 0; j < 100; j++) { + ASSERT_EQ((modint(i) * modint(j)).val(), i * j); + } + } + ASSERT_EQ((modint(1234) + modint(5678)).val(), 1234 + 5678); + ASSERT_EQ((modint(1234) - modint(5678)).val(), INT32_MAX - 5678 + 1234); + ASSERT_EQ((modint(1234) * modint(5678)).val(), 1234 * 5678); + + using mint = static_modint; + for (int i = 0; i < 100; i++) { + for (int j = 0; j < 100; j++) { + ASSERT_EQ((mint(i) * mint(j)).val(), i * j); + } + } + ASSERT_EQ((mint(1234) + mint(5678)).val(), 1234 + 5678); + ASSERT_EQ((mint(1234) - mint(5678)).val(), INT32_MAX - 5678 + 1234); + ASSERT_EQ((mint(1234) * mint(5678)).val(), 1234 * 5678); + ASSERT_EQ((mint(INT32_MAX) + mint(INT32_MAX)).val(), 0); +} + #ifndef _MSC_VER TEST(ModintTest, Int128) { @@ -158,6 +181,13 @@ TEST(ModintTest, Inv) { int x = modint(i).inv().val(); ASSERT_EQ(1, (ll(x) * i) % 1'000'000'008); } + + modint::set_mod(INT32_MAX); + for (int i = 1; i < 100000; i++) { + if (gcd(i, INT32_MAX) != 1) continue; + int x = modint(i).inv().val(); + ASSERT_EQ(1, (ll(x) * i) % INT32_MAX); + } } TEST(ModintTest, ConstUsage) {