Skip to content

Commit

Permalink
changed the field from is_data_in_med to is_data_in_meds in DataTrain…
Browse files Browse the repository at this point in the history
…ingArguments; added a parse_dynamic_arguments to extend the argument parser more flexibly (#72)
  • Loading branch information
ChaoPang authored Nov 1, 2024
1 parent 4fba805 commit 0443b66
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 18 deletions.
2 changes: 1 addition & 1 deletion sample_configs/hf_cehrbert_finetuning_runner_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ min_frequency: 0

# Below is a list of Med-to-CehrBert related arguments
att_function_type: "cehrbert"
is_data_in_med: false
is_data_in_meds: false
inpatient_att_function_type: "mix"
include_auxiliary_token: true
include_demographic_prompt: false
Expand Down
2 changes: 1 addition & 1 deletion sample_configs/hf_cehrbert_pretrain_runner_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ min_frequency: 0

# Below is a list of Med-to-CehrBert related arguments
att_function_type: "cehrbert"
is_data_in_med: false
is_data_in_meds: false
inpatient_att_function_type: "none"
include_auxiliary_token: true
include_demographic_prompt: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ min_frequency: 0

# Below is a list of Med-to-CehrBert related arguments
att_function_type: "cehrbert"
is_data_in_med: true
is_data_in_meds: true
inpatient_att_function_type: "mix"
include_auxiliary_token: true
include_demographic_prompt: false
Expand Down
4 changes: 2 additions & 2 deletions src/cehrbert/data_generators/hf_data_generator/hf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def create_cehrbert_pretraining_dataset(
)

# If the data is already in meds, we don't need to sort the sequence anymore
if data_args.is_data_in_med:
if data_args.is_data_in_meds:
mapping_functions = [HFTokenizationMapping(concept_tokenizer, True)]
else:
mapping_functions = [
Expand Down Expand Up @@ -89,7 +89,7 @@ def create_cehrbert_finetuning_dataset(
batch_size=data_args.preprocessing_batch_size,
)

if data_args.is_data_in_med:
if data_args.is_data_in_meds:
mapping_functions = [
HFFineTuningMapping(),
HFTokenizationMapping(concept_tokenizer, False),
Expand Down
2 changes: 1 addition & 1 deletion src/cehrbert/runners/hf_cehrbert_finetune_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def main():
LOG.info("Prepared dataset loaded from disk...")
else:
# If the data is in the MEDS format, we need to convert it to the CEHR-BERT format
if data_args.is_data_in_med:
if data_args.is_data_in_meds:
meds_extension_path = get_meds_extension_path(
data_folder=os.path.expanduser(data_args.cohort_folder),
dataset_prepared_path=os.path.expanduser(data_args.dataset_prepared_path),
Expand Down
2 changes: 1 addition & 1 deletion src/cehrbert/runners/hf_cehrbert_pretrain_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def main():
tokenizer = load_and_create_tokenizer(data_args=data_args, model_args=model_args, dataset=processed_dataset)
else:
# If the data is in the MEDS format, we need to convert it to the CEHR-BERT format
if data_args.is_data_in_med:
if data_args.is_data_in_meds:
meds_extension_path = get_meds_extension_path(
data_folder=os.path.expanduser(data_args.data_folder),
dataset_prepared_path=os.path.expanduser(data_args.dataset_prepared_path),
Expand Down
4 changes: 2 additions & 2 deletions src/cehrbert/runners/hf_runner_argument_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ class DataTrainingArguments:
"choices": f"choices={[e.value for e in AttType]}",
},
)
is_data_in_med: Optional[bool] = dataclasses.field(
is_data_in_meds: Optional[bool] = dataclasses.field(
default=False,
metadata={"help": "The boolean indicator to indicate whether the data is in the MED format"},
metadata={"help": "The boolean indicator to indicate whether the data is in the MEDS format"},
)
inpatient_att_function_type: Literal[
AttType.CEHR_BERT.value,
Expand Down
74 changes: 65 additions & 9 deletions src/cehrbert/runners/runner_util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
import glob
import hashlib
import os
Expand Down Expand Up @@ -215,6 +216,67 @@ def remove_trailing_slashes(path: str) -> str:
return path.rstrip("/\\")


def parse_dynamic_arguments(
argument_classes: Tuple[dataclasses.dataclass, ...] = (DataTrainingArguments, ModelArguments, TrainingArguments)
) -> Tuple:
"""
Parses command-line arguments with extended flexibility, allowing for the inclusion of custom argument classes.
This function utilizes `HfArgumentParser` to parse arguments from command line input, JSON, or YAML files.
By default, it expects `ModelArguments`, `DataTrainingArguments`, and `TrainingArguments`, but it can be extended
with additional argument classes through the `argument_classes` parameter, making it suitable
for various custom setups.
Parameters:
argument_classes (Tuple[Type]): A tuple of argument classes to be parsed. Defaults to
`(ModelArguments, DataTrainingArguments, TrainingArguments)`. Additional argument classes can be specified
for greater flexibility in configuration.
Returns:
Tuple: A tuple of parsed arguments, one for each argument class provided. The order of the returned tuple
matches the order of the `argument_classes` parameter.
Raises:
FileNotFoundError: If the specified JSON or YAML file does not exist.
json.JSONDecodeError: If there is an error parsing a JSON file.
yaml.YAMLError: If there is an error parsing a YAML file.
Exception: For other issues that occur during argument parsing.
Example usage:
- Command-line: `python training_script.py --model_name_or_path bert-base-uncased --do_train`
- JSON file: `python training_script.py config.json`
- YAML file: `python training_script.py config.yaml`
Flexibility:
The function can be customized to include new argument classes as needed:
Example with a custom argument class:
```python
class CustomArguments:
# Define custom arguments here
pass
custom_args = parse_extended_args(
(ModelArguments, DataTrainingArguments, TrainingArguments, CustomArguments)
)
```
This example demonstrates how to include additional argument classes
beyond the defaults for a more tailored setup.
"""
parser = HfArgumentParser(argument_classes)

# Check if input is a JSON or YAML file
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
args = parser.parse_json_file(json_file=os.path.expanduser(sys.argv[1]))
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
args = parser.parse_yaml_file(yaml_file=os.path.expanduser(sys.argv[1]))
else:
args = parser.parse_args_into_dataclasses()

return tuple(args)


def parse_runner_args() -> Tuple[DataTrainingArguments, ModelArguments, TrainingArguments]:
"""
Parses command line arguments provided to a script for training a model using the Hugging Face.
Expand Down Expand Up @@ -253,15 +315,9 @@ def parse_runner_args() -> Tuple[DataTrainingArguments, ModelArguments, Training
Or using a YAML configuration file:
$ python training_script.py config.yaml
"""
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.expanduser(sys.argv[1]))
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
model_args, data_args, training_args = parser.parse_yaml_file(yaml_file=os.path.expanduser(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
data_args, model_args, training_args = parse_dynamic_arguments(
(DataTrainingArguments, ModelArguments, TrainingArguments)
)
return data_args, model_args, training_args


Expand Down

0 comments on commit 0443b66

Please sign in to comment.