-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added RunningStatistics to remove the extreme outliers when calculati…
…ng the running mean and std
- Loading branch information
Showing
4 changed files
with
188 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import numpy as np | ||
from femr.stat_utils import OnlineStatistics | ||
|
||
|
||
class RunningStatistics(OnlineStatistics): | ||
def __init__(self, capacity=100, lower_quantile=0.05, upper_quantile=0.950): | ||
super().__init__() | ||
self.excluding_outlier_online_statistics = ExcludingOutlierOnlineStatistics( | ||
capacity=capacity, lower_quantile=lower_quantile, upper_quantile=upper_quantile | ||
) | ||
|
||
def add(self, weight: float, value: float) -> None: | ||
if self.excluding_outlier_online_statistics.is_full(): | ||
super().add(weight, value) | ||
else: | ||
self.excluding_outlier_online_statistics.add(value) | ||
if self.excluding_outlier_online_statistics.is_full(): | ||
self.current_mean = self.excluding_outlier_online_statistics.get_current_mean() | ||
self.variance = self.excluding_outlier_online_statistics.get_sum_of_squared() | ||
self.count = self.excluding_outlier_online_statistics.get_count() | ||
|
||
def mean(self) -> float: | ||
"""Return the current mean.""" | ||
if self.excluding_outlier_online_statistics.is_full(): | ||
return super().mean() | ||
else: | ||
self.excluding_outlier_online_statistics.get_current_mean() | ||
|
||
def standard_deviation(self) -> float: | ||
"""Return the current standard devation.""" | ||
if self.excluding_outlier_online_statistics.is_full(): | ||
return super().standard_deviation() | ||
else: | ||
return self.excluding_outlier_online_statistics.standard_deviation() | ||
|
||
|
||
class ExcludingOutlierOnlineStatistics: | ||
def __init__(self, capacity=100, lower_quantile=0.05, upper_quantile=0.950): | ||
super().__init__() | ||
self.lower_quantile = lower_quantile | ||
self.upper_quantile = upper_quantile | ||
self.capacity = capacity | ||
self.raw_data = list() | ||
self.filtered_data = list() | ||
self.updated = False | ||
|
||
def reset(self): | ||
self.raw_data.clear() | ||
self.filtered_data.clear() | ||
self.updated = False | ||
|
||
def is_full(self) -> bool: | ||
return len(self.raw_data) >= self.capacity | ||
|
||
def add(self, value: float) -> None: | ||
if len(self.raw_data) < self.capacity: | ||
self.raw_data.append(value) | ||
# When new data is added to the raw_data array, we need to update the filtered_data later on | ||
self.updated = False | ||
else: | ||
raise ValueError(f"The capacity of the underlying data is full at {self.capacity}") | ||
|
||
def get_count(self) -> int: | ||
self.update_remove_outliers() | ||
return len(self.filtered_data) | ||
|
||
def get_current_mean(self) -> float: | ||
self.update_remove_outliers() | ||
if self.filtered_data: | ||
return np.mean(self.filtered_data) | ||
else: | ||
raise ValueError(f"There is no value") | ||
|
||
def get_sum_of_squared(self) -> float: | ||
self.update_remove_outliers() | ||
if self.filtered_data: | ||
current_mean = np.mean(self.filtered_data) | ||
return np.sum([(x - current_mean) ** 2 for x in self.filtered_data]) | ||
else: | ||
raise ValueError(f"There is no value") | ||
|
||
def standard_deviation(self) -> float: | ||
self.update_remove_outliers() | ||
if self.filtered_data: | ||
return np.std(self.filtered_data) | ||
else: | ||
raise ValueError(f"There is no value") | ||
|
||
def update_remove_outliers( | ||
self, | ||
) -> None: | ||
if not self.updated and len(self.raw_data) > 0: | ||
# Toggle the updated variable | ||
self.updated = True | ||
lower_bound = np.quantile(self.raw_data, self.lower_quantile) | ||
upper_bound = np.quantile(self.raw_data, self.upper_quantile) | ||
# Update the filtered_data | ||
self.filtered_data = [x for x in self.raw_data if lower_bound <= x <= upper_bound] |
Empty file.
85 changes: 85 additions & 0 deletions
85
tests/unit_tests/utils/excluding_outliers_online_statistics_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import unittest | ||
|
||
import numpy as np | ||
|
||
from cehrbert.utils.stat_utils import ExcludingOutlierOnlineStatistics # Replace with the actual module name | ||
|
||
|
||
class TestExcludingOutlierOnlineStatistics(unittest.TestCase): | ||
|
||
def setUp(self): | ||
# Create an instance of ExcludingOutlierOnlineStatistics with default settings | ||
self.stats = ExcludingOutlierOnlineStatistics(capacity=10, lower_quantile=0.05, upper_quantile=0.95) | ||
|
||
def test_add_data_within_capacity(self): | ||
# Test adding data within the capacity | ||
for i in range(10): | ||
self.stats.add(i) | ||
self.assertEqual(len(self.stats.raw_data), 10) | ||
self.stats.reset() | ||
|
||
def test_add_data_beyond_capacity(self): | ||
# Test adding data beyond the capacity | ||
for i in range(10): | ||
self.stats.add(i) | ||
with self.assertRaises(ValueError): | ||
self.stats.add(11) | ||
self.stats.reset() | ||
|
||
def test_remove_outliers(self): | ||
# Test removing outliers | ||
data = [10, 12, 13, 14, 99999, 15, 16, 17, 18, -1000] | ||
for x in data: | ||
self.stats.add(x) | ||
|
||
# Trigger outlier removal | ||
self.stats.update_remove_outliers() | ||
|
||
# The expected filtered data excludes -1000 and 1000 since they are extreme values | ||
expected_filtered_data = [10, 12, 13, 14, 15, 16, 17, 18] | ||
self.assertListEqual(self.stats.filtered_data, expected_filtered_data) | ||
|
||
def test_mean_calculation(self): | ||
# Test the mean calculation after excluding outliers | ||
data = [10, 12, 13, 14, 1000, 15, 16, 17, 18, -1000] | ||
for x in data: | ||
self.stats.add(x) | ||
|
||
# Test mean after excluding outliers | ||
mean = self.stats.get_current_mean() | ||
expected_mean = np.mean([10, 12, 13, 14, 15, 16, 17, 18]) | ||
self.assertAlmostEqual(mean, expected_mean, places=5) | ||
|
||
def test_get_sum_of_squared(self): | ||
# Test removing outliers | ||
data = [10, 12, 13, 14, 99999, 15, 16, 17, 18, -1000] | ||
for x in data: | ||
self.stats.add(x) | ||
|
||
actual_sum_of_squared = self.stats.get_sum_of_squared() | ||
expected_sum_of_squares = np.sum( | ||
(np.asarray([10, 12, 13, 14, 15, 16, 17, 18]) - np.mean([10, 12, 13, 14, 15, 16, 17, 18])) ** 2 | ||
) | ||
self.assertEqual(actual_sum_of_squared, expected_sum_of_squares) | ||
self.stats.reset() | ||
|
||
def test_standard_deviation_calculation(self): | ||
# Test the standard deviation after excluding outliers | ||
data = [10, 12, 13, 14, 1000, 15, 16, 17, 18, -1000] | ||
for x in data: | ||
self.stats.add(x) | ||
|
||
# Test standard deviation after excluding outliers | ||
stddev = self.stats.standard_deviation() | ||
expected_stddev = np.std([10, 12, 13, 14, 15, 16, 17, 18], ddof=0) | ||
self.assertAlmostEqual(stddev, expected_stddev, places=5) | ||
self.stats.reset() | ||
|
||
def test_empty_filtered_data(self): | ||
# Test when no data is present | ||
self.assertRaises(ValueError, self.stats.get_current_mean) | ||
self.assertRaises(ValueError, self.stats.standard_deviation) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |