-
Notifications
You must be signed in to change notification settings - Fork 62
/
Copy pathbignum_addpow2_compare.cpp
executable file
·110 lines (82 loc) · 3.22 KB
/
bignum_addpow2_compare.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
// {{{ strings/polynomial_hash }}}
#include <vector>
#include <cassert>
#include <iostream>
struct bignum_addpow2_compare {
using hash_t = polynomial_hash<modnum<int(1e9 + 7)>, 3, 2>;
struct binary_string {
hash_t hash;
int ct_trailing_ones;
int left, right;
binary_string(bool bit) : hash(bit), ct_trailing_ones(bit), left(-1), right(-1) {}
binary_string(hash_t _hash, int _ct_trailing_ones, int _left, int _right)
: hash(_hash), ct_trailing_ones(_ct_trailing_ones), left(_left), right(_right) {}
int length() const { return hash.N; }
bool is_all_ones() const { return hash.N == ct_trailing_ones; }
};
std::vector<binary_string> nums = { binary_string(0), binary_string(1) };
const int one = 1;
std::vector<int> zero = { 0 };
int concatenate(int x, int y) {
nums.emplace_back(
hash_t::concatenate(nums[x].hash, nums[y].hash),
nums[y].ct_trailing_ones + (nums[y].is_all_ones() ? nums[x].ct_trailing_ones : 0),
x,
y
);
return int(nums.size()) - 1;
}
int get_zero_of_width(int bit_width) {
int index = __builtin_ctz(bit_width);
while (index >= int(zero.size()))
zero.push_back(concatenate(zero.back(), zero.back()));
return zero[index];
}
int carry_count(int x, int pow) const {
const int len = nums[x].length();
if (len <= pow)
return 0;
if (nums[x].is_all_ones())
return len - pow;
if (len == 1)
return nums[x].ct_trailing_ones;
int res = carry_count(nums[x].right, pow);
if (pow + res >= len / 2)
res += carry_count(nums[x].left, pow + res - len / 2);
return res;
}
int invert_range(int x, int L, int R) {
const int len = nums[x].length();
assert(0 <= L && L < R && R <= len);
if (len == 1)
return x ^ 1;
if (0 == L && R == len && nums[x].is_all_ones())
return get_zero_of_width(len);
int right = L < len / 2 ? invert_range(nums[x].right, L, std::min(R, len / 2)) : nums[x].right;
int left = R > len / 2 ? invert_range(nums[x].left, std::max(0, L - len / 2), R - len / 2) : nums[x].left;
return concatenate(left, right);
}
int add_pow2(int x, int pow) {
int carries = carry_count(x, pow);
while (nums[x].length() <= pow + carries)
x = concatenate(get_zero_of_width(nums[x].length()), x);
return invert_range(x, pow, pow + carries + 1);
}
bool less_than(int x, int y) const {
if (nums[x].length() != nums[y].length())
return nums[x].length() < nums[y].length();
if (nums[x].length() == 1)
return nums[x].ct_trailing_ones < nums[y].ct_trailing_ones;
if (nums[nums[x].left].hash == nums[nums[y].left].hash)
return less_than(nums[x].right, nums[y].right);
return less_than(nums[x].left, nums[y].left);
}
void print_bits(int x) const {
if (nums[x].length() == 1) {
std::cout << nums[x].ct_trailing_ones;
} else {
print_bits(nums[x].left);
print_bits(nums[x].right);
}
}
};