-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement exact string searching using FM index and LZ index (#57)
* Added fm_index.py * Added wavelet_tree.py * Added lz_index.py * Implement tests * Added test/test_wavelet_tree.py * Added optimal range_search function * Added NaiveRangeSearcher * Added text
- Loading branch information
Showing
30 changed files
with
3,054 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import itertools | ||
|
||
# pylint: disable=too-many-instance-attributes | ||
class WaveletTree: | ||
def __init__(self, t, n, A = None): | ||
t = t[1:] | ||
if A is not None: | ||
self.alphabet = set(A) | ||
else: | ||
self.alphabet = set(t) | ||
A = sorted(list(self.alphabet)) | ||
self.n, self.smallest, self.largest = n, A[0], A[-1] | ||
if len(A) == 1: | ||
self.leaf = True | ||
return | ||
self.leaf = False | ||
A_left, A_right = A[:(len(A) + 1) // 2], A[(len(A) + 1) // 2:] | ||
self.zero_indexed, self.one_indexed = set(A_left), set(A_right) | ||
value_array = [1 if c in self.one_indexed else 0 for c in t] | ||
self.prefix_sum = list(itertools.accumulate(value_array, initial = 0)) | ||
self.left_indices, self.right_indices = [0], [0] | ||
for i, c in enumerate(t, start = 1): | ||
if c in self.zero_indexed: | ||
self.left_indices.append(i) | ||
else: | ||
self.right_indices.append(i) | ||
left_text = ['#'] + [c for c in t if c in self.zero_indexed] | ||
right_text = ['#'] + [c for c in t if c in self.one_indexed] | ||
self.left = WaveletTree(left_text, len(left_text) - 1, A_left) | ||
self.right = WaveletTree(right_text, len(right_text) - 1, A_right) | ||
|
||
def _left_tree_range(self, l, r): | ||
return l - self.prefix_sum[l - 1], r - self.prefix_sum[r] | ||
|
||
def _right_tree_range(self, l, r): | ||
return (self.prefix_sum[l - 1] + 1, self.prefix_sum[r]) | ||
|
||
def rank(self, c, l, r): | ||
if c not in self.alphabet or l > r or l > self.n or r < 1: | ||
return 0 | ||
if self.leaf: | ||
return r - l + 1 | ||
if c in self.zero_indexed: | ||
new_l, new_r = self._left_tree_range(l, r) | ||
return self.left.rank(c, new_l, new_r) | ||
new_l, new_r = self._right_tree_range(l, r) | ||
return self.right.rank(c, new_l, new_r) | ||
|
||
def prefix_rank(self, c, r): | ||
return self.rank(c, 1, r) | ||
|
||
def select(self, c, k, l, r): | ||
if c not in self.alphabet or l > r or l > self.n or r < 1 : | ||
return None | ||
if self.leaf: | ||
return k + l - 1 if k <= r - l + 1 else None | ||
if c in self.zero_indexed: | ||
new_l, new_r = self._left_tree_range(l, r) | ||
result = self.left.select(c, k, new_l, new_r) | ||
return self.left_indices[result] if result is not None else None | ||
new_l, new_r = self._right_tree_range(l, r) | ||
result = self.right.select(c, k, new_l, new_r) | ||
return self.right_indices[result] if result is not None else None | ||
|
||
def quantile(self, k, l, r): | ||
if k < 1 or k > r - l + 1: | ||
return None | ||
if self.leaf: | ||
return self.smallest if k <= self.n else None | ||
left_num = self.prefix_sum[r] - self.prefix_sum[l-1] | ||
if r - l + 1 - left_num >= k: | ||
new_l, new_r = self._left_tree_range(l, r) | ||
return self.left.quantile(k, new_l, new_r) | ||
new_l, new_r = self._right_tree_range(l, r) | ||
return self.right.quantile(k-r+l-1+left_num, new_l, new_r) | ||
|
||
def _does_one_range_end_in_another(self, l, r, i, j): | ||
return (i <= l <= j) or (i <= r <= j) | ||
|
||
def _ranges_intersect(self, l, r, i, j): | ||
return (self._does_one_range_end_in_another(l, r, i ,j) or | ||
self._does_one_range_end_in_another(i, j, l, r)) | ||
|
||
def range_count(self, l, r, x, y): | ||
if l > r or l > self.n or l < 1 or x > y: | ||
return 0 | ||
if x <= self.smallest and self.largest <= y: | ||
return r-l+1 | ||
if self.leaf or y < self.smallest or x > self.largest: | ||
return 0 | ||
l_node, r_node = self.left, self.right | ||
if (self._ranges_intersect(l_node.smallest, l_node.largest, x, y) and | ||
self._ranges_intersect(r_node.smallest, r_node.largest, x, y)): | ||
new_left_l, new_left_r = self._left_tree_range(l, r) | ||
new_right_l, new_right_r = self._right_tree_range(l, r) | ||
return (self.left.range_count(new_left_l, new_left_r, x, y) | ||
+ self.right.range_count(new_right_l, new_right_r, x, y)) | ||
if self._ranges_intersect(self.right.smallest, self.right.largest, x, y): | ||
new_l, new_r = self._right_tree_range(l, r) | ||
return self.right.range_count(new_l, new_r, x, y) | ||
new_l, new_r = self._left_tree_range(l, r) | ||
return self.left.range_count(new_l, new_r, x, y) | ||
|
||
def range_search(self, l, r, x, y): | ||
if l > r or l > self.n or l < 1 or x > y: | ||
return [] | ||
if x <= self.smallest and self.largest <= y: | ||
return list(range(l, r + 1)) | ||
if self.leaf or y < self.smallest or x > self.largest: | ||
return [] | ||
l_node, r_node = self.left, self.right | ||
if (self._ranges_intersect(l_node.smallest, l_node.largest, x, y) | ||
and self._ranges_intersect(r_node.smallest, r_node.largest, x, y)): | ||
new_left_l, new_left_r = self._left_tree_range(l, r) | ||
new_right_l, new_right_r = self._right_tree_range(l, r) | ||
return (([self.left_indices[x] for x in | ||
self.left.range_search(new_left_l, new_left_r, x, y)]) + | ||
([self.right_indices[x] for x in | ||
self.right.range_search(new_right_l, new_right_r, x, y)])) | ||
if self._ranges_intersect(self.right.smallest, self.right.largest, x, y): | ||
return [ | ||
self.right_indices[x] | ||
for x in self.right.range_search(*self._right_tree_range(l, r), x, y)] | ||
return [ | ||
self.left_indices[x] | ||
for x in self.left.range_search(*self._left_tree_range(l, r), x, y)] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
#pylint: disable=too-few-public-methods | ||
#pylint: disable=invalid-name | ||
class _RankSearcher: | ||
SAMPLE_SIZE = 8 | ||
|
||
def __init__(self, L, mapper_of_chars, n): | ||
self.L = L | ||
#prepare closest samplings | ||
current_sample = 0 | ||
self.closest_sample = [0] | ||
for i in range(1, n+2): | ||
if (abs(current_sample-i) > abs(current_sample + self.SAMPLE_SIZE-i) and | ||
(i + self.SAMPLE_SIZE < n)): | ||
current_sample += self.SAMPLE_SIZE | ||
self.closest_sample.append(current_sample) | ||
|
||
#Generate values for occ for given samples O(|A|*n) | ||
self.occ_for_char = { self.L[i]: [0] for i in range(1, n+2)} | ||
for c in mapper_of_chars: | ||
current_value, next_sample = 0, self.SAMPLE_SIZE | ||
for i in range(1, n+2): | ||
if L[i] == c: | ||
current_value += 1 | ||
if i == next_sample: | ||
self.occ_for_char[c].append(current_value) | ||
next_sample = next_sample + self.SAMPLE_SIZE | ||
|
||
def prefix_rank(self, c, i): | ||
if self.closest_sample[i] < i: | ||
to_add = sum( | ||
1 for c_it in self.L[self.closest_sample[i] + 1:i + 1] if c_it == c) | ||
else: | ||
to_add = sum( | ||
-1 for c_it in self.L[i + 1:self.closest_sample[i] + 1] if c_it == c) | ||
return (self.occ_for_char[c][self.closest_sample[i] // self.SAMPLE_SIZE] | ||
+ to_add) | ||
|
||
#pylint: disable=too-few-public-methods | ||
#pylint: disable=invalid-name | ||
class _FMIndex: | ||
def __init__ (self, SA, BWT, text, n, rank_searcher = None): | ||
self.L = BWT | ||
F = '#$' + ''.join(text[SA[i]] for i in range(1, n + 1)) | ||
self.n = n | ||
self.SA = SA | ||
|
||
#prepare char mapping for F | ||
self.mapper_of_chars = { F[2] : 0} | ||
self.beginnings = [2] | ||
last = F[2] | ||
for i in range(3, n+2): | ||
if F[i] != last: | ||
last = F[i] | ||
self.beginnings.append(i) | ||
self.mapper_of_chars[last] = len(self.beginnings) - 1 | ||
|
||
self.len_of_alphabet = len(self.mapper_of_chars) | ||
self.rank_searcher = (_RankSearcher(self.L, self.mapper_of_chars, n) | ||
if rank_searcher is None else rank_searcher) | ||
|
||
def from_suffix_array_and_bwt(SA, BWT, text, n, rank_searcher = None): | ||
return _FMIndex(SA, BWT, text, n, rank_searcher) | ||
|
||
# O(|p|) | ||
def count(FM, p, size): | ||
low, high = _get_range_of_occurrences(FM, p, size) | ||
return max(high - low + 1, 0) if low > -1 else 0 | ||
|
||
# O(|p| + k) where k is the number or occurances of p in text | ||
def contains(FM, p, l): | ||
low, high = _get_range_of_occurrences(FM, p, l) | ||
yield from sorted([FM.SA[i-1] for i in range(low, high + 1) if low > -1]) | ||
|
||
def _get_range_of_occurrences(FM, p, size): | ||
if size > FM.n or size == 0: | ||
return -1, -1 | ||
|
||
if p[-1] not in FM.mapper_of_chars: | ||
return -1, -1 | ||
|
||
map_idx = FM.mapper_of_chars[p[-1]] | ||
l= FM.beginnings[map_idx] | ||
r = (FM.beginnings[map_idx + 1] - 1 | ||
if map_idx != FM.len_of_alphabet - 1 else FM.n + 1) | ||
|
||
for c in p[-2:0:-1]: | ||
if c not in FM.mapper_of_chars: | ||
return -1, -1 | ||
occurrences_before = FM.rank_searcher.prefix_rank(c, l - 1) | ||
occurrences_after = FM.rank_searcher.prefix_rank(c, r) | ||
if occurrences_before == occurrences_after: | ||
return -1, -1 | ||
map_idx = FM.mapper_of_chars[c] | ||
l = FM.beginnings[map_idx] + occurrences_before | ||
r = FM.beginnings[map_idx] + occurrences_after - 1 | ||
if r < l: | ||
return -1, -1 | ||
|
||
return l, r |
Oops, something went wrong.