Skip to content

Commit

Permalink
Implemented TruncatedOnlineStatistics to collect truncated stats in a…
Browse files Browse the repository at this point in the history
… distributed way (#57)

* added patient_id to the pretraining data for debugging

* added RunningStatistics to remove the extreme outliers when calculating the running mean and std

* added the lab value lower/upper bounds to the tokenizer so we can bound the extreme values during tokenization

* we only add a new value to the running stat if the value is between the lower and upper bounds

* renamed RunningStatistics to TruncatedOnlineStatistics, and ExcludingOnlineStatistics to TruncatedOfflineStatistics, added unittests for these new stats utilities

* added a tqdm progress bar for aggregating the lab statistics

* handle the cases in TruncatedOnlineStatistics, where the filtered_data contains zero elements

* added lower_bound and upper_bound to the lab stats to cehrbert tokenizer

* fixed the masked mse loss for cehrbert when value_prediction is enabled

* updated the cehr-bert architecture for predicting concepts with values

* fixed a bug where the normalized value should be bounded by a multiple of the standard deviation because the normalized value is assumed to follow the standard normal
  • Loading branch information
ChaoPang authored Sep 13, 2024
1 parent 9ecaca5 commit dcf9715
Show file tree
Hide file tree
Showing 10 changed files with 519 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments

CEHRBERT_COLUMNS = [
"person_id",
"concept_ids",
"ages",
"dates",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __init__(self, data_args: DataTrainingArguments, is_pretraining: bool = True

def remove_columns(self):
if self._is_pretraining:
return ["visits", "patient_id", "birth_datetime", "index_date"]
return ["visits", "birth_datetime", "index_date"]
else:
return [
"visits",
Expand Down
15 changes: 8 additions & 7 deletions src/cehrbert/models/hf_models/hf_cehrbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,21 +96,23 @@ def forward(

merged = torch.where(
concept_value_masks.to(torch.bool),
concept_embeddings_with_val,
gelu_new(concept_embeddings_with_val),
concept_embeddings,
)

return merged


class ConceptValuePredictionLayer(nn.Module):
def __init__(self, embedding_size):
def __init__(self, embedding_size, layer_norm_eps):
super(ConceptValuePredictionLayer, self).__init__()
self.embedding_size = embedding_size
self.concept_value_decoder_layer = nn.Sequential(
nn.Linear(embedding_size, embedding_size // 2),
gelu_new,
nn.LayerNorm(embedding_size // 2, eps=layer_norm_eps),
nn.Linear(embedding_size // 2, 1),
gelu_new,
)

def forward(self, hidden_states: Optional[torch.FloatTensor]):
Expand Down Expand Up @@ -259,7 +261,7 @@ def __init__(self, config: CehrBertConfig):

self.bert = CehrBert(config)
if self.config.include_value_prediction:
self.concept_value_decoder_layer = ConceptValuePredictionLayer(config.hidden_size)
self.concept_value_decoder_layer = ConceptValuePredictionLayer(config.hidden_size, config.layer_norm_eps)
self.cls = BertOnlyMLMHead(config)

# Initialize weights and apply final processing
Expand Down Expand Up @@ -320,10 +322,9 @@ def forward(
if self.config.include_value_prediction:
mlm_masks = labels != -100
predicted_values = self.concept_value_decoder_layer(cehrbert_output.last_hidden_state)
num_items = torch.sum(concept_value_masks.to(torch.float32), dim=-1) + 1e-6
values_ = (predicted_values.squeeze(-1) - concept_values) ** 2
masked_mse = torch.sum(values_ * concept_value_masks * mlm_masks, dim=-1) / num_items
total_loss += torch.mean(masked_mse)
total_loss += torch.mean(
(predicted_values.squeeze(-1) - concept_values) ** 2 * concept_value_masks * mlm_masks
)

return CehrBertModelOutput(
loss=total_loss,
Expand Down
27 changes: 23 additions & 4 deletions src/cehrbert/models/hf_models/tokenization_hf_cehrbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from tokenizers.models import WordLevel
from tokenizers.pre_tokenizers import WhitespaceSplit
from tokenizers.trainers import WordLevelTrainer
from tqdm import tqdm
from transformers.tokenization_utils_base import PushToHubMixin

from cehrbert.models.hf_models.tokenization_utils import agg_helper, agg_statistics, map_statistics
Expand Down Expand Up @@ -64,6 +65,9 @@ def __init__(
"unit": lab_stat["unit"],
"mean": lab_stat["mean"],
"std": lab_stat["std"],
"value_outlier_std": lab_stat["value_outlier_std"],
"lower_bound": lab_stat["lower_bound"],
"upper_bound": lab_stat["upper_bound"],
}
for lab_stat in lab_stats
}
Expand Down Expand Up @@ -296,16 +300,22 @@ def batched_generator():

tokenizer.train_from_iterator(generator, trainer=trainer)

map_statistics_partial = partial(
map_statistics,
capacity=data_args.offline_stats_capacity,
value_outlier_std=data_args.value_outlier_std,
)

if data_args.streaming:
parts = dataset.map(
partial(agg_helper, map_func=map_statistics),
partial(agg_helper, map_func=map_statistics_partial),
batched=True,
batch_size=data_args.preprocessing_batch_size,
remove_columns=dataset.column_names,
)
else:
parts = dataset.map(
partial(agg_helper, map_func=map_statistics),
partial(agg_helper, map_func=map_statistics_partial),
batched=True,
batch_size=data_args.preprocessing_batch_size,
remove_columns=dataset.column_names,
Expand All @@ -314,19 +324,23 @@ def batched_generator():
new_fingerprint="invalid",
)
current = None
for stat in parts:
for stat in tqdm(parts, desc="Aggregating the lab statistics"):
fixed_stat = pickle.loads(stat["data"])
if current is None:
current = fixed_stat
else:
current = agg_statistics(current, fixed_stat)

lab_stats = [
{
"concept_id": concept_id,
"unit": unit,
"mean": online_stats.mean(),
"std": online_stats.standard_deviation(),
"count": online_stats.count,
"value_outlier_std": data_args.value_outlier_std,
"lower_bound": online_stats.mean() - data_args.value_outlier_std * online_stats.standard_deviation(),
"upper_bound": online_stats.mean() + data_args.value_outlier_std * online_stats.standard_deviation(),
}
for (concept_id, unit), online_stats in current["numeric_stats_by_lab"].items()
]
Expand All @@ -342,8 +356,13 @@ def normalize(self, concept_id, concept_value) -> float:
mean_ = concept_value - self._lab_stat_mapping[concept_id]["mean"]
std = self._lab_stat_mapping[concept_id]["std"]
if std > 0:
value_outlier_std = self._lab_stat_mapping[concept_id]["value_outlier_std"]
normalized_value = mean_ / self._lab_stat_mapping[concept_id]["std"]
# Clip the value between the lower and upper bounds of the corresponding lab
normalized_value = max(-value_outlier_std, min(value_outlier_std, normalized_value))
else:
normalized_value = mean_
# If there is not a valid standard deviation,
# we just the normalized value to the mean of the standard normal
normalized_value = 0.0
return normalized_value
return concept_value
10 changes: 6 additions & 4 deletions src/cehrbert/models/hf_models/tokenization_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import collections
import json
import pickle
from functools import partial
from typing import Any, Dict

from femr.stat_utils import OnlineStatistics
from cehrbert.utils.stat_utils import TruncatedOnlineStatistics


def load_json_file(json_file):
Expand All @@ -21,13 +22,14 @@ def agg_helper(*args, map_func):
return {"data": [pickle.dumps(result)]}


def map_statistics(batch: Dict[str, Any]) -> Dict[str, Any]:
def map_statistics(batch: Dict[str, Any], capacity=100, value_outlier_std=2.0) -> Dict[str, Any]:
if "units" in batch:
concept_value_units = batch["units"]
else:
concept_value_units = [["default_unit" for _ in cons] for cons in batch["concept_ids"]]

numeric_stats_by_lab = collections.defaultdict(OnlineStatistics)
numeric_stats_by_lab = collections.defaultdict(
partial(TruncatedOnlineStatistics, capacity=capacity, value_outlier_std=value_outlier_std)
)
for concept_ids, concept_values, concept_value_indicators, units in zip(
batch["concept_ids"],
batch["concept_values"],
Expand Down
10 changes: 10 additions & 0 deletions src/cehrbert/runners/hf_runner_argument_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,16 @@ class DataTrainingArguments:
default=False,
metadata={"help": "Indicates whether to randomly shuffle the records that have the same rank"},
)
offline_stats_capacity: Optional[int] = dataclasses.field(
default=100,
metadata={
"help": "The minimum num of lab values to collect for the truncated offline statistics before switching to the online statistics calculation"
},
)
value_outlier_std: Optional[float] = dataclasses.field(
default=3.0,
metadata={"help": "The lower quantile for excluding the extreme lower lab values"},
)


@dataclasses.dataclass
Expand Down
Loading

0 comments on commit dcf9715

Please sign in to comment.