Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement exact string searching using FM index and LZ index #57

Merged
merged 24 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions common/wavelet_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import itertools

# pylint: disable=too-many-instance-attributes
class WaveletTree:
def __init__(self, t, n, A = None):
t = t[1:]
if A is not None:
self.alphabet = set(A)
else:
self.alphabet = set(t)
A = sorted(list(self.alphabet))
self.n, self.smallest, self.largest = n, A[0], A[-1]
if len(A) == 1:
self.leaf = True
return
self.leaf = False
A_left, A_right = A[:(len(A) + 1) // 2], A[(len(A) + 1) // 2:]
self.zero_indexed, self.one_indexed = set(A_left), set(A_right)
value_array = [1 if c in self.one_indexed else 0 for c in t]
self.prefix_sum = list(itertools.accumulate(value_array, initial = 0))
self.left_indices, self.right_indices = [0], [0]
for i, c in enumerate(t, start = 1):
if c in self.zero_indexed:
self.left_indices.append(i)
else:
self.right_indices.append(i)
left_text = ['#'] + [c for c in t if c in self.zero_indexed]
right_text = ['#'] + [c for c in t if c in self.one_indexed]
self.left = WaveletTree(left_text, len(left_text) - 1, A_left)
self.right = WaveletTree(right_text, len(right_text) - 1, A_right)

def _left_tree_range(self, l, r):
return l - self.prefix_sum[l - 1], r - self.prefix_sum[r]

def _right_tree_range(self, l, r):
return (self.prefix_sum[l - 1] + 1, self.prefix_sum[r])

def rank(self, c, l, r):
if c not in self.alphabet or l > r or l > self.n or r < 1:
return 0
if self.leaf:
return r - l + 1
if c in self.zero_indexed:
new_l, new_r = self._left_tree_range(l, r)
return self.left.rank(c, new_l, new_r)
new_l, new_r = self._right_tree_range(l, r)
return self.right.rank(c, new_l, new_r)

def prefix_rank(self, c, r):
return self.rank(c, 1, r)

def select(self, c, k, l, r):
if c not in self.alphabet or l > r or l > self.n or r < 1 :
return None
if self.leaf:
return k + l - 1 if k <= r - l + 1 else None
if c in self.zero_indexed:
new_l, new_r = self._left_tree_range(l, r)
result = self.left.select(c, k, new_l, new_r)
return self.left_indices[result] if result is not None else None
new_l, new_r = self._right_tree_range(l, r)
result = self.right.select(c, k, new_l, new_r)
return self.right_indices[result] if result is not None else None

def quantile(self, k, l, r):
if k < 1 or k > r - l + 1:
return None
if self.leaf:
return self.smallest if k <= self.n else None
left_num = self.prefix_sum[r] - self.prefix_sum[l-1]
if r - l + 1 - left_num >= k:
new_l, new_r = self._left_tree_range(l, r)
return self.left.quantile(k, new_l, new_r)
new_l, new_r = self._right_tree_range(l, r)
return self.right.quantile(k-r+l-1+left_num, new_l, new_r)

def _does_one_range_end_in_another(self, l, r, i, j):
return (i <= l <= j) or (i <= r <= j)

def _ranges_intersect(self, l, r, i, j):
return (self._does_one_range_end_in_another(l, r, i ,j) or
self._does_one_range_end_in_another(i, j, l, r))

def range_count(self, l, r, x, y):
if l > r or l > self.n or l < 1 or x > y:
return 0
if x <= self.smallest and self.largest <= y:
return r-l+1
if self.leaf or y < self.smallest or x > self.largest:
return 0
l_node, r_node = self.left, self.right
if (self._ranges_intersect(l_node.smallest, l_node.largest, x, y) and
self._ranges_intersect(r_node.smallest, r_node.largest, x, y)):
new_left_l, new_left_r = self._left_tree_range(l, r)
new_right_l, new_right_r = self._right_tree_range(l, r)
return (self.left.range_count(new_left_l, new_left_r, x, y)
+ self.right.range_count(new_right_l, new_right_r, x, y))
if self._ranges_intersect(self.right.smallest, self.right.largest, x, y):
new_l, new_r = self._right_tree_range(l, r)
return self.right.range_count(new_l, new_r, x, y)
new_l, new_r = self._left_tree_range(l, r)
return self.left.range_count(new_l, new_r, x, y)

def range_search(self, l, r, x, y):
if l > r or l > self.n or l < 1 or x > y:
return []
if x <= self.smallest and self.largest <= y:
return list(range(l, r + 1))
if self.leaf or y < self.smallest or x > self.largest:
return []
l_node, r_node = self.left, self.right
if (self._ranges_intersect(l_node.smallest, l_node.largest, x, y)
and self._ranges_intersect(r_node.smallest, r_node.largest, x, y)):
new_left_l, new_left_r = self._left_tree_range(l, r)
new_right_l, new_right_r = self._right_tree_range(l, r)
return (([self.left_indices[x] for x in
self.left.range_search(new_left_l, new_left_r, x, y)]) +
([self.right_indices[x] for x in
self.right.range_search(new_right_l, new_right_r, x, y)]))
if self._ranges_intersect(self.right.smallest, self.right.largest, x, y):
return [
self.right_indices[x]
for x in self.right.range_search(*self._right_tree_range(l, r), x, y)]
return [
self.left_indices[x]
for x in self.left.range_search(*self._left_tree_range(l, r), x, y)]
99 changes: 99 additions & 0 deletions string_indexing/fm_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#pylint: disable=too-few-public-methods
#pylint: disable=invalid-name
class _RankSearcher:
SAMPLE_SIZE = 8

def __init__(self, L, mapper_of_chars, n):
self.L = L
#prepare closest samplings
current_sample = 0
self.closest_sample = [0]
for i in range(1, n+2):
if (abs(current_sample-i) > abs(current_sample + self.SAMPLE_SIZE-i) and
(i + self.SAMPLE_SIZE < n)):
current_sample += self.SAMPLE_SIZE
self.closest_sample.append(current_sample)

#Generate values for occ for given samples O(|A|*n)
self.occ_for_char = { self.L[i]: [0] for i in range(1, n+2)}
for c in mapper_of_chars:
current_value, next_sample = 0, self.SAMPLE_SIZE
for i in range(1, n+2):
if L[i] == c:
current_value += 1
if i == next_sample:
self.occ_for_char[c].append(current_value)
next_sample = next_sample + self.SAMPLE_SIZE

def prefix_rank(self, c, i):
if self.closest_sample[i] < i:
to_add = sum(
1 for c_it in self.L[self.closest_sample[i] + 1:i + 1] if c_it == c)
else:
to_add = sum(
-1 for c_it in self.L[i + 1:self.closest_sample[i] + 1] if c_it == c)
return (self.occ_for_char[c][self.closest_sample[i] // self.SAMPLE_SIZE]
+ to_add)

#pylint: disable=too-few-public-methods
#pylint: disable=invalid-name
class _FMIndex:
def __init__ (self, SA, BWT, text, n, rank_searcher = None):
self.L = BWT
F = '#$' + ''.join(text[SA[i]] for i in range(1, n + 1))
self.n = n
self.SA = SA

#prepare char mapping for F
self.mapper_of_chars = { F[2] : 0}
self.beginnings = [2]
last = F[2]
for i in range(3, n+2):
if F[i] != last:
last = F[i]
self.beginnings.append(i)
self.mapper_of_chars[last] = len(self.beginnings) - 1

self.len_of_alphabet = len(self.mapper_of_chars)
self.rank_searcher = (_RankSearcher(self.L, self.mapper_of_chars, n)
if rank_searcher is None else rank_searcher)

def from_suffix_array_and_bwt(SA, BWT, text, n, rank_searcher = None):
return _FMIndex(SA, BWT, text, n, rank_searcher)

# O(|p|)
def count(FM, p, size):
low, high = _get_range_of_occurrences(FM, p, size)
return max(high - low + 1, 0) if low > -1 else 0

# O(|p| + k) where k is the number or occurances of p in text
def contains(FM, p, l):
low, high = _get_range_of_occurrences(FM, p, l)
yield from sorted([FM.SA[i-1] for i in range(low, high + 1) if low > -1])

def _get_range_of_occurrences(FM, p, size):
if size > FM.n or size == 0:
return -1, -1

if p[-1] not in FM.mapper_of_chars:
return -1, -1

map_idx = FM.mapper_of_chars[p[-1]]
l= FM.beginnings[map_idx]
r = (FM.beginnings[map_idx + 1] - 1
if map_idx != FM.len_of_alphabet - 1 else FM.n + 1)

for c in p[-2:0:-1]:
if c not in FM.mapper_of_chars:
return -1, -1
occurrences_before = FM.rank_searcher.prefix_rank(c, l - 1)
occurrences_after = FM.rank_searcher.prefix_rank(c, r)
if occurrences_before == occurrences_after:
return -1, -1
map_idx = FM.mapper_of_chars[c]
l = FM.beginnings[map_idx] + occurrences_before
r = FM.beginnings[map_idx] + occurrences_after - 1
if r < l:
return -1, -1

return l, r
Loading