Skip to content

Commit

Permalink
Fixed the relative paths in the finetuning runner
Browse files Browse the repository at this point in the history
  • Loading branch information
ChaoPang committed Oct 9, 2024
1 parent 4ba50e2 commit 999a06d
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/cehrbert/runners/hf_cehrbert_finetune_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,11 @@ def compute_metrics(references: Union[List[float], pd.Series], probs: Union[List
def load_pretrained_tokenizer(
model_args,
) -> CehrBertTokenizer:
tokenizer_name_or_path = os.path.expanduser(model_args.tokenizer_name_or_path)
try:
return CehrBertTokenizer.from_pretrained(model_args.tokenizer_name_or_path)
return CehrBertTokenizer.from_pretrained(tokenizer_name_or_path)
except Exception:
raise ValueError(f"Can not load the pretrained tokenizer from {model_args.tokenizer_name_or_path}")
raise ValueError(f"Can not load the pretrained tokenizer from {tokenizer_name_or_path}")


def load_finetuned_model(model_args: ModelArguments, model_name_or_path: str) -> CehrBertPreTrainedModel:
Expand All @@ -81,6 +82,7 @@ def load_finetuned_model(model_args: ModelArguments, model_name_or_path: str) ->
f"finetune_model_type can be one of the following types {[e.value for e in FineTuneModelType]}"
)
# Try to create a new model based on the base model
model_name_or_path = os.path.expanduser(model_name_or_path)
try:
return finetune_model_cls.from_pretrained(model_name_or_path)
except ValueError:
Expand Down

0 comments on commit 999a06d

Please sign in to comment.