Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented TruncatedOnlineStatistics to collect truncated stats in a distributed way #57

Merged
merged 11 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments

CEHRBERT_COLUMNS = [
"person_id",
"concept_ids",
"ages",
"dates",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __init__(self, data_args: DataTrainingArguments, is_pretraining: bool = True

def remove_columns(self):
if self._is_pretraining:
return ["visits", "patient_id", "birth_datetime", "index_date"]
return ["visits", "birth_datetime", "index_date"]
else:
return [
"visits",
Expand Down
15 changes: 8 additions & 7 deletions src/cehrbert/models/hf_models/hf_cehrbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,21 +96,23 @@ def forward(

merged = torch.where(
concept_value_masks.to(torch.bool),
concept_embeddings_with_val,
gelu_new(concept_embeddings_with_val),
concept_embeddings,
)

return merged


class ConceptValuePredictionLayer(nn.Module):
def __init__(self, embedding_size):
def __init__(self, embedding_size, layer_norm_eps):
super(ConceptValuePredictionLayer, self).__init__()
self.embedding_size = embedding_size
self.concept_value_decoder_layer = nn.Sequential(
nn.Linear(embedding_size, embedding_size // 2),
gelu_new,
nn.LayerNorm(embedding_size // 2, eps=layer_norm_eps),
nn.Linear(embedding_size // 2, 1),
gelu_new,
)

def forward(self, hidden_states: Optional[torch.FloatTensor]):
Expand Down Expand Up @@ -259,7 +261,7 @@ def __init__(self, config: CehrBertConfig):

self.bert = CehrBert(config)
if self.config.include_value_prediction:
self.concept_value_decoder_layer = ConceptValuePredictionLayer(config.hidden_size)
self.concept_value_decoder_layer = ConceptValuePredictionLayer(config.hidden_size, config.layer_norm_eps)
self.cls = BertOnlyMLMHead(config)

# Initialize weights and apply final processing
Expand Down Expand Up @@ -320,10 +322,9 @@ def forward(
if self.config.include_value_prediction:
mlm_masks = labels != -100
predicted_values = self.concept_value_decoder_layer(cehrbert_output.last_hidden_state)
num_items = torch.sum(concept_value_masks.to(torch.float32), dim=-1) + 1e-6
values_ = (predicted_values.squeeze(-1) - concept_values) ** 2
masked_mse = torch.sum(values_ * concept_value_masks * mlm_masks, dim=-1) / num_items
total_loss += torch.mean(masked_mse)
total_loss += torch.mean(
(predicted_values.squeeze(-1) - concept_values) ** 2 * concept_value_masks * mlm_masks
)

return CehrBertModelOutput(
loss=total_loss,
Expand Down
27 changes: 23 additions & 4 deletions src/cehrbert/models/hf_models/tokenization_hf_cehrbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from tokenizers.models import WordLevel
from tokenizers.pre_tokenizers import WhitespaceSplit
from tokenizers.trainers import WordLevelTrainer
from tqdm import tqdm
from transformers.tokenization_utils_base import PushToHubMixin

from cehrbert.models.hf_models.tokenization_utils import agg_helper, agg_statistics, map_statistics
Expand Down Expand Up @@ -64,6 +65,9 @@ def __init__(
"unit": lab_stat["unit"],
"mean": lab_stat["mean"],
"std": lab_stat["std"],
"value_outlier_std": lab_stat["value_outlier_std"],
"lower_bound": lab_stat["lower_bound"],
"upper_bound": lab_stat["upper_bound"],
}
for lab_stat in lab_stats
}
Expand Down Expand Up @@ -296,16 +300,22 @@ def batched_generator():

tokenizer.train_from_iterator(generator, trainer=trainer)

map_statistics_partial = partial(
map_statistics,
capacity=data_args.offline_stats_capacity,
value_outlier_std=data_args.value_outlier_std,
)

if data_args.streaming:
parts = dataset.map(
partial(agg_helper, map_func=map_statistics),
partial(agg_helper, map_func=map_statistics_partial),
batched=True,
batch_size=data_args.preprocessing_batch_size,
remove_columns=dataset.column_names,
)
else:
parts = dataset.map(
partial(agg_helper, map_func=map_statistics),
partial(agg_helper, map_func=map_statistics_partial),
batched=True,
batch_size=data_args.preprocessing_batch_size,
remove_columns=dataset.column_names,
Expand All @@ -314,19 +324,23 @@ def batched_generator():
new_fingerprint="invalid",
)
current = None
for stat in parts:
for stat in tqdm(parts, desc="Aggregating the lab statistics"):
fixed_stat = pickle.loads(stat["data"])
if current is None:
current = fixed_stat
else:
current = agg_statistics(current, fixed_stat)

lab_stats = [
{
"concept_id": concept_id,
"unit": unit,
"mean": online_stats.mean(),
"std": online_stats.standard_deviation(),
"count": online_stats.count,
"value_outlier_std": data_args.value_outlier_std,
"lower_bound": online_stats.mean() - data_args.value_outlier_std * online_stats.standard_deviation(),
"upper_bound": online_stats.mean() + data_args.value_outlier_std * online_stats.standard_deviation(),
}
for (concept_id, unit), online_stats in current["numeric_stats_by_lab"].items()
]
Expand All @@ -342,8 +356,13 @@ def normalize(self, concept_id, concept_value) -> float:
mean_ = concept_value - self._lab_stat_mapping[concept_id]["mean"]
std = self._lab_stat_mapping[concept_id]["std"]
if std > 0:
value_outlier_std = self._lab_stat_mapping[concept_id]["value_outlier_std"]
normalized_value = mean_ / self._lab_stat_mapping[concept_id]["std"]
# Clip the value between the lower and upper bounds of the corresponding lab
normalized_value = max(-value_outlier_std, min(value_outlier_std, normalized_value))
else:
normalized_value = mean_
# If there is not a valid standard deviation,
# we just the normalized value to the mean of the standard normal
normalized_value = 0.0
return normalized_value
return concept_value
10 changes: 6 additions & 4 deletions src/cehrbert/models/hf_models/tokenization_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import collections
import json
import pickle
from functools import partial
from typing import Any, Dict

from femr.stat_utils import OnlineStatistics
from cehrbert.utils.stat_utils import TruncatedOnlineStatistics


def load_json_file(json_file):
Expand All @@ -21,13 +22,14 @@ def agg_helper(*args, map_func):
return {"data": [pickle.dumps(result)]}


def map_statistics(batch: Dict[str, Any]) -> Dict[str, Any]:
def map_statistics(batch: Dict[str, Any], capacity=100, value_outlier_std=2.0) -> Dict[str, Any]:
if "units" in batch:
concept_value_units = batch["units"]
else:
concept_value_units = [["default_unit" for _ in cons] for cons in batch["concept_ids"]]

numeric_stats_by_lab = collections.defaultdict(OnlineStatistics)
numeric_stats_by_lab = collections.defaultdict(
partial(TruncatedOnlineStatistics, capacity=capacity, value_outlier_std=value_outlier_std)
)
for concept_ids, concept_values, concept_value_indicators, units in zip(
batch["concept_ids"],
batch["concept_values"],
Expand Down
10 changes: 10 additions & 0 deletions src/cehrbert/runners/hf_runner_argument_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,16 @@ class DataTrainingArguments:
default=False,
metadata={"help": "Indicates whether to randomly shuffle the records that have the same rank"},
)
offline_stats_capacity: Optional[int] = dataclasses.field(
default=100,
metadata={
"help": "The minimum num of lab values to collect for the truncated offline statistics before switching to the online statistics calculation"
},
)
value_outlier_std: Optional[float] = dataclasses.field(
default=3.0,
metadata={"help": "The lower quantile for excluding the extreme lower lab values"},
)


@dataclasses.dataclass
Expand Down
Loading
Loading