diff --git a/src/cehrbert/data_generators/hf_data_generator/hf_dataset.py b/src/cehrbert/data_generators/hf_data_generator/hf_dataset.py index 6212a522..e06d1176 100644 --- a/src/cehrbert/data_generators/hf_data_generator/hf_dataset.py +++ b/src/cehrbert/data_generators/hf_data_generator/hf_dataset.py @@ -12,6 +12,7 @@ from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments CEHRBERT_COLUMNS = [ + "person_id", "concept_ids", "ages", "dates", diff --git a/src/cehrbert/data_generators/hf_data_generator/hf_dataset_mapping.py b/src/cehrbert/data_generators/hf_data_generator/hf_dataset_mapping.py index f5cf17d0..f3377e9c 100644 --- a/src/cehrbert/data_generators/hf_data_generator/hf_dataset_mapping.py +++ b/src/cehrbert/data_generators/hf_data_generator/hf_dataset_mapping.py @@ -115,7 +115,7 @@ def __init__(self, data_args: DataTrainingArguments, is_pretraining: bool = True def remove_columns(self): if self._is_pretraining: - return ["visits", "patient_id", "birth_datetime", "index_date"] + return ["visits", "birth_datetime", "index_date"] else: return [ "visits", diff --git a/src/cehrbert/models/hf_models/hf_cehrbert.py b/src/cehrbert/models/hf_models/hf_cehrbert.py index 4d07361e..96e37d9d 100644 --- a/src/cehrbert/models/hf_models/hf_cehrbert.py +++ b/src/cehrbert/models/hf_models/hf_cehrbert.py @@ -96,7 +96,7 @@ def forward( merged = torch.where( concept_value_masks.to(torch.bool), - concept_embeddings_with_val, + gelu_new(concept_embeddings_with_val), concept_embeddings, ) @@ -104,13 +104,15 @@ def forward( class ConceptValuePredictionLayer(nn.Module): - def __init__(self, embedding_size): + def __init__(self, embedding_size, layer_norm_eps): super(ConceptValuePredictionLayer, self).__init__() self.embedding_size = embedding_size self.concept_value_decoder_layer = nn.Sequential( nn.Linear(embedding_size, embedding_size // 2), gelu_new, + nn.LayerNorm(embedding_size // 2, eps=layer_norm_eps), nn.Linear(embedding_size // 2, 1), + gelu_new, ) def forward(self, hidden_states: Optional[torch.FloatTensor]): @@ -259,7 +261,7 @@ def __init__(self, config: CehrBertConfig): self.bert = CehrBert(config) if self.config.include_value_prediction: - self.concept_value_decoder_layer = ConceptValuePredictionLayer(config.hidden_size) + self.concept_value_decoder_layer = ConceptValuePredictionLayer(config.hidden_size, config.layer_norm_eps) self.cls = BertOnlyMLMHead(config) # Initialize weights and apply final processing @@ -320,10 +322,9 @@ def forward( if self.config.include_value_prediction: mlm_masks = labels != -100 predicted_values = self.concept_value_decoder_layer(cehrbert_output.last_hidden_state) - num_items = torch.sum(concept_value_masks.to(torch.float32), dim=-1) + 1e-6 - values_ = (predicted_values.squeeze(-1) - concept_values) ** 2 - masked_mse = torch.sum(values_ * concept_value_masks * mlm_masks, dim=-1) / num_items - total_loss += torch.mean(masked_mse) + total_loss += torch.mean( + (predicted_values.squeeze(-1) - concept_values) ** 2 * concept_value_masks * mlm_masks + ) return CehrBertModelOutput( loss=total_loss, diff --git a/src/cehrbert/models/hf_models/tokenization_hf_cehrbert.py b/src/cehrbert/models/hf_models/tokenization_hf_cehrbert.py index 65c4ce61..27ab92df 100644 --- a/src/cehrbert/models/hf_models/tokenization_hf_cehrbert.py +++ b/src/cehrbert/models/hf_models/tokenization_hf_cehrbert.py @@ -11,6 +11,7 @@ from tokenizers.models import WordLevel from tokenizers.pre_tokenizers import WhitespaceSplit from tokenizers.trainers import WordLevelTrainer +from tqdm import tqdm from transformers.tokenization_utils_base import PushToHubMixin from cehrbert.models.hf_models.tokenization_utils import agg_helper, agg_statistics, map_statistics @@ -64,6 +65,9 @@ def __init__( "unit": lab_stat["unit"], "mean": lab_stat["mean"], "std": lab_stat["std"], + "value_outlier_std": lab_stat["value_outlier_std"], + "lower_bound": lab_stat["lower_bound"], + "upper_bound": lab_stat["upper_bound"], } for lab_stat in lab_stats } @@ -296,16 +300,22 @@ def batched_generator(): tokenizer.train_from_iterator(generator, trainer=trainer) + map_statistics_partial = partial( + map_statistics, + capacity=data_args.offline_stats_capacity, + value_outlier_std=data_args.value_outlier_std, + ) + if data_args.streaming: parts = dataset.map( - partial(agg_helper, map_func=map_statistics), + partial(agg_helper, map_func=map_statistics_partial), batched=True, batch_size=data_args.preprocessing_batch_size, remove_columns=dataset.column_names, ) else: parts = dataset.map( - partial(agg_helper, map_func=map_statistics), + partial(agg_helper, map_func=map_statistics_partial), batched=True, batch_size=data_args.preprocessing_batch_size, remove_columns=dataset.column_names, @@ -314,12 +324,13 @@ def batched_generator(): new_fingerprint="invalid", ) current = None - for stat in parts: + for stat in tqdm(parts, desc="Aggregating the lab statistics"): fixed_stat = pickle.loads(stat["data"]) if current is None: current = fixed_stat else: current = agg_statistics(current, fixed_stat) + lab_stats = [ { "concept_id": concept_id, @@ -327,6 +338,9 @@ def batched_generator(): "mean": online_stats.mean(), "std": online_stats.standard_deviation(), "count": online_stats.count, + "value_outlier_std": data_args.value_outlier_std, + "lower_bound": online_stats.mean() - data_args.value_outlier_std * online_stats.standard_deviation(), + "upper_bound": online_stats.mean() + data_args.value_outlier_std * online_stats.standard_deviation(), } for (concept_id, unit), online_stats in current["numeric_stats_by_lab"].items() ] @@ -342,8 +356,13 @@ def normalize(self, concept_id, concept_value) -> float: mean_ = concept_value - self._lab_stat_mapping[concept_id]["mean"] std = self._lab_stat_mapping[concept_id]["std"] if std > 0: + value_outlier_std = self._lab_stat_mapping[concept_id]["value_outlier_std"] normalized_value = mean_ / self._lab_stat_mapping[concept_id]["std"] + # Clip the value between the lower and upper bounds of the corresponding lab + normalized_value = max(-value_outlier_std, min(value_outlier_std, normalized_value)) else: - normalized_value = mean_ + # If there is not a valid standard deviation, + # we just the normalized value to the mean of the standard normal + normalized_value = 0.0 return normalized_value return concept_value diff --git a/src/cehrbert/models/hf_models/tokenization_utils.py b/src/cehrbert/models/hf_models/tokenization_utils.py index f43d8f4c..ee2515c0 100644 --- a/src/cehrbert/models/hf_models/tokenization_utils.py +++ b/src/cehrbert/models/hf_models/tokenization_utils.py @@ -1,9 +1,10 @@ import collections import json import pickle +from functools import partial from typing import Any, Dict -from femr.stat_utils import OnlineStatistics +from cehrbert.utils.stat_utils import TruncatedOnlineStatistics def load_json_file(json_file): @@ -21,13 +22,14 @@ def agg_helper(*args, map_func): return {"data": [pickle.dumps(result)]} -def map_statistics(batch: Dict[str, Any]) -> Dict[str, Any]: +def map_statistics(batch: Dict[str, Any], capacity=100, value_outlier_std=2.0) -> Dict[str, Any]: if "units" in batch: concept_value_units = batch["units"] else: concept_value_units = [["default_unit" for _ in cons] for cons in batch["concept_ids"]] - - numeric_stats_by_lab = collections.defaultdict(OnlineStatistics) + numeric_stats_by_lab = collections.defaultdict( + partial(TruncatedOnlineStatistics, capacity=capacity, value_outlier_std=value_outlier_std) + ) for concept_ids, concept_values, concept_value_indicators, units in zip( batch["concept_ids"], batch["concept_values"], diff --git a/src/cehrbert/runners/hf_runner_argument_dataclass.py b/src/cehrbert/runners/hf_runner_argument_dataclass.py index 055fe411..1cef3b56 100644 --- a/src/cehrbert/runners/hf_runner_argument_dataclass.py +++ b/src/cehrbert/runners/hf_runner_argument_dataclass.py @@ -156,6 +156,16 @@ class DataTrainingArguments: default=False, metadata={"help": "Indicates whether to randomly shuffle the records that have the same rank"}, ) + offline_stats_capacity: Optional[int] = dataclasses.field( + default=100, + metadata={ + "help": "The minimum num of lab values to collect for the truncated offline statistics before switching to the online statistics calculation" + }, + ) + value_outlier_std: Optional[float] = dataclasses.field( + default=3.0, + metadata={"help": "The lower quantile for excluding the extreme lower lab values"}, + ) @dataclasses.dataclass diff --git a/src/cehrbert/utils/stat_utils.py b/src/cehrbert/utils/stat_utils.py new file mode 100644 index 00000000..d17a96ee --- /dev/null +++ b/src/cehrbert/utils/stat_utils.py @@ -0,0 +1,252 @@ +import numpy as np +import scipy.stats as stats +from femr.stat_utils import OnlineStatistics + + +class TruncatedOnlineStatistics(OnlineStatistics): + + def __init__(self, capacity=100, value_outlier_std=2.0): + super().__init__() + self.is_online_update_started = False + self.value_outlier_std = value_outlier_std + self.truncated_offline_statistics = TruncatedOfflineStatistics( + capacity=capacity, value_outlier_std=value_outlier_std + ) + + def add(self, weight: float, value: float) -> None: + if self.is_online_update_started: + std = self.standard_deviation() + if ( + self.current_mean - self.value_outlier_std * std + <= value + <= self.current_mean + self.value_outlier_std * std + ): + super().add(weight, value) + else: + self.truncated_offline_statistics.add(value) + if self.truncated_offline_statistics.is_full(): + self.begin_online_stats() + + def mean(self) -> float: + """Return the current mean.""" + if self.is_online_update_started: + return super().mean() + else: + return self.truncated_offline_statistics.get_current_mean() + + def standard_deviation(self) -> float: + """Return the current standard deviation.""" + # If the count is zero, we don't calculate the standard deviation + if self.count == 0: + return 0.0 + elif self.is_online_update_started: + return super().standard_deviation() + else: + return self.truncated_offline_statistics.get_standard_deviation() + + def begin_online_stats(self): + + # This prevents the online stats from being started twice + if self.is_online_update_started: + raise RuntimeError(f"The statistics has already been brought online, you can't start the online mode twice") + + self.is_online_update_started = True + self.current_mean = self.truncated_offline_statistics.get_current_mean() + self.variance = self.truncated_offline_statistics.get_sum_of_squared() + self.count = self.truncated_offline_statistics.get_count() + + def combine(self, other) -> None: + """ + The two truncated online stats objects need to be brought to the online mode before the stats are combined. + + Args: + other: + + Returns: + """ + if not self.is_online_update_started: + self.begin_online_stats() + if not other.is_online_update_started: + other.begin_online_stats() + super().combine(other) + + +class TruncatedOfflineStatistics: + """ + A class to compute and maintain statistics for a dataset while excluding outliers based on a. + + truncated normal distribution defined by a specified number of standard deviations. + + This class supports offline data collection (i.e., data is accumulated until capacity is reached), + and outliers beyond the specified number of standard deviations are excluded before computing + statistics such as mean and standard deviation. + + Attributes: + ----------- + capacity : int + The maximum number of data points that can be stored. + value_outlier_std : float + The number of standard deviations used to define outliers. Data points outside this range + are excluded when computing statistics. + lower_quantile : float + The quantile corresponding to the lower bound for valid data, computed as the cumulative + distribution function (CDF) of -`value_outlier_std`. + upper_quantile : float + The quantile corresponding to the upper bound for valid data, computed as the CDF of + `value_outlier_std`. + raw_data : list + A list to store all incoming data points. + filtered_data : list + A list to store data points after removing outliers. + updated : bool + A flag that indicates whether the filtered data has been updated after new data points were + added. + """ + + def __init__(self, capacity=100, value_outlier_std=2.0): + """ + Initializes the TruncatedOfflineStatistics instance with a capacity and standard deviation threshold. + + for outlier detection. + + Parameters: + ----------- + capacity : int, optional + The maximum number of data points to store (default is 100). + value_outlier_std : float, optional + The number of standard deviations to use for outlier detection (default is 2.0). + """ + super().__init__() + self.lower_quantile = stats.norm.cdf(-value_outlier_std) + self.upper_quantile = stats.norm.cdf(value_outlier_std) + self.capacity = capacity + self.raw_data = list() + self.filtered_data = list() + self.updated = False + + def is_full(self) -> bool: + """ + Checks if the number of data points in the `raw_data` list has reached the capacity. + + Returns: + -------- + bool + True if the number of data points is greater than or equal to the capacity, otherwise False. + """ + return len(self.raw_data) >= self.capacity + + def add(self, value: float) -> None: + """ + Adds a new data point to the `raw_data` list if the capacity is not full. + + If the capacity is reached, raises a ValueError. + Also marks the `updated` flag as False to indicate that the filtered data needs to be refreshed. + + Parameters: + ----------- + value : float + The new data point to be added to the dataset. + + Raises: + ------- + ValueError: + If the capacity of the underlying data is full. + """ + if len(self.raw_data) < self.capacity: + self.raw_data.append(value) + # When new data is added to the raw_data array, we need to update the filtered_data later on + self.updated = False + else: + raise ValueError(f"The capacity of the underlying data is full at {self.capacity}") + + def get_count(self) -> int: + """ + Returns the count of data points in `filtered_data` after removing outliers. + + Returns: + -------- + int + The number of data points that remain after outliers are filtered out. + """ + self._update_filtered_data() + return len(self.filtered_data) + + def get_current_mean(self) -> float: + """ + Computes and returns the mean of the `filtered_data` (excluding outliers). + + Returns: + -------- + float + The mean of the filtered data. Returns 0.0 if there are no valid data points. + """ + self._update_filtered_data() + if self.filtered_data: + return np.mean(self.filtered_data) + else: + return 0.0 + + def get_sum_of_squared(self) -> float: + """ + Computes the sum of squared differences from the mean for the `filtered_data`. + + Returns: + -------- + float + The sum of squared differences from the mean for the filtered data. + Returns 0.0 if no valid data points are present. + """ + self._update_filtered_data() + if self.filtered_data: + current_mean = np.mean(self.filtered_data) + return np.sum([(x - current_mean) ** 2 for x in self.filtered_data]) + else: + return 0.0 + + def get_standard_deviation(self) -> float: + """ + Computes the standard deviation of the `filtered_data` (excluding outliers). + + Returns: + -------- + float + The standard deviation of the filtered data. + Returns 0.0 if there are no valid data points. + """ + self._update_filtered_data() + if self.filtered_data: + return np.std(self.filtered_data) + else: + return 0.0 + + def _update_filtered_data( + self, + ) -> None: + """ + Filters the `raw_data` to remove outliers based on the `value_outlier_std` threshold. + + This method is called internally before any computation of statistics to ensure + that the data being used is current and valid. + + This method uses the `lower_quantile` and `upper_quantile` to filter the data points. + """ + if not self.updated and len(self.raw_data) > 0: + # Toggle the updated variable + self.updated = True + lower_bound = np.quantile(self.raw_data, self.lower_quantile) + upper_bound = np.quantile(self.raw_data, self.upper_quantile) + # Update the filtered_data + self.filtered_data = [x for x in self.raw_data if lower_bound <= x <= upper_bound] + + def reset(self): + """ + Resets the raw and filtered data, clearing all stored data points. + + This method also resets the `updated` flag to indicate that the data needs to be re-filtered + when new data is added. + + A useful method for unittests + """ + self.raw_data.clear() + self.filtered_data.clear() + self.updated = False diff --git a/tests/unit_tests/utils/__init__.py b/tests/unit_tests/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit_tests/utils/truncated_offline_statistics_test.py b/tests/unit_tests/utils/truncated_offline_statistics_test.py new file mode 100644 index 00000000..d90ee6c5 --- /dev/null +++ b/tests/unit_tests/utils/truncated_offline_statistics_test.py @@ -0,0 +1,93 @@ +import unittest + +import numpy as np + +from cehrbert.utils.stat_utils import TruncatedOfflineStatistics # Replace with the actual module name + + +class TestTruncatedOfflineStatistics(unittest.TestCase): + + def setUp(self): + # Create an instance of ExcludingOutlierOnlineStatistics with default settings + self.stats = TruncatedOfflineStatistics(capacity=10, value_outlier_std=2.0) + + def test_two_elements(self): + # Test adding data within the capacity + for i in range(2): + self.stats.add(i) + self.stats._update_filtered_data() + self.assertEqual(0, len(self.stats.filtered_data)) + + def test_add_data_within_capacity(self): + # Test adding data within the capacity + for i in range(10): + self.stats.add(i) + self.assertEqual(len(self.stats.raw_data), 10) + self.stats.reset() + + def test_add_data_beyond_capacity(self): + # Test adding data beyond the capacity + for i in range(10): + self.stats.add(i) + with self.assertRaises(ValueError): + self.stats.add(11) + self.stats.reset() + + def test_remove_outliers(self): + # Test removing outliers + data = [10, 12, 13, 14, 99999, 15, 16, 17, 18, -1000] + for x in data: + self.stats.add(x) + + # Trigger outlier removal + self.stats._update_filtered_data() + + # The expected filtered data excludes -1000 and 1000 since they are extreme values + expected_filtered_data = [10, 12, 13, 14, 15, 16, 17, 18] + self.assertListEqual(self.stats.filtered_data, expected_filtered_data) + + def test_mean_calculation(self): + # Test the mean calculation after excluding outliers + data = [10, 12, 13, 14, 1000, 15, 16, 17, 18, -1000] + for x in data: + self.stats.add(x) + + # Test mean after excluding outliers + mean = self.stats.get_current_mean() + expected_mean = np.mean([10, 12, 13, 14, 15, 16, 17, 18]) + self.assertAlmostEqual(mean, expected_mean, places=5) + + def test_get_sum_of_squared(self): + # Test removing outliers + data = [10, 12, 13, 14, 99999, 15, 16, 17, 18, -1000] + for x in data: + self.stats.add(x) + + actual_sum_of_squared = self.stats.get_sum_of_squared() + expected_sum_of_squares = np.sum( + (np.asarray([10, 12, 13, 14, 15, 16, 17, 18]) - np.mean([10, 12, 13, 14, 15, 16, 17, 18])) ** 2 + ) + self.assertEqual(actual_sum_of_squared, expected_sum_of_squares) + self.stats.reset() + + def test_standard_deviation_calculation(self): + # Test the standard deviation after excluding outliers + data = [10, 12, 13, 14, 1000, 15, 16, 17, 18, -1000] + for x in data: + self.stats.add(x) + + # Test standard deviation after excluding outliers + stddev = self.stats.get_standard_deviation() + expected_stddev = np.std([10, 12, 13, 14, 15, 16, 17, 18], ddof=0) + self.assertAlmostEqual(stddev, expected_stddev, places=5) + self.stats.reset() + + def test_empty_filtered_data(self): + # Test when no data is present + self.assertEqual(0.0, self.stats.get_current_mean()) + self.assertEqual(0.0, self.stats.get_standard_deviation()) + self.assertEqual(0.0, self.stats.get_sum_of_squared()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests/utils/truncated_online_statistics_test.py b/tests/unit_tests/utils/truncated_online_statistics_test.py new file mode 100644 index 00000000..0cd69df6 --- /dev/null +++ b/tests/unit_tests/utils/truncated_online_statistics_test.py @@ -0,0 +1,125 @@ +import unittest + +import numpy as np + +from cehrbert.utils.stat_utils import TruncatedOfflineStatistics, TruncatedOnlineStatistics + + +class TestTruncatedOnlineStatistics(unittest.TestCase): + + # def setUp(self): + # # Set up instances of TruncatedOnlineStatistics and TruncatedOfflineStatistics for testing + # self.online_stats = TruncatedOnlineStatistics(capacity=10, value_outlier_std=2.0) + # self.offline_stats = TruncatedOfflineStatistics(capacity=10, value_outlier_std=2.0) + + def test_add_data_before_online_mode(self): + # Set up instances of TruncatedOnlineStatistics and TruncatedOfflineStatistics for testing + online_stats = TruncatedOnlineStatistics(capacity=10, value_outlier_std=2.0) + offline_stats = TruncatedOfflineStatistics(capacity=10, value_outlier_std=2.0) + # Test adding data before transitioning to online mode + data = [1, 2, 3, 4, 5, 100, -100] + for x in data: + online_stats.add(1.0, x) + + # Before switching to online, check that offline stats are still active + self.assertFalse(online_stats.is_online_update_started) + self.assertEqual(offline_stats.get_count(), 0) # Offline mode hasn't reached full capacity yet + + def test_add_data_and_switch_to_online_mode(self): + # Set up instances of TruncatedOnlineStatistics and TruncatedOfflineStatistics for testing + online_stats = TruncatedOnlineStatistics(capacity=10, value_outlier_std=2.0) + # Add enough data to trigger the switch to online mode + data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + for x in data: + online_stats.add(1.0, x) + + # Online update should have been started + self.assertTrue(online_stats.is_online_update_started) + + def test_mean_calculation_in_online_mode(self): + # Set up instances of TruncatedOnlineStatistics and TruncatedOfflineStatistics for testing + online_stats = TruncatedOnlineStatistics(capacity=10, value_outlier_std=2.0) + # Test that the mean is calculated correctly in online mode + data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + for x in data: + online_stats.add(1.0, x) + + # Mean after adding all data points (in online mode) + expected_mean = np.mean(data) + self.assertAlmostEqual(online_stats.mean(), expected_mean, places=5) + + def test_standard_deviation_calculation_in_online_mode(self): + # Set up instances of TruncatedOnlineStatistics and TruncatedOfflineStatistics for testing + online_stats = TruncatedOnlineStatistics(capacity=10, value_outlier_std=2.0) + # Test that the standard deviation is calculated correctly in online mode + data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + for x in data: + online_stats.add(1.0, x) + + # Standard deviation after adding all data points (in online mode) + # The truncated offline stats removes the first and the last number from the array + expected_stddev = np.std(data[1:-1], ddof=0) + self.assertAlmostEqual(online_stats.standard_deviation(), expected_stddev, places=5) + + def test_online_mean_with_outliers(self): + online_stats = TruncatedOnlineStatistics(capacity=11, value_outlier_std=2.0) + # Add data with outliers, ensuring they are excluded + data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 1000, -1000] + for x in data: + online_stats.add(1.0, x) + # After excluding the outliers, calculate mean + expected_mean = np.mean([1, 2, 3, 4, 5, 6, 7, 8, 9]) # Exclude outliers + self.assertAlmostEqual(online_stats.mean(), expected_mean, places=5) + + def test_online_standard_deviation_with_outliers(self): + online_stats = TruncatedOnlineStatistics(capacity=11, value_outlier_std=2.0) + # Add data with outliers, ensuring they are excluded + data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 1000, -1000] + for x in data: + online_stats.add(1.0, x) + + # After excluding the outliers, calculate standard deviation + expected_stddev = np.std([1, 2, 3, 4, 5, 6, 7, 8, 9], ddof=0) # Exclude outliers + self.assertAlmostEqual(online_stats.standard_deviation(), expected_stddev, places=5) + + def test_combining_two_truncated_online_stats(self): + + online_stats = TruncatedOnlineStatistics(capacity=10, value_outlier_std=2.0) + # Combine two TruncatedOnlineStatistics objects and check the mean and variance + data1 = [1, 2, 3, 4, 5] + data2 = [6, 7, 8, 9, 10] + + for x in data1: + online_stats.add(1.0, x) + + other_stats = TruncatedOnlineStatistics(capacity=10, value_outlier_std=2.0) + for x in data2: + other_stats.add(1.0, x) + + online_stats.combine(other_stats) + + # Check the combined mean and variance + combined_data = data1[1:-1] + data2[1:-1] + expected_mean = np.mean(combined_data) + expected_stddev = np.std(combined_data, ddof=0) + + self.assertAlmostEqual(online_stats.mean(), expected_mean, places=5) + self.assertAlmostEqual(online_stats.standard_deviation(), expected_stddev, places=5) + + def test_add_data_beyond_capacity(self): + online_stats = TruncatedOnlineStatistics(capacity=10, value_outlier_std=2.0) + # Test that exceeding the capacity raises an error + data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] + for x in data[:-1]: # Add data up to capacity + online_stats.add(1.0, x) + prev_mean = online_stats.mean() + prev_std = online_stats.standard_deviation() + # This falls outside the range (mean - 2 * std, mean + 2 * std) + # Therefore the mean and std do not change + online_stats.add(1.0, data[-1]) + self.assertEqual(prev_mean, online_stats.mean()) + self.assertEqual(prev_std, online_stats.standard_deviation()) + + +if __name__ == "__main__": + unittest.main()