Skip to content

Commit

Permalink
added a if/else block for applying age_batch_norm because there could…
Browse files Browse the repository at this point in the history
… be batches of single example in the forward pass, which would break the batch_norm
  • Loading branch information
ChaoPang committed Oct 25, 2024
1 parent 07600fb commit 35dcd9e
Showing 1 changed file with 32 additions and 1 deletion.
33 changes: 32 additions & 1 deletion src/cehrbert/models/hf_models/hf_cehrbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,36 @@ def __init__(self, config: CehrBertConfig):
# Initialize weights and apply final processing
self.post_init()

def _apply_age_norm(
self,
age_at_index: torch.FloatTensor,
) -> torch.FloatTensor:
"""
Applies batch normalization to the input age tensor.
If the batch contains more than one example,
standard batch normalization is applied. If the batch size is 1, batch normalization is applied
without updating the running statistics, ensuring that the normalization uses the stored running
mean and variance without modification.
Args:
age_at_index (torch.FloatTensor): A tensor containing the age values to normalize.
The tensor has shape `(batch_size, num_features)` where `batch_size` is the number of samples in the batch.
Returns:
torch.FloatTensor: A tensor with the normalized age values.
"""
if age_at_index.shape[0] > 1:
normalized_age = self.age_batch_norm(age_at_index)
else:
self.age_batch_norm.eval()
# Apply batch norm without updating running stats
with torch.no_grad(): # Prevent tracking gradients, since we don't want to update anything
normalized_age = self.age_batch_norm(age_at_index)
# Optionally, set the layer back to training mode if needed later
self.age_batch_norm.train()
return normalized_age

def forward(
self,
input_ids: torch.LongTensor,
Expand All @@ -369,7 +399,8 @@ def forward(
output_hidden_states: Optional[bool] = None,
classifier_label: Optional[torch.FloatTensor] = None,
) -> CehrBertSequenceClassifierOutput:
normalized_age = self.age_batch_norm(age_at_index)

normalized_age = self._apply_age_norm(age_at_index)

cehrbert_output = self.bert(
input_ids,
Expand Down

0 comments on commit 35dcd9e

Please sign in to comment.