diff --git a/src/cehrbert/data_generators/hf_data_generator/hf_dataset_mapping.py b/src/cehrbert/data_generators/hf_data_generator/hf_dataset_mapping.py index 8056572..89d3fb6 100644 --- a/src/cehrbert/data_generators/hf_data_generator/hf_dataset_mapping.py +++ b/src/cehrbert/data_generators/hf_data_generator/hf_dataset_mapping.py @@ -246,7 +246,7 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]: ) # Add artificial time tokens to the patient timeline if timedelta exists - if time_delta: + if time_delta is not None: # This generates an artificial time token depending on the choice of the time token functions self._update_cehrbert_record( cehrbert_record, diff --git a/src/cehrbert/data_generators/hf_data_generator/meds_utils.py b/src/cehrbert/data_generators/hf_data_generator/meds_utils.py index bf42cf9..328ec71 100644 --- a/src/cehrbert/data_generators/hf_data_generator/meds_utils.py +++ b/src/cehrbert/data_generators/hf_data_generator/meds_utils.py @@ -138,7 +138,7 @@ def convert_one_patient( for visit_id, blocks in patient_block_dict.items(): visit_type = blocks[0].visit_type visit_start_datetime = min([b.min_time for b in blocks]) - visit_end_datetime = max([b.max_time for b in blocks]) + visit_end_datetime = max([b.get_visit_end_datetime() for b in blocks]) discharge_facility = ( next(filter(None, [b.get_discharge_facility() for b in blocks]), None) if visit_type in [DEFAULT_INPATIENT_CONCEPT_ID, DEFAULT_ED_CONCEPT_ID] @@ -195,6 +195,7 @@ def create_dataset_from_meds_reader( LOG.info("The inpatient_att_function_type: %s", data_args.inpatient_att_function_type) LOG.info("The include_auxiliary_token: %s", data_args.include_auxiliary_token) LOG.info("The include_demographic_prompt: %s", data_args.include_demographic_prompt) + LOG.info("The meds_exclude_tables: %s", "\n".join(data_args.meds_exclude_tables)) train_dataset = _create_cehrbert_data_from_meds( data_args=data_args, diff --git a/src/cehrbert/data_generators/hf_data_generator/patient_block.py b/src/cehrbert/data_generators/hf_data_generator/patient_block.py index d174fd1..adf0ee2 100644 --- a/src/cehrbert/data_generators/hf_data_generator/patient_block.py +++ b/src/cehrbert/data_generators/hf_data_generator/patient_block.py @@ -65,7 +65,7 @@ def __init__( Attributes are initialized to store visit metadata and calculate admission/discharge statuses. """ self.visit_id = visit_id - self.events = events + self.events = sorted(events, key=lambda e: [e.time, e.code]) self.min_time = events[0].time self.max_time = events[-1].time self.conversion = conversion @@ -179,6 +179,12 @@ def _convert_event(self, event) -> List[Event]: ) ] + def get_visit_end_datetime(self) -> datetime: + for e in self.events: + if hasattr(e, "end"): + return getattr(e, "end") + return self.max_time + def get_meds_events(self) -> Iterable[Event]: """ Retrieves all medication events for the visit, converting each raw event if necessary. @@ -258,7 +264,7 @@ def omop_meds_generate_demographics_and_patient_blocks( if patient_block.min_time.date() <= current_date <= patient_block.max_time.date(): patient_block.events.extend(unlinked_event_mapping.pop(current_date_str, [])) # Need to sort the events if we insert new events to the patient block - patient_block.events = sorted(patient_block.events, key=lambda _: _.time) + patient_block.events = sorted(patient_block.events, key=lambda _: [_.time, _.code]) break max_visit_id = max(patient_block_mapping.keys()) + 1 if len(patient_block_mapping) > 0 else 1 diff --git a/src/cehrbert/models/hf_models/hf_cehrbert.py b/src/cehrbert/models/hf_models/hf_cehrbert.py index fc6fa0e..e5e0885 100644 --- a/src/cehrbert/models/hf_models/hf_cehrbert.py +++ b/src/cehrbert/models/hf_models/hf_cehrbert.py @@ -461,6 +461,36 @@ def __init__(self, config: CehrBertConfig): # Initialize weights and apply final processing self.post_init() + def _apply_age_norm( + self, + age_at_index: torch.FloatTensor, + ) -> torch.FloatTensor: + """ + Applies batch normalization to the input age tensor. + + If the batch contains more than one example, + standard batch normalization is applied. If the batch size is 1, batch normalization is applied + without updating the running statistics, ensuring that the normalization uses the stored running + mean and variance without modification. + + Args: + age_at_index (torch.FloatTensor): A tensor containing the age values to normalize. + The tensor has shape `(batch_size, num_features)` where `batch_size` is the number of samples in the batch. + + Returns: + torch.FloatTensor: A tensor with the normalized age values. + """ + if age_at_index.shape[0] > 1: + normalized_age = self.age_batch_norm(age_at_index) + else: + self.age_batch_norm.eval() + # Apply batch norm without updating running stats + with torch.no_grad(): # Prevent tracking gradients, since we don't want to update anything + normalized_age = self.age_batch_norm(age_at_index) + # Optionally, set the layer back to training mode if needed later + self.age_batch_norm.train() + return normalized_age + def forward( self, input_ids: torch.LongTensor, @@ -476,7 +506,7 @@ def forward( output_hidden_states: Optional[bool] = None, classifier_label: Optional[torch.FloatTensor] = None, ) -> CehrBertSequenceClassifierOutput: - normalized_age = self.age_batch_norm(age_at_index) + normalized_age = self._apply_age_norm(age_at_index) cehrbert_output = self.bert( input_ids, diff --git a/src/cehrbert/runners/hf_runner_argument_dataclass.py b/src/cehrbert/runners/hf_runner_argument_dataclass.py index 1d6abe6..1576605 100644 --- a/src/cehrbert/runners/hf_runner_argument_dataclass.py +++ b/src/cehrbert/runners/hf_runner_argument_dataclass.py @@ -111,7 +111,7 @@ class DataTrainingArguments: }, ) meds_exclude_tables: Optional[List[str]] = dataclasses.field( - default=list, + default_factory=list, metadata={"help": "The tables to exclude in the conversion e.g. measurement"}, ) # TODO: Python 3.9/10 do not support dynamic unpacking, we have to manually provide the entire