Skip to content

Commit

Permalink
Merge pull request #120 from atcoder/fft-optimize
Browse files Browse the repository at this point in the history
Fft optimize
  • Loading branch information
yosupo06 authored Jul 19, 2021
2 parents db08263 + f841a34 commit 89d5d0a
Show file tree
Hide file tree
Showing 4 changed files with 225 additions and 68 deletions.
211 changes: 143 additions & 68 deletions atcoder/convolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,95 +14,169 @@ namespace atcoder {

namespace internal {

template <class mint,
int g = internal::primitive_root<mint::mod()>,
internal::is_static_modint_t<mint>* = nullptr>
struct fft_info {
static constexpr int rank2 = bsf_constexpr(mint::mod() - 1);
std::array<mint, rank2 + 1> root; // root[i]^(2^i) == 1
std::array<mint, rank2 + 1> iroot; // root[i] * iroot[i] == 1

std::array<mint, std::max(0, rank2 - 2 + 1)> rate2;
std::array<mint, std::max(0, rank2 - 2 + 1)> irate2;

std::array<mint, std::max(0, rank2 - 3 + 1)> rate3;
std::array<mint, std::max(0, rank2 - 3 + 1)> irate3;

fft_info() {
root[rank2] = mint(g).pow((mint::mod() - 1) >> rank2);
iroot[rank2] = root[rank2].inv();
for (int i = rank2 - 1; i >= 0; i--) {
root[i] = root[i + 1] * root[i + 1];
iroot[i] = iroot[i + 1] * iroot[i + 1];
}

{
mint prod = 1, iprod = 1;
for (int i = 0; i <= rank2 - 2; i++) {
rate2[i] = root[i + 2] * prod;
irate2[i] = iroot[i + 2] * iprod;
prod *= iroot[i + 2];
iprod *= root[i + 2];
}
}
{
mint prod = 1, iprod = 1;
for (int i = 0; i <= rank2 - 3; i++) {
rate3[i] = root[i + 3] * prod;
irate3[i] = iroot[i + 3] * iprod;
prod *= iroot[i + 3];
iprod *= root[i + 3];
}
}
}
};

template <class mint, internal::is_static_modint_t<mint>* = nullptr>
void butterfly(std::vector<mint>& a) {
static constexpr int g = internal::primitive_root<mint::mod()>;
int n = int(a.size());
int h = internal::ceil_pow2(n);

static bool first = true;
static mint sum_e[30]; // sum_e[i] = ies[0] * ... * ies[i - 1] * es[i]
if (first) {
first = false;
mint es[30], ies[30]; // es[i]^(2^(2+i)) == 1
int cnt2 = bsf(mint::mod() - 1);
mint e = mint(g).pow((mint::mod() - 1) >> cnt2), ie = e.inv();
for (int i = cnt2; i >= 2; i--) {
// e^(2^i) == 1
es[i - 2] = e;
ies[i - 2] = ie;
e *= e;
ie *= ie;
}
mint now = 1;
for (int i = 0; i <= cnt2 - 2; i++) {
sum_e[i] = es[i] * now;
now *= ies[i];
}
}
for (int ph = 1; ph <= h; ph++) {
int w = 1 << (ph - 1), p = 1 << (h - ph);
mint now = 1;
for (int s = 0; s < w; s++) {
int offset = s << (h - ph + 1);
for (int i = 0; i < p; i++) {
auto l = a[i + offset];
auto r = a[i + offset + p] * now;
a[i + offset] = l + r;
a[i + offset + p] = l - r;
static const fft_info<mint> info;

int len = 0; // a[i, i+(n>>len), i+2*(n>>len), ..] is transformed
while (len < h) {
if (h - len == 1) {
int p = 1 << (h - len - 1);
mint rot = 1;
for (int s = 0; s < (1 << len); s++) {
int offset = s << (h - len);
for (int i = 0; i < p; i++) {
auto l = a[i + offset];
auto r = a[i + offset + p] * rot;
a[i + offset] = l + r;
a[i + offset + p] = l - r;
}
if (s + 1 != (1 << len))
rot *= info.rate2[bsf(~(unsigned int)(s))];
}
len++;
} else {
// 4-base
int p = 1 << (h - len - 2);
mint rot = 1, imag = info.root[2];
for (int s = 0; s < (1 << len); s++) {
mint rot2 = rot * rot;
mint rot3 = rot2 * rot;
int offset = s << (h - len);
for (int i = 0; i < p; i++) {
auto mod2 = 1ULL * mint::mod() * mint::mod();
auto a0 = 1ULL * a[i + offset].val();
auto a1 = 1ULL * a[i + offset + p].val() * rot.val();
auto a2 = 1ULL * a[i + offset + 2 * p].val() * rot2.val();
auto a3 = 1ULL * a[i + offset + 3 * p].val() * rot3.val();
auto a1na3imag =
1ULL * mint(a1 + mod2 - a3).val() * imag.val();
auto na2 = mod2 - a2;
a[i + offset] = a0 + a2 + a1 + a3;
a[i + offset + 1 * p] = a0 + a2 + (2 * mod2 - (a1 + a3));
a[i + offset + 2 * p] = a0 + na2 + a1na3imag;
a[i + offset + 3 * p] = a0 + na2 + (mod2 - a1na3imag);
}
if (s + 1 != (1 << len))
rot *= info.rate3[bsf(~(unsigned int)(s))];
}
now *= sum_e[bsf(~(unsigned int)(s))];
len += 2;
}
}
}

template <class mint, internal::is_static_modint_t<mint>* = nullptr>
void butterfly_inv(std::vector<mint>& a) {
static constexpr int g = internal::primitive_root<mint::mod()>;
int n = int(a.size());
int h = internal::ceil_pow2(n);

static bool first = true;
static mint sum_ie[30]; // sum_ie[i] = es[0] * ... * es[i - 1] * ies[i]
if (first) {
first = false;
mint es[30], ies[30]; // es[i]^(2^(2+i)) == 1
int cnt2 = bsf(mint::mod() - 1);
mint e = mint(g).pow((mint::mod() - 1) >> cnt2), ie = e.inv();
for (int i = cnt2; i >= 2; i--) {
// e^(2^i) == 1
es[i - 2] = e;
ies[i - 2] = ie;
e *= e;
ie *= ie;
}
mint now = 1;
for (int i = 0; i <= cnt2 - 2; i++) {
sum_ie[i] = ies[i] * now;
now *= es[i];
}
}
static const fft_info<mint> info;

int len = h; // a[i, i+(n>>len), i+2*(n>>len), ..] is transformed
while (len) {
if (len == 1) {
int p = 1 << (h - len);
mint irot = 1;
for (int s = 0; s < (1 << (len - 1)); s++) {
int offset = s << (h - len + 1);
for (int i = 0; i < p; i++) {
auto l = a[i + offset];
auto r = a[i + offset + p];
a[i + offset] = l + r;
a[i + offset + p] =
(unsigned long long)(mint::mod() + l.val() - r.val()) *
irot.val();
;
}
if (s + 1 != (1 << (len - 1)))
irot *= info.irate2[bsf(~(unsigned int)(s))];
}
len--;
} else {
// 4-base
int p = 1 << (h - len);
mint irot = 1, iimag = info.iroot[2];
for (int s = 0; s < (1 << (len - 2)); s++) {
mint irot2 = irot * irot;
mint irot3 = irot2 * irot;
int offset = s << (h - len + 2);
for (int i = 0; i < p; i++) {
auto a0 = 1ULL * a[i + offset + 0 * p].val();
auto a1 = 1ULL * a[i + offset + 1 * p].val();
auto a2 = 1ULL * a[i + offset + 2 * p].val();
auto a3 = 1ULL * a[i + offset + 3 * p].val();

auto a2na3iimag =
1ULL *
mint((mint::mod() + a2 - a3) * iimag.val()).val();

for (int ph = h; ph >= 1; ph--) {
int w = 1 << (ph - 1), p = 1 << (h - ph);
mint inow = 1;
for (int s = 0; s < w; s++) {
int offset = s << (h - ph + 1);
for (int i = 0; i < p; i++) {
auto l = a[i + offset];
auto r = a[i + offset + p];
a[i + offset] = l + r;
a[i + offset + p] =
(unsigned long long)(mint::mod() + l.val() - r.val()) *
inow.val();
a[i + offset] = a0 + a1 + a2 + a3;
a[i + offset + 1 * p] =
(a0 + (mint::mod() - a1) + a2na3iimag) * irot.val();
a[i + offset + 2 * p] =
(a0 + a1 + (mint::mod() - a2) + (mint::mod() - a3)) *
irot2.val();
a[i + offset + 3 * p] =
(a0 + (mint::mod() - a1) + (mint::mod() - a2na3iimag)) *
irot3.val();
}
if (s + 1 != (1 << (len - 2)))
irot *= info.irate3[bsf(~(unsigned int)(s))];
}
inow *= sum_ie[bsf(~(unsigned int)(s))];
len -= 2;
}
}
}

template <class mint, internal::is_static_modint_t<mint>* = nullptr>
std::vector<mint> convolution_naive(const std::vector<mint>& a, const std::vector<mint>& b) {
std::vector<mint> convolution_naive(const std::vector<mint>& a,
const std::vector<mint>& b) {
int n = int(a.size()), m = int(b.size());
std::vector<mint> ans(n + m - 1);
if (n < m) {
Expand Down Expand Up @@ -150,7 +224,8 @@ std::vector<mint> convolution(std::vector<mint>&& a, std::vector<mint>&& b) {
}

template <class mint, internal::is_static_modint_t<mint>* = nullptr>
std::vector<mint> convolution(const std::vector<mint>& a, const std::vector<mint>& b) {
std::vector<mint> convolution(const std::vector<mint>& a,
const std::vector<mint>& b) {
int n = int(a.size()), m = int(b.size());
if (!n || !m) return {};
if (std::min(n, m) <= 60) return convolution_naive(a, b);
Expand Down
8 changes: 8 additions & 0 deletions atcoder/internal_bit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ int ceil_pow2(int n) {
return x;
}

// @param n `1 <= n`
// @return minimum non-negative `x` s.t. `(n & (1 << x)) != 0`
constexpr int bsf_constexpr(unsigned int n) {
int x = 0;
while (!(n & (1 << x))) x++;
return x;
}

// @param n `1 <= n`
// @return minimum non-negative `x` s.t. `(n & (1 << x)) != 0`
int bsf(unsigned int n) {
Expand Down
44 changes: 44 additions & 0 deletions test/unittest/convolution_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,3 +385,47 @@ TEST(ConvolutionTest, Conv18433) {

ASSERT_EQ(conv_naive<MOD>(a, b), convolution<MOD>(a, b));
}

TEST(ConvolutionTest, Conv2) {
std::vector<ll> empty = {};
ASSERT_EQ(empty, convolution<2>(empty, empty));
}

TEST(ConvolutionTest, Conv257) {
const int MOD = 257;
std::vector<ll> a(128), b(129);
for (int i = 0; i < 128; i++) {
a[i] = randint(0, MOD - 1);
}
for (int i = 0; i < 129; i++) {
b[i] = randint(0, MOD - 1);
}

ASSERT_EQ(conv_naive<MOD>(a, b), convolution<MOD>(a, b));
}

TEST(ConvolutionTest, Conv2147483647) {
const int MOD = 2147483647;
using mint = static_modint<MOD>;
std::vector<mint> a(1), b(2);
for (int i = 0; i < 1; i++) {
a[i] = randint(0, MOD - 1);
}
for (int i = 0; i < 2; i++) {
b[i] = randint(0, MOD - 1);
}
ASSERT_EQ(conv_naive(a, b), convolution(a, b));
}

TEST(ConvolutionTest, Conv2130706433) {
const int MOD = 2130706433;
using mint = static_modint<MOD>;
std::vector<mint> a(1024), b(1024);
for (int i = 0; i < 1024; i++) {
a[i] = randint(0, MOD - 1);
}
for (int i = 0; i < 1024; i++) {
b[i] = randint(0, MOD - 1);
}
ASSERT_EQ(conv_naive(a, b), convolution(a, b));
}
30 changes: 30 additions & 0 deletions test/unittest/modint_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,29 @@ TEST(ModintTest, Mod1) {
ASSERT_EQ(0, mint(true).val());
}

TEST(ModintTest, ModIntMax) {
modint::set_mod(INT32_MAX);
for (int i = 0; i < 100; i++) {
for (int j = 0; j < 100; j++) {
ASSERT_EQ((modint(i) * modint(j)).val(), i * j);
}
}
ASSERT_EQ((modint(1234) + modint(5678)).val(), 1234 + 5678);
ASSERT_EQ((modint(1234) - modint(5678)).val(), INT32_MAX - 5678 + 1234);
ASSERT_EQ((modint(1234) * modint(5678)).val(), 1234 * 5678);

using mint = static_modint<INT32_MAX>;
for (int i = 0; i < 100; i++) {
for (int j = 0; j < 100; j++) {
ASSERT_EQ((mint(i) * mint(j)).val(), i * j);
}
}
ASSERT_EQ((mint(1234) + mint(5678)).val(), 1234 + 5678);
ASSERT_EQ((mint(1234) - mint(5678)).val(), INT32_MAX - 5678 + 1234);
ASSERT_EQ((mint(1234) * mint(5678)).val(), 1234 * 5678);
ASSERT_EQ((mint(INT32_MAX) + mint(INT32_MAX)).val(), 0);
}

#ifndef _MSC_VER

TEST(ModintTest, Int128) {
Expand Down Expand Up @@ -158,6 +181,13 @@ TEST(ModintTest, Inv) {
int x = modint(i).inv().val();
ASSERT_EQ(1, (ll(x) * i) % 1'000'000'008);
}

modint::set_mod(INT32_MAX);
for (int i = 1; i < 100000; i++) {
if (gcd(i, INT32_MAX) != 1) continue;
int x = modint(i).inv().val();
ASSERT_EQ(1, (ll(x) * i) % INT32_MAX);
}
}

TEST(ModintTest, ConstUsage) {
Expand Down

0 comments on commit 89d5d0a

Please sign in to comment.