diff --git a/src/cehrbert/data_generators/hf_data_generator/meds_utils.py b/src/cehrbert/data_generators/hf_data_generator/meds_utils.py index 0fe9ecc..799ce91 100644 --- a/src/cehrbert/data_generators/hf_data_generator/meds_utils.py +++ b/src/cehrbert/data_generators/hf_data_generator/meds_utils.py @@ -9,6 +9,7 @@ import numpy as np import pandas as pd from datasets import Dataset, DatasetDict, Split +from jsonschema.benchmarks.unused_registry import instance from cehrbert.data_generators.hf_data_generator.hf_dataset import apply_cehrbert_dataset_mapping from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import MedToCehrBertDatasetMapping, birth_codes @@ -31,11 +32,15 @@ def get_meds_to_cehrbert_conversion_cls( - meds_to_cehrbert_conversion_type: MedsToCehrBertConversionType, + meds_to_cehrbert_conversion_type: Union[MedsToCehrBertConversionType, str], ) -> MedsToCehrBertConversion: for cls in MedsToCehrBertConversion.__subclasses__(): - if meds_to_cehrbert_conversion_type.name == cls.__name__: - return cls() + if instance(meds_to_cehrbert_conversion_type, MedsToCehrBertConversionType): + if meds_to_cehrbert_conversion_type.name == cls.__name__: + return cls() + elif instance(meds_to_cehrbert_conversion_type, str): + if meds_to_cehrbert_conversion_type == cls.__name__: + return cls() raise RuntimeError(f"{meds_to_cehrbert_conversion_type} is not a valid MedsToCehrBertConversionType") diff --git a/src/cehrbert/runners/hf_runner_argument_dataclass.py b/src/cehrbert/runners/hf_runner_argument_dataclass.py index 073384b..055fe41 100644 --- a/src/cehrbert/runners/hf_runner_argument_dataclass.py +++ b/src/cehrbert/runners/hf_runner_argument_dataclass.py @@ -4,7 +4,7 @@ from cehrbert_data.decorators.patient_event_decorator import AttType -from ..data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules import ( +from cehrbert.data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules import ( MedsToBertMimic4, MedsToCehrBertConversion, ) @@ -111,12 +111,14 @@ class DataTrainingArguments: ) # TODO: Python 3.9/10 do not support dynamic unpacking, we have to manually provide the entire # list right now. - meds_to_cehrbert_conversion_type: Literal[MedsToBertMimic4.__name__] = dataclasses.field( - default=MedsToBertMimic4, - metadata={ - "help": "The MEDS to CEHRBERT conversion type e.g. MedsToBertMimic4", - "choices": f"choices={[e for e in MedsToCehrBertConversionType.__members__]}", - }, + meds_to_cehrbert_conversion_type: Literal[MedsToCehrBertConversionType[MedsToBertMimic4.__name__]] = ( + dataclasses.field( + default=MedsToCehrBertConversionType[MedsToBertMimic4.__name__], + metadata={ + "help": "The MEDS to CEHRBERT conversion type e.g. MedsToBertMimic4", + "choices": f"choices={[e for e in MedsToCehrBertConversionType.__members__]}", + }, + ) ) include_auxiliary_token: Optional[bool] = dataclasses.field( default=False,