Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix gradio #199

Merged
merged 2 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions dicee/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,22 @@ def get_aswa_state_dict(self, model):
return ensemble_state_dict

def decide(self, running_model_state_dict, ensemble_state_dict, val_running_model, mrr_updated_ensemble_model):
"""
Hard Update
Soft Update
Rejection

Parameters
----------
running_model_state_dict
ensemble_state_dict
val_running_model
mrr_updated_ensemble_model

Returns
-------

"""
if val_running_model > mrr_updated_ensemble_model and val_running_model > self.val_aswa:
"""Hard Update """
torch.save(running_model_state_dict, f=f"{self.path}/aswa.pt")
Expand Down Expand Up @@ -273,17 +289,17 @@ def on_train_epoch_end(self, trainer, model):
self.val_aswa = val_running_model
return True
else:

# (1) Load ASWA ensemble
ensemble_state_dict = self.get_aswa_state_dict(model)

# Evaluate
# (2) Evaluate (1) on the validation data.
ensemble = type(model)(model.args)
ensemble.load_state_dict(ensemble_state_dict)
mrr_updated_ensemble_model = trainer.evaluator.eval(dataset=trainer.dataset, trained_model=ensemble,
form_of_labelling=trainer.form_of_labelling,
during_training=True)["Val"]["MRR"]
print(f"MRR Running {val_running_model:.4f} | MRR ASWA: {self.val_aswa:.4f} |ASWA|:{sum(self.alphas)}")

# print(f"| MRR Running {val_running_model:.4f} | MRR ASWA: {self.val_aswa:.4f} |ASWA|:{sum(self.alphas)}")
# (3) Update or not
self.decide(model.state_dict(), ensemble_state_dict, val_running_model, mrr_updated_ensemble_model)


Expand Down
10 changes: 5 additions & 5 deletions dicee/knowledge_graph_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,11 +1050,11 @@ def predict(str_subject: str, str_predicate: str, str_object: str, random_exampl

gr.Interface(
fn=predict,
inputs=[gr.inputs.Textbox(lines=1, placeholder=None, label='Subject'),
gr.inputs.Textbox(lines=1, placeholder=None, label='Predicate'),
gr.inputs.Textbox(lines=1, placeholder=None, label='Object'), "checkbox"],
outputs=[gr.outputs.Textbox(label='Input Triple'),
gr.outputs.Dataframe(label='Outputs', type='pandas')],
inputs=[gr.Textbox(lines=1, placeholder=None, label='Subject'),
gr.Textbox(lines=1, placeholder=None, label='Predicate'),
gr.Textbox(lines=1, placeholder=None, label='Object'), "checkbox"],
outputs=[gr.Textbox(label='Input Triple'),
gr.Dataframe(label='Outputs', type='pandas')],
title=f'{self.name} Deployment',
description='1. Enter a triple to compute its score,\n'
'2. Enter a subject and predicate pair to obtain most likely top ten entities or\n'
Expand Down
20 changes: 18 additions & 2 deletions dicee/models/base_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from typing import List, Any, Tuple, Union, Dict
import pytorch_lightning
import numpy as np
Expand Down Expand Up @@ -477,10 +476,27 @@ def get_sentence_representation(self, x: torch.LongTensor):
tail_emb = self.token_embeddings(t)
return head_ent_emb, rel_emb, tail_emb

def get_bpe_head_and_relation_representation(self, x: torch.LongTensor):
def get_bpe_head_and_relation_representation(self, x: torch.LongTensor) -> Tuple[
torch.FloatTensor, torch.FloatTensor]:
"""

Parameters
----------
x

Returns
-------

"""
h, r = x[:, 0, :], x[:, 1, :]
# N, T, D
head_ent_emb = self.token_embeddings(h)
# N, T, D
rel_emb = self.token_embeddings(r)
# A sequence of sub-list embeddings representing an embedding of a head entity should be normalized to 0.
# Therefore, the norm of a row vector obtained from T by D matrix must be 1.
head_ent_emb = F.normalize(head_ent_emb, p=2, dim=(1, 2))
rel_emb = F.normalize(rel_emb, p=2, dim=(1, 2))
return head_ent_emb, rel_emb

def get_embeddings(self) -> Tuple[np.ndarray, np.ndarray]:
Expand Down
14 changes: 7 additions & 7 deletions dicee/static_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,22 +456,22 @@ def deploy_tail_entity_prediction(pre_trained_kge, str_subject, str_predicate, t
if pre_trained_kge.model.name == 'Shallom':
print('Tail entity prediction is not available for Shallom')
raise NotImplementedError
scores, entity = pre_trained_kge.predict_topk(h=[str_subject], r=[str_predicate], topk=top_k)
return f'( {str_subject}, {str_predicate}, ? )', pd.DataFrame({'Entity': entity, 'Score': scores})
str_entity_scores = pre_trained_kge.predict_topk(h=[str_subject], r=[str_predicate], topk=top_k)

return f'( {str_subject}, {str_predicate}, ? )', pd.DataFrame(str_entity_scores,columns=["entity","score"])


def deploy_head_entity_prediction(pre_trained_kge, str_object, str_predicate, top_k):
if pre_trained_kge.model.name == 'Shallom':
print('Head entity prediction is not available for Shallom')
raise NotImplementedError

scores, entity = pre_trained_kge.predict_topk(t=[str_object], r=[str_predicate], topk=top_k)
return f'( ?, {str_predicate}, {str_object} )', pd.DataFrame({'Entity': entity, 'Score': scores})
str_entity_scores = pre_trained_kge.predict_topk(t=[str_object], r=[str_predicate], topk=top_k)
return f'( ?, {str_predicate}, {str_object} )', pd.DataFrame(str_entity_scores,columns=["entity","score"])


def deploy_relation_prediction(pre_trained_kge, str_subject, str_object, top_k):
scores, relations = pre_trained_kge.predict_topk(h=[str_subject], t=[str_object], topk=top_k)
return f'( {str_subject}, ?, {str_object} )', pd.DataFrame({'Relations': relations, 'Score': scores})
str_relation_scores = pre_trained_kge.predict_topk(h=[str_subject], t=[str_object], topk=top_k)
return f'( {str_subject}, ?, {str_object} )', pd.DataFrame(str_relation_scores,columns=["relation","score"])


@timeit
Expand Down
4 changes: 2 additions & 2 deletions tests/test_inductive_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_inductive(self):
args.optim = 'Adam'
args.num_epochs = 500
args.batch_size = 1024
args.lr = 0.001
args.lr = 0.1
args.input_dropout_rate = 0.0
args.hidden_dropout_rate = 0.0
args.feature_map_dropout_rate = 0.0
Expand All @@ -29,7 +29,7 @@ def test_inductive(self):
result = Execute(args).start()
assert result['Train']['MRR'] >= 0.88
assert result['Val']['MRR'] >= 0.78
assert result['Test']['MRR'] >= 0.78
assert result['Test']['MRR'] >= 0.77
pre_trained_kge = KGE(path=result['path_experiment_folder'])
assert (pre_trained_kge.predict(h="alga", r="isa", t="entity", logits=False) >
pre_trained_kge.predict(h="Demir", r="loves", t="Embeddings", logits=False))
Loading