diff --git a/dicee/callbacks.py b/dicee/callbacks.py index aa28979a..40ac6fd5 100644 --- a/dicee/callbacks.py +++ b/dicee/callbacks.py @@ -330,7 +330,7 @@ def on_train_epoch_end(self, trainer, model): self.val_aswa= val_running_model torch.save(model.state_dict(), f=f"{self.path}/trainer_checkpoint_main.pt") self.sample_counter = 1 - print(f"Hard Update: MRR: {self.val_aswa}") + print(f" Hard Update: MRR: {self.val_aswa:.4f}") else: # Load ensemble ensemble_state_dict = torch.load(f"{self.path}/trainer_checkpoint_main.pt", torch.device(model.device)) @@ -351,7 +351,7 @@ def on_train_epoch_end(self, trainer, model): self.val_aswa = mrr_updated_ensemble_model torch.save(ensemble_state_dict, f=f"{self.path}/trainer_checkpoint_main.pt") self.sample_counter += 1 - print(f" Soft Update: MRR: {self.val_aswa} | |ASWA|:{self.sample_counter}") + print(f" Soft Update: MRR: {self.val_aswa:.4f} | |ASWA|:{self.sample_counter}") else: print(" No update") diff --git a/tests/test_adaptive_swa.py b/tests/test_adaptive_swa.py new file mode 100644 index 00000000..a5752404 --- /dev/null +++ b/tests/test_adaptive_swa.py @@ -0,0 +1,38 @@ +from dicee.executer import Execute +import pytest +from dicee.config import Namespace + +class TestASWA: + @pytest.mark.filterwarnings('ignore::UserWarning') + def test_k_vs_all(self): + args = Namespace() + args.model = 'Keci' + args.p = 0 + args.q = 1 + args.scoring_technique = "KvsAll" + args.dataset_dir = "KGs/UMLS" + args.trainer = "PL" + args.num_epochs = 200 + args.lr = 0.1 + args.embedding_dim = 32 + args.batch_size = 1024 + args.adaptive_swa = True + aswa_report = Execute(args).start() + + args = Namespace() + args.model = 'Keci' + args.p = 0 + args.q = 1 + args.scoring_technique = "KvsAll" + args.dataset_dir = "KGs/UMLS" + args.trainer = "PL" + args.num_epochs = 200 + args.lr = 0.1 + args.embedding_dim = 32 + args.batch_size = 1024 + args.stochastic_weight_avg = True + swa_report = Execute(args).start() + + assert aswa_report["Val"]["MRR"]>swa_report["Val"]["MRR"] + assert aswa_report["Test"]["MRR"]>swa_report["Test"]["MRR"] + assert 0.88 > aswa_report["Test"]["MRR"]>swa_report["Test"]["MRR"] >0.75 \ No newline at end of file