Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Meds conversion time token fix #71

Merged
merged 4 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
)

# Add artificial time tokens to the patient timeline if timedelta exists
if time_delta:
if time_delta is not None:
# This generates an artificial time token depending on the choice of the time token functions
self._update_cehrbert_record(
cehrbert_record,
Expand Down
3 changes: 2 additions & 1 deletion src/cehrbert/data_generators/hf_data_generator/meds_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def convert_one_patient(
for visit_id, blocks in patient_block_dict.items():
visit_type = blocks[0].visit_type
visit_start_datetime = min([b.min_time for b in blocks])
visit_end_datetime = max([b.max_time for b in blocks])
visit_end_datetime = max([b.get_visit_end_datetime() for b in blocks])
discharge_facility = (
next(filter(None, [b.get_discharge_facility() for b in blocks]), None)
if visit_type in [DEFAULT_INPATIENT_CONCEPT_ID, DEFAULT_ED_CONCEPT_ID]
Expand Down Expand Up @@ -195,6 +195,7 @@ def create_dataset_from_meds_reader(
LOG.info("The inpatient_att_function_type: %s", data_args.inpatient_att_function_type)
LOG.info("The include_auxiliary_token: %s", data_args.include_auxiliary_token)
LOG.info("The include_demographic_prompt: %s", data_args.include_demographic_prompt)
LOG.info("The meds_exclude_tables: %s", "\n".join(data_args.meds_exclude_tables))

train_dataset = _create_cehrbert_data_from_meds(
data_args=data_args,
Expand Down
10 changes: 8 additions & 2 deletions src/cehrbert/data_generators/hf_data_generator/patient_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(
Attributes are initialized to store visit metadata and calculate admission/discharge statuses.
"""
self.visit_id = visit_id
self.events = events
self.events = sorted(events, key=lambda e: [e.time, e.code])
self.min_time = events[0].time
self.max_time = events[-1].time
self.conversion = conversion
Expand Down Expand Up @@ -179,6 +179,12 @@ def _convert_event(self, event) -> List[Event]:
)
]

def get_visit_end_datetime(self) -> datetime:
for e in self.events:
if hasattr(e, "end"):
return getattr(e, "end")
return self.max_time

def get_meds_events(self) -> Iterable[Event]:
"""
Retrieves all medication events for the visit, converting each raw event if necessary.
Expand Down Expand Up @@ -258,7 +264,7 @@ def omop_meds_generate_demographics_and_patient_blocks(
if patient_block.min_time.date() <= current_date <= patient_block.max_time.date():
patient_block.events.extend(unlinked_event_mapping.pop(current_date_str, []))
# Need to sort the events if we insert new events to the patient block
patient_block.events = sorted(patient_block.events, key=lambda _: _.time)
patient_block.events = sorted(patient_block.events, key=lambda _: [_.time, _.code])
break

max_visit_id = max(patient_block_mapping.keys()) + 1 if len(patient_block_mapping) > 0 else 1
Expand Down
32 changes: 31 additions & 1 deletion src/cehrbert/models/hf_models/hf_cehrbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,36 @@ def __init__(self, config: CehrBertConfig):
# Initialize weights and apply final processing
self.post_init()

def _apply_age_norm(
self,
age_at_index: torch.FloatTensor,
) -> torch.FloatTensor:
"""
Applies batch normalization to the input age tensor.

If the batch contains more than one example,
standard batch normalization is applied. If the batch size is 1, batch normalization is applied
without updating the running statistics, ensuring that the normalization uses the stored running
mean and variance without modification.

Args:
age_at_index (torch.FloatTensor): A tensor containing the age values to normalize.
The tensor has shape `(batch_size, num_features)` where `batch_size` is the number of samples in the batch.

Returns:
torch.FloatTensor: A tensor with the normalized age values.
"""
if age_at_index.shape[0] > 1:
normalized_age = self.age_batch_norm(age_at_index)
else:
self.age_batch_norm.eval()
# Apply batch norm without updating running stats
with torch.no_grad(): # Prevent tracking gradients, since we don't want to update anything
normalized_age = self.age_batch_norm(age_at_index)
# Optionally, set the layer back to training mode if needed later
self.age_batch_norm.train()
return normalized_age

def forward(
self,
input_ids: torch.LongTensor,
Expand All @@ -476,7 +506,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
classifier_label: Optional[torch.FloatTensor] = None,
) -> CehrBertSequenceClassifierOutput:
normalized_age = self.age_batch_norm(age_at_index)
normalized_age = self._apply_age_norm(age_at_index)

cehrbert_output = self.bert(
input_ids,
Expand Down
2 changes: 1 addition & 1 deletion src/cehrbert/runners/hf_runner_argument_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class DataTrainingArguments:
},
)
meds_exclude_tables: Optional[List[str]] = dataclasses.field(
default=list,
default_factory=list,
metadata={"help": "The tables to exclude in the conversion e.g. measurement"},
)
# TODO: Python 3.9/10 do not support dynamic unpacking, we have to manually provide the entire
Expand Down
Loading