diff --git a/sample_configs/hf_cehrbert_finetuning_runner_config.yaml b/sample_configs/hf_cehrbert_finetuning_runner_config.yaml index 2653e805..3063d659 100644 --- a/sample_configs/hf_cehrbert_finetuning_runner_config.yaml +++ b/sample_configs/hf_cehrbert_finetuning_runner_config.yaml @@ -15,7 +15,7 @@ min_frequency: 0 # Below is a list of Med-to-CehrBert related arguments att_function_type: "cehrbert" -is_data_in_med: false +is_data_in_meds: false inpatient_att_function_type: "mix" include_auxiliary_token: true include_demographic_prompt: false diff --git a/sample_configs/hf_cehrbert_pretrain_runner_config.yaml b/sample_configs/hf_cehrbert_pretrain_runner_config.yaml index c19c1bad..49202eae 100644 --- a/sample_configs/hf_cehrbert_pretrain_runner_config.yaml +++ b/sample_configs/hf_cehrbert_pretrain_runner_config.yaml @@ -15,7 +15,7 @@ min_frequency: 0 # Below is a list of Med-to-CehrBert related arguments att_function_type: "cehrbert" -is_data_in_med: false +is_data_in_meds: false inpatient_att_function_type: "none" include_auxiliary_token: true include_demographic_prompt: false diff --git a/sample_configs/hf_cehrbert_pretrain_runner_meds_config.yaml b/sample_configs/hf_cehrbert_pretrain_runner_meds_config.yaml index e7bbc756..59ca503f 100644 --- a/sample_configs/hf_cehrbert_pretrain_runner_meds_config.yaml +++ b/sample_configs/hf_cehrbert_pretrain_runner_meds_config.yaml @@ -15,7 +15,7 @@ min_frequency: 0 # Below is a list of Med-to-CehrBert related arguments att_function_type: "cehrbert" -is_data_in_med: true +is_data_in_meds: true inpatient_att_function_type: "mix" include_auxiliary_token: true include_demographic_prompt: false diff --git a/src/cehrbert/data_generators/hf_data_generator/hf_dataset.py b/src/cehrbert/data_generators/hf_data_generator/hf_dataset.py index 87c68786..fc687d5b 100644 --- a/src/cehrbert/data_generators/hf_data_generator/hf_dataset.py +++ b/src/cehrbert/data_generators/hf_data_generator/hf_dataset.py @@ -46,7 +46,7 @@ def create_cehrbert_pretraining_dataset( ) # If the data is already in meds, we don't need to sort the sequence anymore - if data_args.is_data_in_med: + if data_args.is_data_in_meds: mapping_functions = [HFTokenizationMapping(concept_tokenizer, True)] else: mapping_functions = [ @@ -89,7 +89,7 @@ def create_cehrbert_finetuning_dataset( batch_size=data_args.preprocessing_batch_size, ) - if data_args.is_data_in_med: + if data_args.is_data_in_meds: mapping_functions = [ HFFineTuningMapping(), HFTokenizationMapping(concept_tokenizer, False), diff --git a/src/cehrbert/runners/hf_cehrbert_finetune_runner.py b/src/cehrbert/runners/hf_cehrbert_finetune_runner.py index 2d041546..8e0d2dfe 100644 --- a/src/cehrbert/runners/hf_cehrbert_finetune_runner.py +++ b/src/cehrbert/runners/hf_cehrbert_finetune_runner.py @@ -106,7 +106,7 @@ def main(): LOG.info("Prepared dataset loaded from disk...") else: # If the data is in the MEDS format, we need to convert it to the CEHR-BERT format - if data_args.is_data_in_med: + if data_args.is_data_in_meds: meds_extension_path = get_meds_extension_path( data_folder=os.path.expanduser(data_args.cohort_folder), dataset_prepared_path=os.path.expanduser(data_args.dataset_prepared_path), diff --git a/src/cehrbert/runners/hf_cehrbert_pretrain_runner.py b/src/cehrbert/runners/hf_cehrbert_pretrain_runner.py index bb83ca3e..eb1e6ea5 100644 --- a/src/cehrbert/runners/hf_cehrbert_pretrain_runner.py +++ b/src/cehrbert/runners/hf_cehrbert_pretrain_runner.py @@ -170,7 +170,7 @@ def main(): tokenizer = load_and_create_tokenizer(data_args=data_args, model_args=model_args, dataset=processed_dataset) else: # If the data is in the MEDS format, we need to convert it to the CEHR-BERT format - if data_args.is_data_in_med: + if data_args.is_data_in_meds: meds_extension_path = get_meds_extension_path( data_folder=os.path.expanduser(data_args.data_folder), dataset_prepared_path=os.path.expanduser(data_args.dataset_prepared_path), diff --git a/src/cehrbert/runners/hf_runner_argument_dataclass.py b/src/cehrbert/runners/hf_runner_argument_dataclass.py index 15766057..3aa5ef39 100644 --- a/src/cehrbert/runners/hf_runner_argument_dataclass.py +++ b/src/cehrbert/runners/hf_runner_argument_dataclass.py @@ -90,9 +90,9 @@ class DataTrainingArguments: "choices": f"choices={[e.value for e in AttType]}", }, ) - is_data_in_med: Optional[bool] = dataclasses.field( + is_data_in_meds: Optional[bool] = dataclasses.field( default=False, - metadata={"help": "The boolean indicator to indicate whether the data is in the MED format"}, + metadata={"help": "The boolean indicator to indicate whether the data is in the MEDS format"}, ) inpatient_att_function_type: Literal[ AttType.CEHR_BERT.value, diff --git a/src/cehrbert/runners/runner_util.py b/src/cehrbert/runners/runner_util.py index 39bced6a..6d80f040 100644 --- a/src/cehrbert/runners/runner_util.py +++ b/src/cehrbert/runners/runner_util.py @@ -1,3 +1,4 @@ +import dataclasses import glob import hashlib import os @@ -215,6 +216,67 @@ def remove_trailing_slashes(path: str) -> str: return path.rstrip("/\\") +def parse_dynamic_arguments( + argument_classes: Tuple[dataclasses.dataclass, ...] = (DataTrainingArguments, ModelArguments, TrainingArguments) +) -> Tuple: + """ + Parses command-line arguments with extended flexibility, allowing for the inclusion of custom argument classes. + + This function utilizes `HfArgumentParser` to parse arguments from command line input, JSON, or YAML files. + By default, it expects `ModelArguments`, `DataTrainingArguments`, and `TrainingArguments`, but it can be extended + with additional argument classes through the `argument_classes` parameter, making it suitable + for various custom setups. + + Parameters: + argument_classes (Tuple[Type]): A tuple of argument classes to be parsed. Defaults to + `(ModelArguments, DataTrainingArguments, TrainingArguments)`. Additional argument classes can be specified + for greater flexibility in configuration. + + Returns: + Tuple: A tuple of parsed arguments, one for each argument class provided. The order of the returned tuple + matches the order of the `argument_classes` parameter. + + Raises: + FileNotFoundError: If the specified JSON or YAML file does not exist. + json.JSONDecodeError: If there is an error parsing a JSON file. + yaml.YAMLError: If there is an error parsing a YAML file. + Exception: For other issues that occur during argument parsing. + + Example usage: + - Command-line: `python training_script.py --model_name_or_path bert-base-uncased --do_train` + - JSON file: `python training_script.py config.json` + - YAML file: `python training_script.py config.yaml` + + Flexibility: + The function can be customized to include new argument classes as needed: + + Example with a custom argument class: + ```python + class CustomArguments: + # Define custom arguments here + pass + + + custom_args = parse_extended_args( + (ModelArguments, DataTrainingArguments, TrainingArguments, CustomArguments) + ) + ``` + This example demonstrates how to include additional argument classes + beyond the defaults for a more tailored setup. + """ + parser = HfArgumentParser(argument_classes) + + # Check if input is a JSON or YAML file + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + args = parser.parse_json_file(json_file=os.path.expanduser(sys.argv[1])) + elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): + args = parser.parse_yaml_file(yaml_file=os.path.expanduser(sys.argv[1])) + else: + args = parser.parse_args_into_dataclasses() + + return tuple(args) + + def parse_runner_args() -> Tuple[DataTrainingArguments, ModelArguments, TrainingArguments]: """ Parses command line arguments provided to a script for training a model using the Hugging Face. @@ -253,15 +315,9 @@ def parse_runner_args() -> Tuple[DataTrainingArguments, ModelArguments, Training Or using a YAML configuration file: $ python training_script.py config.yaml """ - parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) - if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): - # If we pass only one argument to the script and it's the path to a json file, - # let's parse it to get our arguments. - model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.expanduser(sys.argv[1])) - elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): - model_args, data_args, training_args = parser.parse_yaml_file(yaml_file=os.path.expanduser(sys.argv[1])) - else: - model_args, data_args, training_args = parser.parse_args_into_dataclasses() + data_args, model_args, training_args = parse_dynamic_arguments( + (DataTrainingArguments, ModelArguments, TrainingArguments) + ) return data_args, model_args, training_args