Skip to content

Commit

Permalink
added RunningStatistics to remove the extreme outliers when calculati…
Browse files Browse the repository at this point in the history
…ng the running mean and std
  • Loading branch information
ChaoPang committed Sep 12, 2024
1 parent 34b16d8 commit 75cca56
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/cehrbert/models/hf_models/tokenization_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import collections
import json
import pickle
from functools import partial
from typing import Any, Dict

from femr.stat_utils import OnlineStatistics
from cehrbert.utils.stat_utils import RunningStatistics


def load_json_file(json_file):
Expand All @@ -26,8 +27,9 @@ def map_statistics(batch: Dict[str, Any]) -> Dict[str, Any]:
concept_value_units = batch["units"]
else:
concept_value_units = [["default_unit" for _ in cons] for cons in batch["concept_ids"]]

numeric_stats_by_lab = collections.defaultdict(OnlineStatistics)
numeric_stats_by_lab = collections.defaultdict(
partial(RunningStatistics, capacity=100, lower_quantile=0.05, upper_quantile=0.95)
)
for concept_ids, concept_values, concept_value_indicators, units in zip(
batch["concept_ids"],
batch["concept_values"],
Expand Down
98 changes: 98 additions & 0 deletions src/cehrbert/utils/stat_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import numpy as np
from femr.stat_utils import OnlineStatistics


class RunningStatistics(OnlineStatistics):
def __init__(self, capacity=100, lower_quantile=0.05, upper_quantile=0.950):
super().__init__()
self.excluding_outlier_online_statistics = ExcludingOutlierOnlineStatistics(
capacity=capacity, lower_quantile=lower_quantile, upper_quantile=upper_quantile
)

def add(self, weight: float, value: float) -> None:
if self.excluding_outlier_online_statistics.is_full():
super().add(weight, value)
else:
self.excluding_outlier_online_statistics.add(value)
if self.excluding_outlier_online_statistics.is_full():
self.current_mean = self.excluding_outlier_online_statistics.get_current_mean()
self.variance = self.excluding_outlier_online_statistics.get_sum_of_squared()
self.count = self.excluding_outlier_online_statistics.get_count()

def mean(self) -> float:
"""Return the current mean."""
if self.excluding_outlier_online_statistics.is_full():
return super().mean()
else:
self.excluding_outlier_online_statistics.get_current_mean()

def standard_deviation(self) -> float:
"""Return the current standard devation."""
if self.excluding_outlier_online_statistics.is_full():
return super().standard_deviation()
else:
return self.excluding_outlier_online_statistics.standard_deviation()


class ExcludingOutlierOnlineStatistics:
def __init__(self, capacity=100, lower_quantile=0.05, upper_quantile=0.950):
super().__init__()
self.lower_quantile = lower_quantile
self.upper_quantile = upper_quantile
self.capacity = capacity
self.raw_data = list()
self.filtered_data = list()
self.updated = False

def reset(self):
self.raw_data.clear()
self.filtered_data.clear()
self.updated = False

def is_full(self) -> bool:
return len(self.raw_data) >= self.capacity

def add(self, value: float) -> None:
if len(self.raw_data) < self.capacity:
self.raw_data.append(value)
# When new data is added to the raw_data array, we need to update the filtered_data later on
self.updated = False
else:
raise ValueError(f"The capacity of the underlying data is full at {self.capacity}")

def get_count(self) -> int:
self.update_remove_outliers()
return len(self.filtered_data)

def get_current_mean(self) -> float:
self.update_remove_outliers()
if self.filtered_data:
return np.mean(self.filtered_data)
else:
raise ValueError(f"There is no value")

def get_sum_of_squared(self) -> float:
self.update_remove_outliers()
if self.filtered_data:
current_mean = np.mean(self.filtered_data)
return np.sum([(x - current_mean) ** 2 for x in self.filtered_data])
else:
raise ValueError(f"There is no value")

def standard_deviation(self) -> float:
self.update_remove_outliers()
if self.filtered_data:
return np.std(self.filtered_data)
else:
raise ValueError(f"There is no value")

def update_remove_outliers(
self,
) -> None:
if not self.updated and len(self.raw_data) > 0:
# Toggle the updated variable
self.updated = True
lower_bound = np.quantile(self.raw_data, self.lower_quantile)
upper_bound = np.quantile(self.raw_data, self.upper_quantile)
# Update the filtered_data
self.filtered_data = [x for x in self.raw_data if lower_bound <= x <= upper_bound]
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import unittest

import numpy as np

from cehrbert.utils.stat_utils import ExcludingOutlierOnlineStatistics # Replace with the actual module name


class TestExcludingOutlierOnlineStatistics(unittest.TestCase):

def setUp(self):
# Create an instance of ExcludingOutlierOnlineStatistics with default settings
self.stats = ExcludingOutlierOnlineStatistics(capacity=10, lower_quantile=0.05, upper_quantile=0.95)

def test_add_data_within_capacity(self):
# Test adding data within the capacity
for i in range(10):
self.stats.add(i)
self.assertEqual(len(self.stats.raw_data), 10)
self.stats.reset()

def test_add_data_beyond_capacity(self):
# Test adding data beyond the capacity
for i in range(10):
self.stats.add(i)
with self.assertRaises(ValueError):
self.stats.add(11)
self.stats.reset()

def test_remove_outliers(self):
# Test removing outliers
data = [10, 12, 13, 14, 99999, 15, 16, 17, 18, -1000]
for x in data:
self.stats.add(x)

# Trigger outlier removal
self.stats.update_remove_outliers()

# The expected filtered data excludes -1000 and 1000 since they are extreme values
expected_filtered_data = [10, 12, 13, 14, 15, 16, 17, 18]
self.assertListEqual(self.stats.filtered_data, expected_filtered_data)

def test_mean_calculation(self):
# Test the mean calculation after excluding outliers
data = [10, 12, 13, 14, 1000, 15, 16, 17, 18, -1000]
for x in data:
self.stats.add(x)

# Test mean after excluding outliers
mean = self.stats.get_current_mean()
expected_mean = np.mean([10, 12, 13, 14, 15, 16, 17, 18])
self.assertAlmostEqual(mean, expected_mean, places=5)

def test_get_sum_of_squared(self):
# Test removing outliers
data = [10, 12, 13, 14, 99999, 15, 16, 17, 18, -1000]
for x in data:
self.stats.add(x)

actual_sum_of_squared = self.stats.get_sum_of_squared()
expected_sum_of_squares = np.sum(
(np.asarray([10, 12, 13, 14, 15, 16, 17, 18]) - np.mean([10, 12, 13, 14, 15, 16, 17, 18])) ** 2
)
self.assertEqual(actual_sum_of_squared, expected_sum_of_squares)
self.stats.reset()

def test_standard_deviation_calculation(self):
# Test the standard deviation after excluding outliers
data = [10, 12, 13, 14, 1000, 15, 16, 17, 18, -1000]
for x in data:
self.stats.add(x)

# Test standard deviation after excluding outliers
stddev = self.stats.standard_deviation()
expected_stddev = np.std([10, 12, 13, 14, 15, 16, 17, 18], ddof=0)
self.assertAlmostEqual(stddev, expected_stddev, places=5)
self.stats.reset()

def test_empty_filtered_data(self):
# Test when no data is present
self.assertRaises(ValueError, self.stats.get_current_mean)
self.assertRaises(ValueError, self.stats.standard_deviation)


if __name__ == "__main__":
unittest.main()

0 comments on commit 75cca56

Please sign in to comment.