Skip to content

Commit

Permalink
WIP: BTE kvsall and usage within KGE class
Browse files Browse the repository at this point in the history
  • Loading branch information
Demirrr committed Nov 2, 2023
1 parent 1248ae9 commit cf82885
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 40 deletions.
9 changes: 8 additions & 1 deletion dicee/abstracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,14 @@ def get_entity_embeddings(self, items: List[str]):
Returns
---------
"""
return self.model.entity_embeddings(torch.LongTensor([self.entity_to_idx[i] for i in items]))
if self.configs["byte_pair_encoding"]:
t_encode = self.enc.encode_batch(items)
if len(t_encode) !=self.configs["max_length_subword_tokens"]:
for i in range(len(t_encode)):
t_encode[i].extend([self.dummy_id for _ in range(self.configs["max_length_subword_tokens"] - len(t_encode[i]))])
return self.model.token_embeddings(torch.LongTensor(t_encode)).flatten(1)
else:
return self.model.entity_embeddings(torch.LongTensor([self.entity_to_idx[i] for i in items]))

def get_relation_embeddings(self, items: List[str]):
"""
Expand Down
6 changes: 3 additions & 3 deletions dicee/analyse_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,12 @@ def save_experiment(self, x):

def to_df(self):
return pd.DataFrame(
dict(model_name=self.model_name, # pq=self.pq,
dict(model_name=self.model_name,
byte_pair_encoding=self.byte_pair_encoding,
path_dataset_folder=self.path_dataset_folder,
train_mrr=self.train_mrr, train_h1=self.train_h1,
train_h3=self.train_h3, train_h10=self.train_h10,
num_epochs=self.num_epochs,
#full_storage_path=self.full_storage_path,
val_mrr=self.val_mrr, val_h1=self.val_h1,
val_h3=self.val_h3, val_h10=self.val_h10,
Expand All @@ -92,8 +93,7 @@ def to_df(self):
runtime=self.runtime,
params=self.num_params,
callbacks=self.callbacks,
# normalization=self.normalization,
# embeddingdim=self.embedding_dim
embeddingdim=self.embedding_dim,
scoring_technique=self.scoring_technique
)
)
Expand Down
100 changes: 70 additions & 30 deletions dicee/knowledge_graph_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def eval_lp_performance(self, dataset=List[Tuple[str, str, str]], filtered=True)
return evaluate_lp(model=self.model, triple_idx=idx_dataset, num_entities=len(self.entity_to_idx),
er_vocab=None, re_vocab=None)

def predict_missing_head_entity(self, relation: Union[List[str], str], tail_entity: Union[List[str], str]) -> Tuple:
def predict_missing_head_entity(self, relation: Union[List[str], str], tail_entity: Union[List[str], str],
within=None) -> Tuple:
"""
Given a relation and a tail entity, return top k ranked head entity.
Expand Down Expand Up @@ -92,10 +93,10 @@ def predict_missing_head_entity(self, relation: Union[List[str], str], tail_enti
x = torch.stack((head_entity,
relation.repeat(self.num_entities, ),
tail_entity.repeat(self.num_entities, )), dim=1)
return self.model.forward(x)
return self.model(x)

def predict_missing_relations(self, head_entity: Union[List[str], str],
tail_entity: Union[List[str], str]) -> Tuple:
tail_entity: Union[List[str], str], within=None) -> Tuple:
"""
Given a head entity and a tail entity, return top k ranked relations.
Expand Down Expand Up @@ -139,7 +140,7 @@ def predict_missing_relations(self, head_entity: Union[List[str], str],
return self.model(x)

def predict_missing_tail_entity(self, head_entity: Union[List[str], str],
relation: Union[List[str], str]) -> torch.FloatTensor:
relation: Union[List[str], str], within: List[str] = None) -> torch.FloatTensor:
"""
Given a head entity and a relation, return top k ranked entities
Expand All @@ -161,26 +162,62 @@ def predict_missing_tail_entity(self, head_entity: Union[List[str], str],
scores
"""
tail_entity = torch.arange(0, len(self.entity_to_idx))
if within is not None:
h_encode = self.enc.encode(head_entity[0])
r_encode = self.enc.encode(relation[0])
t_encode = self.enc.encode_batch(within)
length = self.configs["max_length_subword_tokens"]

if isinstance(head_entity, list):
head_entity = torch.LongTensor([self.entity_to_idx[i] for i in head_entity])
else:
head_entity = torch.LongTensor([self.entity_to_idx[head_entity]])
if isinstance(relation, list):
relation = torch.LongTensor([self.relation_to_idx[i] for i in relation])
num_entities = len(within)
if len(h_encode) != length:
h_encode.extend([self.dummy_id for _ in range(length - len(h_encode))])

if len(r_encode) != length:
r_encode.extend([self.dummy_id for _ in range(length - len(r_encode))])

if len(t_encode) != length:
for i in range(len(t_encode)):
t_encode[i].extend([self.dummy_id for _ in range(length - len(t_encode[i]))])

h_encode = torch.LongTensor(h_encode).unsqueeze(0)
r_encode = torch.LongTensor(r_encode).unsqueeze(0)
t_encode = torch.LongTensor(t_encode)

x = torch.stack((torch.repeat_interleave(input=h_encode, repeats=num_entities, dim=0),
torch.repeat_interleave(input=r_encode, repeats=num_entities, dim=0),
t_encode), dim=1)
else:
relation = torch.LongTensor([self.relation_to_idx[relation]])
tail_entity = torch.arange(0, len(self.entity_to_idx))

x = torch.stack((head_entity.repeat(self.num_entities, ),
relation.repeat(self.num_entities, ),
tail_entity), dim=1)
return self.model.forward(x)
if isinstance(head_entity, list):
head_entity = torch.LongTensor([self.entity_to_idx[i] for i in head_entity])
else:
head_entity = torch.LongTensor([self.entity_to_idx[head_entity]])
if isinstance(relation, list):
relation = torch.LongTensor([self.relation_to_idx[i] for i in relation])
else:
relation = torch.LongTensor([self.relation_to_idx[relation]])

x = torch.stack((head_entity.repeat(self.num_entities, ),
relation.repeat(self.num_entities, ),
tail_entity), dim=1)
return self.model(x)

def predict(self, *, h: Union[List[str], str] = None, r: Union[List[str], str] = None,
t: Union[List[str], str] = None) -> torch.FloatTensor:
t: Union[List[str], str] = None, within=None, logits=True) -> torch.FloatTensor:
"""
Predict missing triples by means of
Parameters
----------
logits
h
r
t
within
Returns
-------
"""
# (1) Sanity checking.
if h is not None:
Expand All @@ -198,25 +235,29 @@ def predict(self, *, h: Union[List[str], str] = None, r: Union[List[str], str] =
assert r is not None
assert t is not None
# ? r, t
scores = self.predict_missing_head_entity(r, t)
scores = self.predict_missing_head_entity(r, t, within)
# (3) Predict missing relation given a head entity and a tail entity.
elif r is None:
assert h is not None
assert t is not None
# h ? t
scores = self.predict_missing_relations(h, t)
scores = self.predict_missing_relations(h, t, within)
# (4) Predict missing tail entity given a head entity and a relation
elif t is None:
assert h is not None
assert r is not None
# h r ?
scores = self.predict_missing_tail_entity(h, r)
scores = self.predict_missing_tail_entity(h, r, within)
else:
scores=self.triple_score(h, r, t, logits=True)

if logits:
return scores
else:
scores = self.triple_score(h, r, t)
return torch.sigmoid(scores)
return torch.sigmoid(scores)

def predict_topk(self, *, h: List[str] = None, r: List[str] = None, t: List[str] = None,
topk: int = 10):
topk: int = 10, within: List[str] = None):
"""
Predict missing item in a given triple.
Expand Down Expand Up @@ -259,7 +300,7 @@ def predict_topk(self, *, h: List[str] = None, r: List[str] = None, t: List[str]
assert r is not None
assert t is not None
# ? r, t
scores = self.predict_missing_head_entity(r, t).flatten()
scores = self.predict_missing_head_entity(r, t, within=within).flatten()
if self.apply_semantic_constraint:
# filter the scores
for th, i in enumerate(r):
Expand All @@ -274,7 +315,7 @@ def predict_topk(self, *, h: List[str] = None, r: List[str] = None, t: List[str]
assert h is not None
assert t is not None
# h ? t
scores = self.predict_missing_relations(h, t).flatten()
scores = self.predict_missing_relations(h, t, within=within).flatten()
sort_scores, sort_idxs = torch.topk(scores, topk)
return [(self.idx_to_relations[idx_top_entity], scores.item()) for idx_top_entity, scores in
zip(sort_idxs.tolist(), torch.sigmoid(sort_scores))]
Expand All @@ -284,7 +325,7 @@ def predict_topk(self, *, h: List[str] = None, r: List[str] = None, t: List[str]
assert h is not None
assert r is not None
# h r ?t
scores = self.predict_missing_tail_entity(h, r).flatten()
scores = self.predict_missing_tail_entity(h, r, within=within).flatten()
if self.apply_semantic_constraint:
# filter the scores
for th, i in enumerate(r):
Expand Down Expand Up @@ -360,11 +401,10 @@ def triple_score(self, h: Union[List[str], str] = None, r: Union[List[str], str]
raise NotImplementedError()
else:
with torch.no_grad():
out = self.model(x)
if logits:
return out
return self.model(x)
else:
return torch.sigmoid(out)
return torch.sigmoid(self.model(x))

def t_norm(self, tens_1: torch.Tensor, tens_2: torch.Tensor, tnorm: str = 'min') -> torch.Tensor:
if 'min' in tnorm:
Expand Down
8 changes: 4 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def get_default_arguments(description=None):
parser = pl.Trainer.add_argparse_args(argparse.ArgumentParser(add_help=False))
# Default Trainer param https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#methods
# Knowledge graph related arguments
parser.add_argument("--dataset_dir", type=str, default="KGs/NELL-995-h100",
parser.add_argument("--dataset_dir", type=str, default="KGs/Countries-S3",
help="The path of a folder containing train.txt, and/or valid.txt and/or test.txt"
",e.g., KGs/UMLS")
parser.add_argument("--sparql_endpoint", type=str, default=None,
Expand Down Expand Up @@ -45,10 +45,10 @@ def get_default_arguments(description=None):
choices=['Adam', 'SGD'])
parser.add_argument('--embedding_dim', type=int, default=32,
help='Number of dimensions for an embedding vector. ')
parser.add_argument("--num_epochs", type=int, default=256, help='Number of epochs for training. ')
parser.add_argument("--num_epochs", type=int, default=2000, help='Number of epochs for training. ')
parser.add_argument('--batch_size', type=int, default=1024,
help='Mini batch size. If None, automatic batch finder is applied')
parser.add_argument("--lr", type=float, default=0.1)
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument('--callbacks', type=json.loads,
default={},
help='{"PPE":{ "last_percent_to_consider": 10}}'
Expand Down Expand Up @@ -102,7 +102,7 @@ def get_default_arguments(description=None):
parser.add_argument('--pykeen_model_kwargs', type=json.loads, default={})
# WIP
parser.add_argument("--byte_pair_encoding",
action="store_true",
action="store_false",
help="Currently only avail. for KGE implemented within dice-embeddings.")
if description is None:
return parser.parse_args()
Expand Down
4 changes: 2 additions & 2 deletions tests/test_inductive_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@ def test_inductive(self):
assert result['Val']['MRR'] >= 0.78
assert result['Test']['MRR'] >= 0.78
pre_trained_kge = KGE(path=result['path_experiment_folder'])
assert pre_trained_kge.predict(h="alga", r="isa", t="entity") >= 0.55
assert pre_trained_kge.predict(h="Demir", r="loves", t="Embeddings") >= 0.49
assert (pre_trained_kge.predict(h="alga", r="isa", t="entity", logits=False) >
pre_trained_kge.predict(h="Demir", r="loves", t="Embeddings", logits=False))

0 comments on commit cf82885

Please sign in to comment.