Skip to content

Commit

Permalink
WIP: ASWA GPU exp
Browse files Browse the repository at this point in the history
  • Loading branch information
Caglar Demir committed Nov 21, 2023
1 parent 508e2eb commit 5ce7e2c
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 20 deletions.
52 changes: 32 additions & 20 deletions dicee/analyse_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self):
self.batch_size = []
self.lr = []
self.byte_pair_encoding = []

self.aswa=[]
self.path_dataset_folder = []
self.full_storage_path = []
self.pq = []
Expand Down Expand Up @@ -50,6 +50,7 @@ def save_experiment(self, x):
self.lr.append(x['lr'])

self.byte_pair_encoding.append(x["byte_pair_encoding"])
self.aswa.append(x["adaptive_swa"])
self.path_dataset_folder.append(x['dataset_dir'])
self.pq.append((x['p'], x['q']))
self.runtime.append(x['Runtime'])
Expand Down Expand Up @@ -79,17 +80,24 @@ def save_experiment(self, x):

def to_df(self):
return pd.DataFrame(
dict(model_name=self.model_name,
dict(model=self.model_name,
byte_pair_encoding=self.byte_pair_encoding,
path_dataset_folder=self.path_dataset_folder,
train_mrr=self.train_mrr, train_h1=self.train_h1,
train_h3=self.train_h3, train_h10=self.train_h10,
aswa=self.aswa,
Dataset=self.path_dataset_folder,
trainMRR=self.train_mrr,
trainH1=self.train_h1,
trainH3=self.train_h3,
trainH10=self.train_h10,
num_epochs=self.num_epochs,
#full_storage_path=self.full_storage_path,
val_mrr=self.val_mrr, val_h1=self.val_h1,
val_h3=self.val_h3, val_h10=self.val_h10,
test_mrr=self.test_mrr, test_h1=self.test_h1,
test_h3=self.test_h3, test_h10=self.test_h10,
full_storage_path=self.full_storage_path,
valMRR=self.val_mrr,
valH1=self.val_h1,
valH3=self.val_h3,
valH10=self.val_h10,
testMRR=self.test_mrr,
testH1=self.test_h1,
testH3=self.test_h3,
testH10=self.test_h10,
runtime=self.runtime,
params=self.num_params,
callbacks=self.callbacks,
Expand All @@ -111,27 +119,31 @@ def analyse(args):
['model', 'dataset_dir', 'embedding_dim',
'normalization', 'num_epochs', 'batch_size', 'lr',
'callbacks',
'scoring_technique',
'adaptive_swa',
"scoring_technique",
"byte_pair_encoding",
'dataset_dir', 'p', 'q']}
with open(f'{full_path}/report.json', 'r') as f:
report = json.load(f)
report = {i: report[i] for i in ['Runtime', 'NumParam']}
with open(f'{full_path}/eval_report.json', 'r') as f:
eval_report = json.load(f)

try:
with open(f'{full_path}/report.json', 'r') as f:
report = json.load(f)
report = {i: report[i] for i in ['Runtime', 'NumParam']}
with open(f'{full_path}/eval_report.json', 'r') as f:
eval_report = json.load(f)
except:
print("NOT found")
continue
config.update(eval_report)
config.update(report)
print(config)
experiments.append(config)

counter = Experiment()

for i in experiments:
counter.save_experiment(i)

df = counter.to_df()
df.sort_values(by=['test_mrr'], ascending=False, inplace=True)
df.sort_values(by=['testMRR'], ascending=False, inplace=True)
pd.set_option("display.precision", 3)
# print(df)
print(df.to_latex(index=False, float_format="%.3f"))
Expand Down
6 changes: 6 additions & 0 deletions dicee/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def __init__(self, num_epochs, path):
self.entered_good_regions = None
self.alphas = None
self.val_aswa=-1
self.num_rejects=0

def on_fit_end(self, trainer, model):
"""
Expand Down Expand Up @@ -267,6 +268,9 @@ def on_train_epoch_end(self, trainer, model):
self.initial_eval_setting = trainer.evaluator.args.eval_model
trainer.evaluator.args.eval_model = "val"

if self.num_rejects >= int(self.num_epochs *0.25):
return True

val_running_model = self.compute_mrr(trainer, model)
# self.val_aswa is initialized as -1
if val_running_model > self.val_aswa:
Expand Down Expand Up @@ -298,6 +302,8 @@ def on_train_epoch_end(self, trainer, model):
print(f" Soft Update: MRR: {self.val_aswa:.4f} | |ASWA|:{self.sample_counter}")
else:
print(" No update")
self.num_rejects+=1


class FPPE(AbstractPPECallback):
"""
Expand Down

0 comments on commit 5ce7e2c

Please sign in to comment.