From 35dcd9e1bc5ab1396953521d76415abf35ba8338 Mon Sep 17 00:00:00 2001 From: Chao Pang Date: Fri, 25 Oct 2024 14:20:45 -0400 Subject: [PATCH] added a if/else block for applying age_batch_norm because there could be batches of single example in the forward pass, which would break the batch_norm --- src/cehrbert/models/hf_models/hf_cehrbert.py | 33 +++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/src/cehrbert/models/hf_models/hf_cehrbert.py b/src/cehrbert/models/hf_models/hf_cehrbert.py index 96e37d9..fc6fa0e 100644 --- a/src/cehrbert/models/hf_models/hf_cehrbert.py +++ b/src/cehrbert/models/hf_models/hf_cehrbert.py @@ -354,6 +354,36 @@ def __init__(self, config: CehrBertConfig): # Initialize weights and apply final processing self.post_init() + def _apply_age_norm( + self, + age_at_index: torch.FloatTensor, + ) -> torch.FloatTensor: + """ + Applies batch normalization to the input age tensor. + + If the batch contains more than one example, + standard batch normalization is applied. If the batch size is 1, batch normalization is applied + without updating the running statistics, ensuring that the normalization uses the stored running + mean and variance without modification. + + Args: + age_at_index (torch.FloatTensor): A tensor containing the age values to normalize. + The tensor has shape `(batch_size, num_features)` where `batch_size` is the number of samples in the batch. + + Returns: + torch.FloatTensor: A tensor with the normalized age values. + """ + if age_at_index.shape[0] > 1: + normalized_age = self.age_batch_norm(age_at_index) + else: + self.age_batch_norm.eval() + # Apply batch norm without updating running stats + with torch.no_grad(): # Prevent tracking gradients, since we don't want to update anything + normalized_age = self.age_batch_norm(age_at_index) + # Optionally, set the layer back to training mode if needed later + self.age_batch_norm.train() + return normalized_age + def forward( self, input_ids: torch.LongTensor, @@ -369,7 +399,8 @@ def forward( output_hidden_states: Optional[bool] = None, classifier_label: Optional[torch.FloatTensor] = None, ) -> CehrBertSequenceClassifierOutput: - normalized_age = self.age_batch_norm(age_at_index) + + normalized_age = self._apply_age_norm(age_at_index) cehrbert_output = self.bert( input_ids,