Skip to content

Commit

Permalink
Make sure task-specific checkpoints are cleared when replaced. (#838)
Browse files Browse the repository at this point in the history
  • Loading branch information
sleepinyourhat authored Jul 12, 2019
1 parent c713766 commit 8a1b3c0
Showing 1 changed file with 16 additions and 12 deletions.
28 changes: 16 additions & 12 deletions jiant/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,18 +992,22 @@ def _description_from_metrics(self, metrics):
""" format some metrics as a string """
return ", ".join(["%s: %.4f" % (name, value) for name, value in metrics.items()])

def _unmark_previous_best(self, phase, val_pass, task=""):
def _unmark_previous_best(self, phase, val_pass, task_dir_name=""):
marked_best = glob.glob(
os.path.join(self._serialization_dir, task, "*_state_{}_val_*.best.th".format(phase))
os.path.join(
self._serialization_dir, task_dir_name, "*_state_{}_val_*.best.th".format(phase)
)
)
for file in marked_best:
# Skip the just-written checkpoint.
if "_{}.".format(val_pass) not in file:
os.rename(file, re.sub("%s$" % (".best.th"), ".th", file))

def _delete_old_checkpoints(self, phase, val_pass, task=""):
def _delete_old_checkpoints(self, phase, val_pass, task_dir_name=""):
candidates = glob.glob(
os.path.join(self._serialization_dir + task, "*_state_{}_val_*.th".format(phase))
os.path.join(
self._serialization_dir, task_dir_name, "*_state_{}_val_*.th".format(phase)
)
)
for file in candidates:
# Skip the best, because we'll need it.
Expand Down Expand Up @@ -1032,16 +1036,16 @@ def _save_checkpoint(self, training_state, phase="pretrain", new_best=False, tas
else:
best_str = ""

task_directory = ""
task_dir_name = ""

if phase == "target_train":
# We only pass in one task at a time during target train phase.
assert len(tasks) == 1
task_directory = tasks[0].name
task_dir_name = tasks[0].name

model_path = os.path.join(
self._serialization_dir,
task_directory,
task_dir_name,
"model_state_{}_val_{}{}.th".format(phase, val_pass, best_str),
)

Expand Down Expand Up @@ -1073,15 +1077,15 @@ def _save_checkpoint(self, training_state, phase="pretrain", new_best=False, tas
task_states,
os.path.join(
self._serialization_dir,
task_directory,
task_dir_name,
"task_state_{}_val_{}{}.th".format(phase, val_pass, best_str),
),
)
torch.save(
metric_states,
os.path.join(
self._serialization_dir,
task_directory,
task_dir_name,
"metric_state_{}_val_{}{}.th".format(phase, val_pass, best_str),
),
)
Expand All @@ -1090,15 +1094,15 @@ def _save_checkpoint(self, training_state, phase="pretrain", new_best=False, tas
training_state,
os.path.join(
self._serialization_dir,
task_directory,
task_dir_name,
"training_state_{}_val_{}{}.th".format(phase, val_pass, best_str),
),
)
if new_best:
self._unmark_previous_best(phase, val_pass, task_directory)
self._unmark_previous_best(phase, val_pass, task_dir_name)

if not self._keep_all_checkpoints:
self._delete_old_checkpoints(phase, val_pass)
self._delete_old_checkpoints(phase, val_pass, task_dir_name)

log.info("Saved checkpoints to %s", self._serialization_dir)

Expand Down

0 comments on commit 8a1b3c0

Please sign in to comment.