Skip to content

Commit

Permalink
Regression test for ASWA vs SWA
Browse files Browse the repository at this point in the history
  • Loading branch information
Demirrr committed Nov 21, 2023
1 parent 4e8cd74 commit 067148a
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 2 deletions.
4 changes: 2 additions & 2 deletions dicee/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def on_train_epoch_end(self, trainer, model):
self.val_aswa= val_running_model
torch.save(model.state_dict(), f=f"{self.path}/trainer_checkpoint_main.pt")
self.sample_counter = 1
print(f"Hard Update: MRR: {self.val_aswa}")
print(f" Hard Update: MRR: {self.val_aswa:.4f}")
else:
# Load ensemble
ensemble_state_dict = torch.load(f"{self.path}/trainer_checkpoint_main.pt", torch.device(model.device))
Expand All @@ -351,7 +351,7 @@ def on_train_epoch_end(self, trainer, model):
self.val_aswa = mrr_updated_ensemble_model
torch.save(ensemble_state_dict, f=f"{self.path}/trainer_checkpoint_main.pt")
self.sample_counter += 1
print(f" Soft Update: MRR: {self.val_aswa} | |ASWA|:{self.sample_counter}")
print(f" Soft Update: MRR: {self.val_aswa:.4f} | |ASWA|:{self.sample_counter}")
else:
print(" No update")

Expand Down
38 changes: 38 additions & 0 deletions tests/test_adaptive_swa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from dicee.executer import Execute
import pytest
from dicee.config import Namespace

class TestASWA:
@pytest.mark.filterwarnings('ignore::UserWarning')
def test_k_vs_all(self):
args = Namespace()
args.model = 'Keci'
args.p = 0
args.q = 1
args.scoring_technique = "KvsAll"
args.dataset_dir = "KGs/UMLS"
args.trainer = "PL"
args.num_epochs = 200
args.lr = 0.1
args.embedding_dim = 32
args.batch_size = 1024
args.adaptive_swa = True
aswa_report = Execute(args).start()

args = Namespace()
args.model = 'Keci'
args.p = 0
args.q = 1
args.scoring_technique = "KvsAll"
args.dataset_dir = "KGs/UMLS"
args.trainer = "PL"
args.num_epochs = 200
args.lr = 0.1
args.embedding_dim = 32
args.batch_size = 1024
args.stochastic_weight_avg = True
swa_report = Execute(args).start()

assert aswa_report["Val"]["MRR"]>swa_report["Val"]["MRR"]
assert aswa_report["Test"]["MRR"]>swa_report["Test"]["MRR"]
assert 0.88 > aswa_report["Test"]["MRR"]>swa_report["Test"]["MRR"] >0.75

0 comments on commit 067148a

Please sign in to comment.