Skip to content

Commit

Permalink
removed the abosolute model and tokenizer paths from pretraining and …
Browse files Browse the repository at this point in the history
…finetuning (#50)
  • Loading branch information
ChaoPang authored Sep 6, 2024
1 parent 8511dde commit 3d09645
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 13 deletions.
6 changes: 2 additions & 4 deletions src/cehrbert/runners/hf_cehrbert_finetune_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ def load_pretrained_model_and_tokenizer(
) -> Tuple[CehrBertPreTrainedModel, CehrBertTokenizer]:
# Try to load the pretrained tokenizer
try:
tokenizer_abspath = os.path.abspath(model_args.tokenizer_name_or_path)
tokenizer = CehrBertTokenizer.from_pretrained(tokenizer_abspath)
tokenizer = CehrBertTokenizer.from_pretrained(model_args.tokenizer_name_or_path)
except Exception:
raise ValueError(f"Can not load the pretrained tokenizer from {model_args.tokenizer_name_or_path}")

Expand All @@ -82,8 +81,7 @@ def load_pretrained_model_and_tokenizer(

# Try to load the pretrained model
try:
model_abspath = os.path.abspath(model_args.model_name_or_path)
model = finetune_model_cls.from_pretrained(model_abspath)
model = finetune_model_cls.from_pretrained(model_args.model_name_or_path)
except Exception as e:
LOG.warning(e)
model_config = CehrBertConfig(
Expand Down
15 changes: 6 additions & 9 deletions src/cehrbert/runners/hf_cehrbert_pretrain_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
from cehrbert.models.hf_models.config import CehrBertConfig
from cehrbert.models.hf_models.hf_cehrbert import CehrBertForPreTraining
from cehrbert.models.hf_models.tokenization_hf_cehrbert import CehrBertTokenizer

from .hf_runner_argument_dataclass import DataTrainingArguments, ModelArguments
from .runner_util import (
from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments, ModelArguments
from cehrbert.runners.runner_util import (
generate_prepared_ds_path,
get_last_hf_checkpoint,
get_meds_extension_path,
Expand Down Expand Up @@ -54,22 +53,21 @@ def load_and_create_tokenizer(
tokenizer = load_and_create_tokenizer(data_args, model_args, dataset)
"""
# Try to load the pretrained tokenizer
tokenizer_abspath = os.path.abspath(model_args.tokenizer_name_or_path)
try:
tokenizer = CehrBertTokenizer.from_pretrained(tokenizer_abspath)
tokenizer = CehrBertTokenizer.from_pretrained(model_args.tokenizer_name_or_path)
except (OSError, RuntimeError, FileNotFoundError, json.JSONDecodeError) as e:
LOG.warning(
"Failed to load the tokenizer from %s with the error "
"\n%s\nTried to create the tokenizer, however the dataset is not provided.",
tokenizer_abspath,
model_args.tokenizer_name_or_path,
e,
)
if dataset is None:
raise e
tokenizer = CehrBertTokenizer.train_tokenizer(
dataset, feature_names=["concept_ids"], concept_name_mapping={}, data_args=data_args
)
tokenizer.save_pretrained(tokenizer_abspath)
tokenizer.save_pretrained(model_args.tokenizer_name_or_path)

return tokenizer

Expand All @@ -95,8 +93,7 @@ def load_and_create_model(model_args: ModelArguments, tokenizer: CehrBertTokeniz
model = load_and_create_model(model_args, tokenizer)
"""
try:
model_abspath = os.path.abspath(model_args.model_name_or_path)
model_config = AutoConfig.from_pretrained(model_abspath)
model_config = AutoConfig.from_pretrained(model_args.model_name_or_path)
except (OSError, ValueError, FileNotFoundError, json.JSONDecodeError) as e:
LOG.warning(e)
model_config = CehrBertConfig(
Expand Down

0 comments on commit 3d09645

Please sign in to comment.