Skip to content

Commit

Permalink
Fix: BPE
Browse files Browse the repository at this point in the history
  • Loading branch information
Demirrr committed Nov 30, 2023
1 parent c27337d commit cef4626
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 8 deletions.
24 changes: 20 additions & 4 deletions dicee/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,22 @@ def get_aswa_state_dict(self, model):
return ensemble_state_dict

def decide(self, running_model_state_dict, ensemble_state_dict, val_running_model, mrr_updated_ensemble_model):
"""
Hard Update
Soft Update
Rejection
Parameters
----------
running_model_state_dict
ensemble_state_dict
val_running_model
mrr_updated_ensemble_model
Returns
-------
"""
if val_running_model > mrr_updated_ensemble_model and val_running_model > self.val_aswa:
"""Hard Update """
torch.save(running_model_state_dict, f=f"{self.path}/aswa.pt")
Expand Down Expand Up @@ -273,17 +289,17 @@ def on_train_epoch_end(self, trainer, model):
self.val_aswa = val_running_model
return True
else:

# (1) Load ASWA ensemble
ensemble_state_dict = self.get_aswa_state_dict(model)

# Evaluate
# (2) Evaluate (1) on the validation data.
ensemble = type(model)(model.args)
ensemble.load_state_dict(ensemble_state_dict)
mrr_updated_ensemble_model = trainer.evaluator.eval(dataset=trainer.dataset, trained_model=ensemble,
form_of_labelling=trainer.form_of_labelling,
during_training=True)["Val"]["MRR"]
print(f"MRR Running {val_running_model:.4f} | MRR ASWA: {self.val_aswa:.4f} |ASWA|:{sum(self.alphas)}")

# print(f"| MRR Running {val_running_model:.4f} | MRR ASWA: {self.val_aswa:.4f} |ASWA|:{sum(self.alphas)}")
# (3) Update or not
self.decide(model.state_dict(), ensemble_state_dict, val_running_model, mrr_updated_ensemble_model)


Expand Down
20 changes: 18 additions & 2 deletions dicee/models/base_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from typing import List, Any, Tuple, Union, Dict
import pytorch_lightning
import numpy as np
Expand Down Expand Up @@ -477,10 +476,27 @@ def get_sentence_representation(self, x: torch.LongTensor):
tail_emb = self.token_embeddings(t)
return head_ent_emb, rel_emb, tail_emb

def get_bpe_head_and_relation_representation(self, x: torch.LongTensor):
def get_bpe_head_and_relation_representation(self, x: torch.LongTensor) -> Tuple[
torch.FloatTensor, torch.FloatTensor]:
"""
Parameters
----------
x
Returns
-------
"""
h, r = x[:, 0, :], x[:, 1, :]
# N, T, D
head_ent_emb = self.token_embeddings(h)
# N, T, D
rel_emb = self.token_embeddings(r)
# A sequence of sub-list embeddings representing an embedding of a head entity should be normalized to 0.
# Therefore, the norm of a row vector obtained from T by D matrix must be 1.
head_ent_emb = F.normalize(head_ent_emb, p=2, dim=(1, 2))
rel_emb = F.normalize(rel_emb, p=2, dim=(1, 2))
return head_ent_emb, rel_emb

def get_embeddings(self) -> Tuple[np.ndarray, np.ndarray]:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_inductive_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_inductive(self):
args.optim = 'Adam'
args.num_epochs = 500
args.batch_size = 1024
args.lr = 0.001
args.lr = 0.1
args.input_dropout_rate = 0.0
args.hidden_dropout_rate = 0.0
args.feature_map_dropout_rate = 0.0
Expand All @@ -29,7 +29,7 @@ def test_inductive(self):
result = Execute(args).start()
assert result['Train']['MRR'] >= 0.88
assert result['Val']['MRR'] >= 0.78
assert result['Test']['MRR'] >= 0.78
assert result['Test']['MRR'] >= 0.77
pre_trained_kge = KGE(path=result['path_experiment_folder'])
assert (pre_trained_kge.predict(h="alga", r="isa", t="entity", logits=False) >
pre_trained_kge.predict(h="Demir", r="loves", t="Embeddings", logits=False))

0 comments on commit cef4626

Please sign in to comment.