Skip to content

Commit

Permalink
Implement exact string searching using FM index and LZ index (#57)
Browse files Browse the repository at this point in the history
* Added fm_index.py
* Added wavelet_tree.py
* Added lz_index.py
* Implement tests
* Added test/test_wavelet_tree.py
* Added optimal range_search function
* Added NaiveRangeSearcher
* Added text
  • Loading branch information
prolik123 authored Jul 8, 2024
1 parent e12db81 commit 7e12217
Show file tree
Hide file tree
Showing 30 changed files with 3,054 additions and 4 deletions.
126 changes: 126 additions & 0 deletions common/wavelet_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import itertools

# pylint: disable=too-many-instance-attributes
class WaveletTree:
def __init__(self, t, n, A = None):
t = t[1:]
if A is not None:
self.alphabet = set(A)
else:
self.alphabet = set(t)
A = sorted(list(self.alphabet))
self.n, self.smallest, self.largest = n, A[0], A[-1]
if len(A) == 1:
self.leaf = True
return
self.leaf = False
A_left, A_right = A[:(len(A) + 1) // 2], A[(len(A) + 1) // 2:]
self.zero_indexed, self.one_indexed = set(A_left), set(A_right)
value_array = [1 if c in self.one_indexed else 0 for c in t]
self.prefix_sum = list(itertools.accumulate(value_array, initial = 0))
self.left_indices, self.right_indices = [0], [0]
for i, c in enumerate(t, start = 1):
if c in self.zero_indexed:
self.left_indices.append(i)
else:
self.right_indices.append(i)
left_text = ['#'] + [c for c in t if c in self.zero_indexed]
right_text = ['#'] + [c for c in t if c in self.one_indexed]
self.left = WaveletTree(left_text, len(left_text) - 1, A_left)
self.right = WaveletTree(right_text, len(right_text) - 1, A_right)

def _left_tree_range(self, l, r):
return l - self.prefix_sum[l - 1], r - self.prefix_sum[r]

def _right_tree_range(self, l, r):
return (self.prefix_sum[l - 1] + 1, self.prefix_sum[r])

def rank(self, c, l, r):
if c not in self.alphabet or l > r or l > self.n or r < 1:
return 0
if self.leaf:
return r - l + 1
if c in self.zero_indexed:
new_l, new_r = self._left_tree_range(l, r)
return self.left.rank(c, new_l, new_r)
new_l, new_r = self._right_tree_range(l, r)
return self.right.rank(c, new_l, new_r)

def prefix_rank(self, c, r):
return self.rank(c, 1, r)

def select(self, c, k, l, r):
if c not in self.alphabet or l > r or l > self.n or r < 1 :
return None
if self.leaf:
return k + l - 1 if k <= r - l + 1 else None
if c in self.zero_indexed:
new_l, new_r = self._left_tree_range(l, r)
result = self.left.select(c, k, new_l, new_r)
return self.left_indices[result] if result is not None else None
new_l, new_r = self._right_tree_range(l, r)
result = self.right.select(c, k, new_l, new_r)
return self.right_indices[result] if result is not None else None

def quantile(self, k, l, r):
if k < 1 or k > r - l + 1:
return None
if self.leaf:
return self.smallest if k <= self.n else None
left_num = self.prefix_sum[r] - self.prefix_sum[l-1]
if r - l + 1 - left_num >= k:
new_l, new_r = self._left_tree_range(l, r)
return self.left.quantile(k, new_l, new_r)
new_l, new_r = self._right_tree_range(l, r)
return self.right.quantile(k-r+l-1+left_num, new_l, new_r)

def _does_one_range_end_in_another(self, l, r, i, j):
return (i <= l <= j) or (i <= r <= j)

def _ranges_intersect(self, l, r, i, j):
return (self._does_one_range_end_in_another(l, r, i ,j) or
self._does_one_range_end_in_another(i, j, l, r))

def range_count(self, l, r, x, y):
if l > r or l > self.n or l < 1 or x > y:
return 0
if x <= self.smallest and self.largest <= y:
return r-l+1
if self.leaf or y < self.smallest or x > self.largest:
return 0
l_node, r_node = self.left, self.right
if (self._ranges_intersect(l_node.smallest, l_node.largest, x, y) and
self._ranges_intersect(r_node.smallest, r_node.largest, x, y)):
new_left_l, new_left_r = self._left_tree_range(l, r)
new_right_l, new_right_r = self._right_tree_range(l, r)
return (self.left.range_count(new_left_l, new_left_r, x, y)
+ self.right.range_count(new_right_l, new_right_r, x, y))
if self._ranges_intersect(self.right.smallest, self.right.largest, x, y):
new_l, new_r = self._right_tree_range(l, r)
return self.right.range_count(new_l, new_r, x, y)
new_l, new_r = self._left_tree_range(l, r)
return self.left.range_count(new_l, new_r, x, y)

def range_search(self, l, r, x, y):
if l > r or l > self.n or l < 1 or x > y:
return []
if x <= self.smallest and self.largest <= y:
return list(range(l, r + 1))
if self.leaf or y < self.smallest or x > self.largest:
return []
l_node, r_node = self.left, self.right
if (self._ranges_intersect(l_node.smallest, l_node.largest, x, y)
and self._ranges_intersect(r_node.smallest, r_node.largest, x, y)):
new_left_l, new_left_r = self._left_tree_range(l, r)
new_right_l, new_right_r = self._right_tree_range(l, r)
return (([self.left_indices[x] for x in
self.left.range_search(new_left_l, new_left_r, x, y)]) +
([self.right_indices[x] for x in
self.right.range_search(new_right_l, new_right_r, x, y)]))
if self._ranges_intersect(self.right.smallest, self.right.largest, x, y):
return [
self.right_indices[x]
for x in self.right.range_search(*self._right_tree_range(l, r), x, y)]
return [
self.left_indices[x]
for x in self.left.range_search(*self._left_tree_range(l, r), x, y)]
99 changes: 99 additions & 0 deletions string_indexing/fm_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#pylint: disable=too-few-public-methods
#pylint: disable=invalid-name
class _RankSearcher:
SAMPLE_SIZE = 8

def __init__(self, L, mapper_of_chars, n):
self.L = L
#prepare closest samplings
current_sample = 0
self.closest_sample = [0]
for i in range(1, n+2):
if (abs(current_sample-i) > abs(current_sample + self.SAMPLE_SIZE-i) and
(i + self.SAMPLE_SIZE < n)):
current_sample += self.SAMPLE_SIZE
self.closest_sample.append(current_sample)

#Generate values for occ for given samples O(|A|*n)
self.occ_for_char = { self.L[i]: [0] for i in range(1, n+2)}
for c in mapper_of_chars:
current_value, next_sample = 0, self.SAMPLE_SIZE
for i in range(1, n+2):
if L[i] == c:
current_value += 1
if i == next_sample:
self.occ_for_char[c].append(current_value)
next_sample = next_sample + self.SAMPLE_SIZE

def prefix_rank(self, c, i):
if self.closest_sample[i] < i:
to_add = sum(
1 for c_it in self.L[self.closest_sample[i] + 1:i + 1] if c_it == c)
else:
to_add = sum(
-1 for c_it in self.L[i + 1:self.closest_sample[i] + 1] if c_it == c)
return (self.occ_for_char[c][self.closest_sample[i] // self.SAMPLE_SIZE]
+ to_add)

#pylint: disable=too-few-public-methods
#pylint: disable=invalid-name
class _FMIndex:
def __init__ (self, SA, BWT, text, n, rank_searcher = None):
self.L = BWT
F = '#$' + ''.join(text[SA[i]] for i in range(1, n + 1))
self.n = n
self.SA = SA

#prepare char mapping for F
self.mapper_of_chars = { F[2] : 0}
self.beginnings = [2]
last = F[2]
for i in range(3, n+2):
if F[i] != last:
last = F[i]
self.beginnings.append(i)
self.mapper_of_chars[last] = len(self.beginnings) - 1

self.len_of_alphabet = len(self.mapper_of_chars)
self.rank_searcher = (_RankSearcher(self.L, self.mapper_of_chars, n)
if rank_searcher is None else rank_searcher)

def from_suffix_array_and_bwt(SA, BWT, text, n, rank_searcher = None):
return _FMIndex(SA, BWT, text, n, rank_searcher)

# O(|p|)
def count(FM, p, size):
low, high = _get_range_of_occurrences(FM, p, size)
return max(high - low + 1, 0) if low > -1 else 0

# O(|p| + k) where k is the number or occurances of p in text
def contains(FM, p, l):
low, high = _get_range_of_occurrences(FM, p, l)
yield from sorted([FM.SA[i-1] for i in range(low, high + 1) if low > -1])

def _get_range_of_occurrences(FM, p, size):
if size > FM.n or size == 0:
return -1, -1

if p[-1] not in FM.mapper_of_chars:
return -1, -1

map_idx = FM.mapper_of_chars[p[-1]]
l= FM.beginnings[map_idx]
r = (FM.beginnings[map_idx + 1] - 1
if map_idx != FM.len_of_alphabet - 1 else FM.n + 1)

for c in p[-2:0:-1]:
if c not in FM.mapper_of_chars:
return -1, -1
occurrences_before = FM.rank_searcher.prefix_rank(c, l - 1)
occurrences_after = FM.rank_searcher.prefix_rank(c, r)
if occurrences_before == occurrences_after:
return -1, -1
map_idx = FM.mapper_of_chars[c]
l = FM.beginnings[map_idx] + occurrences_before
r = FM.beginnings[map_idx] + occurrences_after - 1
if r < l:
return -1, -1

return l, r
Loading

0 comments on commit 7e12217

Please sign in to comment.