diff --git a/pytorch3dunet/unet3d/trainer.py b/pytorch3dunet/unet3d/trainer.py index 4b59d568..e3a12062 100644 --- a/pytorch3dunet/unet3d/trainer.py +++ b/pytorch3dunet/unet3d/trainer.py @@ -142,7 +142,7 @@ def __init__(self, model, optimizer, lr_scheduler, loss_criterion, eval_criterio elif pre_trained is not None: logger.info(f"Logging pre-trained model from '{pre_trained}'...") utils.load_checkpoint(pre_trained, self.model, None) - if 'checkpoint_dir' not in kwargs: + if not self.checkpoint_dir: self.checkpoint_dir = os.path.split(pre_trained)[0] def fit(self):