diff --git a/src/cehrbert/models/hf_models/tokenization_utils.py b/src/cehrbert/models/hf_models/tokenization_utils.py index f43d8f4c..35d09886 100644 --- a/src/cehrbert/models/hf_models/tokenization_utils.py +++ b/src/cehrbert/models/hf_models/tokenization_utils.py @@ -1,9 +1,10 @@ import collections import json import pickle +from functools import partial from typing import Any, Dict -from femr.stat_utils import OnlineStatistics +from cehrbert.utils.stat_utils import RunningStatistics def load_json_file(json_file): @@ -26,8 +27,9 @@ def map_statistics(batch: Dict[str, Any]) -> Dict[str, Any]: concept_value_units = batch["units"] else: concept_value_units = [["default_unit" for _ in cons] for cons in batch["concept_ids"]] - - numeric_stats_by_lab = collections.defaultdict(OnlineStatistics) + numeric_stats_by_lab = collections.defaultdict( + partial(RunningStatistics, capacity=100, lower_quantile=0.05, upper_quantile=0.95) + ) for concept_ids, concept_values, concept_value_indicators, units in zip( batch["concept_ids"], batch["concept_values"], diff --git a/src/cehrbert/utils/stat_utils.py b/src/cehrbert/utils/stat_utils.py index e69de29b..cc3e5b15 100644 --- a/src/cehrbert/utils/stat_utils.py +++ b/src/cehrbert/utils/stat_utils.py @@ -0,0 +1,98 @@ +import numpy as np +from femr.stat_utils import OnlineStatistics + + +class RunningStatistics(OnlineStatistics): + def __init__(self, capacity=100, lower_quantile=0.05, upper_quantile=0.950): + super().__init__() + self.excluding_outlier_online_statistics = ExcludingOutlierOnlineStatistics( + capacity=capacity, lower_quantile=lower_quantile, upper_quantile=upper_quantile + ) + + def add(self, weight: float, value: float) -> None: + if self.excluding_outlier_online_statistics.is_full(): + super().add(weight, value) + else: + self.excluding_outlier_online_statistics.add(value) + if self.excluding_outlier_online_statistics.is_full(): + self.current_mean = self.excluding_outlier_online_statistics.get_current_mean() + self.variance = self.excluding_outlier_online_statistics.get_sum_of_squared() + self.count = self.excluding_outlier_online_statistics.get_count() + + def mean(self) -> float: + """Return the current mean.""" + if self.excluding_outlier_online_statistics.is_full(): + return super().mean() + else: + self.excluding_outlier_online_statistics.get_current_mean() + + def standard_deviation(self) -> float: + """Return the current standard devation.""" + if self.excluding_outlier_online_statistics.is_full(): + return super().standard_deviation() + else: + return self.excluding_outlier_online_statistics.standard_deviation() + + +class ExcludingOutlierOnlineStatistics: + def __init__(self, capacity=100, lower_quantile=0.05, upper_quantile=0.950): + super().__init__() + self.lower_quantile = lower_quantile + self.upper_quantile = upper_quantile + self.capacity = capacity + self.raw_data = list() + self.filtered_data = list() + self.updated = False + + def reset(self): + self.raw_data.clear() + self.filtered_data.clear() + self.updated = False + + def is_full(self) -> bool: + return len(self.raw_data) >= self.capacity + + def add(self, value: float) -> None: + if len(self.raw_data) < self.capacity: + self.raw_data.append(value) + # When new data is added to the raw_data array, we need to update the filtered_data later on + self.updated = False + else: + raise ValueError(f"The capacity of the underlying data is full at {self.capacity}") + + def get_count(self) -> int: + self.update_remove_outliers() + return len(self.filtered_data) + + def get_current_mean(self) -> float: + self.update_remove_outliers() + if self.filtered_data: + return np.mean(self.filtered_data) + else: + raise ValueError(f"There is no value") + + def get_sum_of_squared(self) -> float: + self.update_remove_outliers() + if self.filtered_data: + current_mean = np.mean(self.filtered_data) + return np.sum([(x - current_mean) ** 2 for x in self.filtered_data]) + else: + raise ValueError(f"There is no value") + + def standard_deviation(self) -> float: + self.update_remove_outliers() + if self.filtered_data: + return np.std(self.filtered_data) + else: + raise ValueError(f"There is no value") + + def update_remove_outliers( + self, + ) -> None: + if not self.updated and len(self.raw_data) > 0: + # Toggle the updated variable + self.updated = True + lower_bound = np.quantile(self.raw_data, self.lower_quantile) + upper_bound = np.quantile(self.raw_data, self.upper_quantile) + # Update the filtered_data + self.filtered_data = [x for x in self.raw_data if lower_bound <= x <= upper_bound] diff --git a/tests/unit_tests/utils/__init__.py b/tests/unit_tests/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit_tests/utils/excluding_outliers_online_statistics_test.py b/tests/unit_tests/utils/excluding_outliers_online_statistics_test.py new file mode 100644 index 00000000..5e0c9ced --- /dev/null +++ b/tests/unit_tests/utils/excluding_outliers_online_statistics_test.py @@ -0,0 +1,85 @@ +import unittest + +import numpy as np + +from cehrbert.utils.stat_utils import ExcludingOutlierOnlineStatistics # Replace with the actual module name + + +class TestExcludingOutlierOnlineStatistics(unittest.TestCase): + + def setUp(self): + # Create an instance of ExcludingOutlierOnlineStatistics with default settings + self.stats = ExcludingOutlierOnlineStatistics(capacity=10, lower_quantile=0.05, upper_quantile=0.95) + + def test_add_data_within_capacity(self): + # Test adding data within the capacity + for i in range(10): + self.stats.add(i) + self.assertEqual(len(self.stats.raw_data), 10) + self.stats.reset() + + def test_add_data_beyond_capacity(self): + # Test adding data beyond the capacity + for i in range(10): + self.stats.add(i) + with self.assertRaises(ValueError): + self.stats.add(11) + self.stats.reset() + + def test_remove_outliers(self): + # Test removing outliers + data = [10, 12, 13, 14, 99999, 15, 16, 17, 18, -1000] + for x in data: + self.stats.add(x) + + # Trigger outlier removal + self.stats.update_remove_outliers() + + # The expected filtered data excludes -1000 and 1000 since they are extreme values + expected_filtered_data = [10, 12, 13, 14, 15, 16, 17, 18] + self.assertListEqual(self.stats.filtered_data, expected_filtered_data) + + def test_mean_calculation(self): + # Test the mean calculation after excluding outliers + data = [10, 12, 13, 14, 1000, 15, 16, 17, 18, -1000] + for x in data: + self.stats.add(x) + + # Test mean after excluding outliers + mean = self.stats.get_current_mean() + expected_mean = np.mean([10, 12, 13, 14, 15, 16, 17, 18]) + self.assertAlmostEqual(mean, expected_mean, places=5) + + def test_get_sum_of_squared(self): + # Test removing outliers + data = [10, 12, 13, 14, 99999, 15, 16, 17, 18, -1000] + for x in data: + self.stats.add(x) + + actual_sum_of_squared = self.stats.get_sum_of_squared() + expected_sum_of_squares = np.sum( + (np.asarray([10, 12, 13, 14, 15, 16, 17, 18]) - np.mean([10, 12, 13, 14, 15, 16, 17, 18])) ** 2 + ) + self.assertEqual(actual_sum_of_squared, expected_sum_of_squares) + self.stats.reset() + + def test_standard_deviation_calculation(self): + # Test the standard deviation after excluding outliers + data = [10, 12, 13, 14, 1000, 15, 16, 17, 18, -1000] + for x in data: + self.stats.add(x) + + # Test standard deviation after excluding outliers + stddev = self.stats.standard_deviation() + expected_stddev = np.std([10, 12, 13, 14, 15, 16, 17, 18], ddof=0) + self.assertAlmostEqual(stddev, expected_stddev, places=5) + self.stats.reset() + + def test_empty_filtered_data(self): + # Test when no data is present + self.assertRaises(ValueError, self.stats.get_current_mean) + self.assertRaises(ValueError, self.stats.standard_deviation) + + +if __name__ == "__main__": + unittest.main()