Skip to content

Commit

Permalink
Merge branch 'query-generator-fix' of https://github.com/dice-group/d…
Browse files Browse the repository at this point in the history
…ice-embeddings into whale-pipeline-evaluation
  • Loading branch information
sshivam95 committed Nov 30, 2024
2 parents e1de124 + 82d85c4 commit 3a98ce6
Showing 1 changed file with 127 additions and 26 deletions.
153 changes: 127 additions & 26 deletions dicee/query_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from copy import deepcopy
from .static_funcs import save_pickle, load_pickle


class QueryGenerator:
def __init__(self, train_path, val_path: str, test_path: str, ent2id: Dict = None, rel2id: Dict = None,
def __init__(self, train_path, val_path: str = None, test_path: str = None, ent2id: Dict = None, rel2id: Dict = None,
seed: int = 1,
gen_valid: bool = False,
gen_test: bool = True):
gen_test: bool = True,
mode: str = "test"):

self.train_path = train_path
self.val_path = val_path
Expand All @@ -24,7 +24,9 @@ def __init__(self, train_path, val_path: str, test_path: str, ent2id: Dict = Non

self.max_ans_num = 1e6

self.mode = str
self.mode = mode
# OLD CODE
# self.mode = str
self.ent2id = ent2id
self.rel2id: Dict = rel2id
self.ent_in: Dict = {}
Expand Down Expand Up @@ -72,18 +74,71 @@ def construct_graph(self, paths: List[str]) -> Tuple[Dict, Dict]:
"""
Construct graph from triples
Returns dicts with incoming and outgoing edges
"""
"""
# Mapping from tail entity and a relation to heads.
tail_relation_to_heads = defaultdict(lambda: defaultdict(set))
# Mapping from head and relation to tails.
head_relation_to_tails = defaultdict(lambda: defaultdict(set))

for path in paths:
import shlex

with open(path, "r") as f:
for line in f:
h, r, t = map(str, line.strip().split("\t"))
tail_relation_to_heads[self.ent2id[t]][self.rel2id[r]].add(self.ent2id[h])
head_relation_to_tails[self.ent2id[h]][self.rel2id[r]].add(self.ent2id[t])
line = line.strip()
# Skip empty lines or comments
if not line or line.startswith('#'):
continue

# Create a shlex lexer for the line
lexer = shlex.shlex(line, posix=True)
lexer.whitespace_split = True
# Set the quote characters to only double quotes
lexer.quotes = '"'
lexer.commenters = '' # Disable comment parsing

# Split the line into tokens
tokens = list(lexer)

# Check that the line ends with a period '.'
if tokens[-1] != '.':
continue # Skip malformed lines

# Remove the period
tokens = tokens[:-1]

# Check that we have exactly 3 tokens: h, r, t
if len(tokens) != 3:
continue # Skip malformed lines

h, r, t = tokens

# Check if t is a literal (starts and ends with double quotes)
if t.startswith('"') and t.endswith('"'):
continue # Skip literals

# Strip angle brackets from URIs if present
h = h.strip('<>')
r = r.strip('<>')
t = t.strip('<>')

# Map to IDs
h_id = self.ent2id.get(h)
r_id = self.rel2id.get(r)
t_id = self.ent2id.get(t)

# Skip if any ID is not found
if h_id is None or r_id is None or t_id is None:
continue

# Update the dictionaries
tail_relation_to_heads.setdefault(t_id, {}).setdefault(r_id, set()).add(h_id)
head_relation_to_tails.setdefault(h_id, {}).setdefault(r_id, set()).add(t_id)

# OLD CODE
# h, r, t = map(str, line.strip().split("\t"))
# tail_relation_to_heads[self.ent2id[t]][self.rel2id[r]].add(self.ent2id[h])
# head_relation_to_tails[self.ent2id[h]][self.rel2id[r]].add(self.ent2id[t])

self.ent_in = tail_relation_to_heads
self.ent_out = head_relation_to_tails
Expand Down Expand Up @@ -447,28 +502,74 @@ def generate_queries(self, query_struct:List, gen_num: int, query_type: str):
and getting queries and answers in return
@ TODO: create a class for each single query struct
"""


train_tail_relation_to_heads, train_head_relation_to_tails = self.construct_graph(paths=[self.train_path])
val_tail_relation_to_heads, val_head_relation_to_tails = self.construct_graph(
paths=[self.train_path, self.val_path])
# ?!
valid_only_ent_in, valid_only_ent_out = self.construct_graph(paths=[self.val_path, self.test_path])

test_tail_relation_to_heads, test_head_relation_to_tails = self.construct_graph(
paths=[self.train_path, self.val_path, self.test_path])
# ?!
test_only_ent_in, test_only_ent_out = self.construct_graph(paths=[self.test_path])
self.mode = 'test'
test_queries, test_tp_answers, test_fp_answers, test_fn_answers = self.ground_queries(
query_struct, test_tail_relation_to_heads, test_head_relation_to_tails, val_tail_relation_to_heads,
val_head_relation_to_tails, gen_num, query_type)

if not self.train_path:
raise ValueError("Training path (train_path) is empty. It must be specified.")

if self.mode == "train":
tail_relation_to_heads, head_relation_to_tails = self.construct_graph(paths=[self.train_path])
val_tail_relation_to_heads, val_head_relation_to_tails = tail_relation_to_heads, head_relation_to_tails
elif self.mode == 'valid':
# Check if val_path is not empty
if not self.val_path:
raise ValueError("Validation path (val_path) is empty. It must be specified for 'valid' mode.")

# Use training and validation data
tail_relation_to_heads, head_relation_to_tails = self.construct_graph(paths=[self.train_path, self.val_path])
val_tail_relation_to_heads, val_head_relation_to_tails = self.construct_graph(paths=[self.val_path])
elif self.mode == 'test':
# Check if val_path and test_path are not empty
if not self.val_path:
raise ValueError("Validation path (val_path) is empty. It must be specified for 'test' mode.")
if not self.test_path:
raise ValueError("Test path (test_path) is empty. It must be specified for 'test' mode.")

# Use all data for constructing the graph
tail_relation_to_heads, head_relation_to_tails = self.construct_graph(
paths=[self.train_path, self.val_path, self.test_path])
val_tail_relation_to_heads, val_head_relation_to_tails = self.construct_graph(
paths=[self.val_path, self.test_path])
else:
raise ValueError(f"Unknown mode '{self.mode}'. Mode must be 'train', 'valid', or 'test'.")

# Ground the queries using the constructed graphs
queries, tp_answers, fp_answers, fn_answers = self.ground_queries(
query_structure=query_struct,
ent_in=tail_relation_to_heads,
ent_out=head_relation_to_tails,
small_ent_in=val_tail_relation_to_heads,
small_ent_out=val_head_relation_to_tails,
gen_num=gen_num,
query_name=query_type
)
# @TODO: test_queries has keys that are tuple ,e.g. ('e', ('r',))
# Yet, query structure defined as a list ['e', ['r']].
# Fix this inconsistency
print(
f"General structure is {query_struct} with name {query_type}. Number of queries generated: {len(test_tp_answers)}")
return test_queries, test_tp_answers, test_fp_answers, test_fn_answers
f"General structure is {query_struct} with name {query_type}. Number of queries generated: {len(tp_answers)}")
return queries, tp_answers, fp_answers, fn_answers

# OLD CODE
# train_tail_relation_to_heads, train_head_relation_to_tails = self.construct_graph(paths=[self.train_path])
# val_tail_relation_to_heads, val_head_relation_to_tails = self.construct_graph(
# paths=[self.train_path, self.val_path])
# # ?!
# valid_only_ent_in, valid_only_ent_out = self.construct_graph(paths=[self.val_path, self.test_path])

# test_tail_relation_to_heads, test_head_relation_to_tails = self.construct_graph(
# paths=[self.train_path, self.val_path, self.test_path])
# # ?!
# test_only_ent_in, test_only_ent_out = self.construct_graph(paths=[self.test_path])
# self.mode = 'test'
# test_queries, test_tp_answers, test_fp_answers, test_fn_answers = self.ground_queries(
# query_struct, test_tail_relation_to_heads, test_head_relation_to_tails, val_tail_relation_to_heads,
# val_head_relation_to_tails, gen_num, query_type)
# # @TODO: test_queries has keys that are tuple ,e.g. ('e', ('r',))
# # Yet, query structure defined as a list ['e', ['r']].
# # Fix this inconsistency
# print(
# f"General structure is {query_struct} with name {query_type}. Number of queries generated: {len(test_tp_answers)}")
# return test_queries, test_tp_answers, test_fp_answers, test_fn_answers

def save_queries(self, query_type: str, gen_num: int, save_path: str):
"""
Expand Down

0 comments on commit 3a98ce6

Please sign in to comment.