Skip to content

Commit

Permalink
Refactoring adaptive swa test for the sake of increasing the test time
Browse files Browse the repository at this point in the history
  • Loading branch information
Demirrr committed Nov 29, 2023
1 parent ae84a2a commit cbf3a51
Showing 1 changed file with 3 additions and 63 deletions.
66 changes: 3 additions & 63 deletions tests/test_adaptive_swa.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def test_k_vs_all_lowest(self):
assert aswa_report["Test"]["MRR"] > swa_report["Test"]["MRR"]
assert aswa_report["Test"]["H@1"] > swa_report["Test"]["H@1"]

"""
@pytest.mark.filterwarnings('ignore::UserWarning')
def test_k_vs_all_low(self):
args = Namespace()
Expand Down Expand Up @@ -106,66 +108,4 @@ def test_k_vs_all_mid(self):
assert aswa_report["Val"]["MRR"] > swa_report["Val"]["MRR"]
assert aswa_report["Test"]["MRR"] > swa_report["Test"]["MRR"]
assert 0.88 > aswa_report["Test"]["MRR"] > swa_report["Test"]["MRR"] > 0.75

@pytest.mark.filterwarnings('ignore::UserWarning')
def test_k_vs_all_high(self):
args = Namespace()
args.model = 'Keci'
args.p = 0
args.q = 1
args.scoring_technique = "KvsAll"
args.dataset_dir = "KGs/UMLS"
args.trainer = "PL"
args.num_epochs = 500
args.lr = 0.1
args.embedding_dim = 32
args.batch_size = 1024
args.adaptive_swa = True
aswa_report = Execute(args).start()

args = Namespace()
args.model = 'Keci'
args.p = 0
args.q = 1
args.scoring_technique = "KvsAll"
args.dataset_dir = "KGs/UMLS"
args.trainer = "PL"
args.num_epochs = 500
args.lr = 0.1
args.embedding_dim = 32
args.batch_size = 1024
args.stochastic_weight_avg = True
swa_report = Execute(args).start()
assert aswa_report["Test"]["MRR"] > swa_report["Test"]["MRR"]
assert aswa_report["Test"]["H@1"] > swa_report["Test"]["H@1"]

def test_1_vs_all_high(self):
args = Namespace()
args.model = 'Keci'
args.p = 0
args.q = 1
args.scoring_technique = "1vsAll"
args.dataset_dir = "KGs/UMLS"
args.trainer = "PL"
args.num_epochs = 120
args.lr = 0.1
args.embedding_dim = 32
args.batch_size = 1024
args.adaptive_swa = True
aswa_report = Execute(args).start()

args = Namespace()
args.model = 'Keci'
args.p = 0
args.q = 1
args.scoring_technique = "1vsAll"
args.dataset_dir = "KGs/UMLS"
args.trainer = "PL"
args.num_epochs = 120
args.lr = 0.1
args.embedding_dim = 32
args.batch_size = 1024
args.stochastic_weight_avg = True
swa_report = Execute(args).start()
assert aswa_report["Test"]["MRR"] >= swa_report["Test"]["MRR"]
assert aswa_report["Test"]["H@1"] >= swa_report["Test"]["H@1"]
"""

0 comments on commit cbf3a51

Please sign in to comment.