diff --git a/src/cehrbert/data_generators/hf_data_generator/hf_dataset_mapping.py b/src/cehrbert/data_generators/hf_data_generator/hf_dataset_mapping.py index 9d6f492..bbc86b3 100644 --- a/src/cehrbert/data_generators/hf_data_generator/hf_dataset_mapping.py +++ b/src/cehrbert/data_generators/hf_data_generator/hf_dataset_mapping.py @@ -30,6 +30,7 @@ "Visit/61", "NUCC/315D00000X", ] +ED_VISIT_TYPE_CODES = ["VISIT/ER"] DISCHARGE_FACILITY_TYPES = [ "8536", "8863", @@ -238,7 +239,11 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]: # We assume the first measurement to be the visit type of the current visit visit_type = visit["visit_type"] - is_inpatient = visit_type in INPATIENT_VISIT_TYPES or visit_type in INPATIENT_VISIT_TYPE_CODES + is_er_or_inpatient = ( + visit_type in INPATIENT_VISIT_TYPES + or visit_type in INPATIENT_VISIT_TYPE_CODES + or visit_type in ED_VISIT_TYPE_CODES + ) # Add artificial time tokens to the patient timeline if timedelta exists if time_delta: @@ -283,32 +288,31 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]: # Add a medical token to the patient timeline # If this is an inpatient visit, we use the event time stamps to calculate age and date # because the patient can stay in the hospital for a period of time. - if is_inpatient: + if is_er_or_inpatient: # Calculate age using the event time stamp age = relativedelta(e["time"], birth_datetime).years # Calculate the week number since the epoch time date = (e["time"] - datetime.datetime(year=1970, month=1, day=1)).days // 7 + # Calculate the time diff in days w.r.t the previous measurement + meas_time_diff = (e["time"] - date_cursor).days + # Update the date_cursor if the time diff between two neighboring measurements is greater than and + # equal to 1 day + if meas_time_diff > 0: + date_cursor = e["time"] + if self._inpatient_time_token_function: + # This generates an artificial time token depending on the choice of the time token functions + self._update_cehrbert_record( + cehrbert_record, + code=f"i-{self._inpatient_time_token_function(meas_time_diff)}", + visit_concept_order=i + 1, + visit_segment=visit_segment, + visit_concept_id=visit_type, + ) else: # For outpatient visits, we use the visit time stamp to calculate age and time because we assume # the outpatient visits start and end on the same day pass - # Calculate the time diff in days w.r.t the previous measurement - meas_time_diff = (e["time"] - date_cursor).days - # Update the date_cursor if the time diff between two neighboring measurements is greater than and - # equal to 1 day - if meas_time_diff > 0: - date_cursor = e["time"] - if self._inpatient_time_token_function: - # This generates an artificial time token depending on the choice of the time token functions - self._update_cehrbert_record( - cehrbert_record, - code=f"i-{self._inpatient_time_token_function(meas_time_diff)}", - visit_concept_order=i + 1, - visit_segment=visit_segment, - visit_concept_id=visit_type, - ) - # If numeric_value exists, this is a concept/value tuple, we indicate this using a concept_value_mask numeric_value = e.get("numeric_value", None) # The unit might be populated with a None value @@ -338,7 +342,7 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]: mlm_skip_value=concept_value_mask, ) - if is_inpatient: + if is_er_or_inpatient: # If visit_end_datetime is populated for the inpatient visit, we update the date_cursor visit_end_datetime = visit.get("visit_end_datetime", None) if visit_end_datetime: