From 877a855cc890edbc32a515f755be29ed3726e95f Mon Sep 17 00:00:00 2001 From: "Kilitcioglu, Doruk" Date: Thu, 19 Jan 2023 08:32:07 -0500 Subject: [PATCH 1/4] Fix double warm start issue Signed-off-by: Kilitcioglu, Doruk --- mabwiser/base_mab.py | 18 ++++++++------ mabwiser/greedy.py | 3 +++ mabwiser/linear.py | 3 +++ mabwiser/softmax.py | 3 +++ mabwiser/thompson.py | 3 +++ mabwiser/treebandit.py | 3 +++ mabwiser/ucb.py | 3 +++ tests/test_greedy.py | 21 ++++++++++++++++ tests/test_lingreedy.py | 31 ++++++++++++++++++++++++ tests/test_lints.py | 29 ++++++++++++++++++++++ tests/test_linucb.py | 27 ++++++++++++++++++--- tests/test_popularity.py | 22 +++++++++++++++++ tests/test_ridge.py | 33 ++++++++++++++++++++++++- tests/test_softmax.py | 25 +++++++++++++++++++ tests/test_thompson.py | 24 +++++++++++++++++++ tests/test_treebandit.py | 52 ++++++++++++++++++++++++++++++++++++++++ tests/test_ucb.py | 22 +++++++++++++++++ 17 files changed, 311 insertions(+), 11 deletions(-) diff --git a/mabwiser/base_mab.py b/mabwiser/base_mab.py index c3294a1..d41c24e 100755 --- a/mabwiser/base_mab.py +++ b/mabwiser/base_mab.py @@ -139,8 +139,10 @@ def predict_expectations(self, contexts: Optional[np.ndarray] = None) -> Union[D pass def warm_start(self, arm_to_features: Dict[Arm, List[Num]], distance_quantile: float) -> NoReturn: - self.cold_arm_to_warm_arm = self._get_cold_arm_to_warm_arm(arm_to_features, distance_quantile) - self._copy_arms(self.cold_arm_to_warm_arm) + new_cold_arm_to_warm_arm = self._get_cold_arm_to_warm_arm(self.cold_arm_to_warm_arm, arm_to_features, + distance_quantile) + self._copy_arms(new_cold_arm_to_warm_arm) + self.cold_arm_to_warm_arm = {**self.cold_arm_to_warm_arm, **new_cold_arm_to_warm_arm} @abc.abstractmethod def _copy_arms(self, cold_arm_to_warm_arm: Dict[Arm, Arm]) -> NoReturn: @@ -360,7 +362,7 @@ def _get_distance_threshold(distance_from_to: Dict[Arm, Dict[Arm, Num]], quantil return threshold - def _get_cold_arm_to_warm_arm(self, arm_to_features, distance_quantile): + def _get_cold_arm_to_warm_arm(self, cold_arm_to_warm_arm, arm_to_features, distance_quantile): # Calculate from-to distances between all pairs of arms based on features # and then find minimum distance (threshold) required to warm start an untrained arm @@ -368,9 +370,11 @@ def _get_cold_arm_to_warm_arm(self, arm_to_features, distance_quantile): distance_threshold = self._get_distance_threshold(distance_from_to, quantile=distance_quantile) # Cold arms - cold_arms = [arm for arm in self.arms if arm not in self.trained_arms] + cold_arms = [arm for arm in self.arms if ((arm not in self.trained_arms) and (arm not in cold_arm_to_warm_arm))] + + # New cold arm to warm arm dictionary + new_cold_arm_to_warm_arm = dict() - cold_arm_to_warm_arm = {} for cold_arm in cold_arms: # Collect distance from cold arm to warm arms @@ -387,6 +391,6 @@ def _get_cold_arm_to_warm_arm(self, arm_to_features, distance_quantile): # Warm start if closest distance lower than minimum required distance if closest_distance <= distance_threshold: - cold_arm_to_warm_arm[cold_arm] = closest_arm + new_cold_arm_to_warm_arm[cold_arm] = closest_arm - return cold_arm_to_warm_arm + return new_cold_arm_to_warm_arm diff --git a/mabwiser/greedy.py b/mabwiser/greedy.py index 2615349..2bfc64a 100755 --- a/mabwiser/greedy.py +++ b/mabwiser/greedy.py @@ -27,6 +27,9 @@ def fit(self, decisions: np.ndarray, rewards: np.ndarray, contexts: np.ndarray = reset(self.arm_to_count, 0) reset(self.arm_to_expectation, 0) + # Reset warm started arms + self.cold_arm_to_warm_arm = dict() + self._parallel_fit(decisions, rewards, contexts) def partial_fit(self, decisions: np.ndarray, rewards: np.ndarray, contexts: np.ndarray = None) -> NoReturn: diff --git a/mabwiser/linear.py b/mabwiser/linear.py index d98800d..a1700f4 100755 --- a/mabwiser/linear.py +++ b/mabwiser/linear.py @@ -135,6 +135,9 @@ def fit(self, decisions: np.ndarray, rewards: np.ndarray, contexts: np.ndarray = for arm in self.arms: self.arm_to_model[arm].init(num_features=self.num_features) + # Reset warm started arms + self.cold_arm_to_warm_arm = dict() + # Perform parallel fit self._parallel_fit(decisions, rewards, contexts) diff --git a/mabwiser/softmax.py b/mabwiser/softmax.py index d74ce5d..7a7d77f 100755 --- a/mabwiser/softmax.py +++ b/mabwiser/softmax.py @@ -31,6 +31,9 @@ def fit(self, decisions: np.ndarray, rewards: np.ndarray, contexts: np.ndarray = reset(self.arm_to_count, 0) reset(self.arm_to_mean, 0) + # Reset warm started arms + self.cold_arm_to_warm_arm = dict() + # Calculate fit self._parallel_fit(decisions, rewards) self._expectation_operation() diff --git a/mabwiser/thompson.py b/mabwiser/thompson.py index 48454c2..c81d3f0 100755 --- a/mabwiser/thompson.py +++ b/mabwiser/thompson.py @@ -31,6 +31,9 @@ def fit(self, decisions: np.ndarray, rewards: np.ndarray, contexts: np.ndarray = reset(self.arm_to_success_count, 1) reset(self.arm_to_fail_count, 1) + # Reset warm started arms + self.cold_arm_to_warm_arm = dict() + # Calculate fit self._parallel_fit(decisions, rewards) diff --git a/mabwiser/treebandit.py b/mabwiser/treebandit.py index 09bd333..1c6d85c 100644 --- a/mabwiser/treebandit.py +++ b/mabwiser/treebandit.py @@ -39,6 +39,9 @@ def fit(self, decisions: np.ndarray, rewards: np.ndarray, contexts: np.ndarray = self.arm_to_tree = {arm: DecisionTreeRegressor(**self.tree_parameters) for arm in self.arms} self.arm_to_leaf_to_rewards = {arm: defaultdict(partial(np.ndarray, 0)) for arm in self.arms} + # Reset warm started arms + self.cold_arm_to_warm_arm = dict() + # If TS and a binarizer function is given, binarize the rewards if isinstance(self.lp, _ThompsonSampling) and self.lp.binarizer: self.lp.is_contextual_binarized = False diff --git a/mabwiser/ucb.py b/mabwiser/ucb.py index 90011e1..7658e7d 100755 --- a/mabwiser/ucb.py +++ b/mabwiser/ucb.py @@ -31,6 +31,9 @@ def fit(self, decisions: np.ndarray, rewards: np.ndarray, contexts: np.ndarray = reset(self.arm_to_mean, 0) reset(self.arm_to_expectation, 0) + # Reset warm started arms + self.cold_arm_to_warm_arm = dict() + # Total number of decisions self.total_count = len(decisions) diff --git a/tests/test_greedy.py b/tests/test_greedy.py index 5272853..a6e3dc5 100755 --- a/tests/test_greedy.py +++ b/tests/test_greedy.py @@ -384,6 +384,27 @@ def test_warm_start(self): mab.warm_start(arm_to_features={1: [0, 1], 2: [0, 0], 3: [0.5, 0.5]}, distance_quantile=0.5) self.assertDictEqual(mab._imp.arm_to_expectation, {1: 0.5, 2: 0.0, 3: 0.5}) + def test_double_warm_start(self): + _, mab = self.predict(arms=[1, 2, 3], + decisions=[1, 1, 1, 2, 2, 2, 1, 1, 1], + rewards=[0, 0, 0, 0, 0, 0, 1, 1, 1], + learning_policy=LearningPolicy.EpsilonGreedy(epsilon=0.0), + seed=7, + num_run=1, + is_predict=False) + + # Before warm start + self.assertEqual(mab._imp.trained_arms, [1, 2]) + self.assertDictEqual(mab._imp.arm_to_expectation, {1: 0.5, 2: 0.0, 3: 0.0}) + + # Warm start, #3 gets warm started by #2 + mab.warm_start(arm_to_features={1: [0, 1], 2: [0.5, 0.5], 3: [0.5, 0.5]}, distance_quantile=0.5) + self.assertDictEqual(mab._imp.arm_to_expectation, {1: 0.5, 2: 0.0, 3: 0.0}) + + # Warm start again, #3 is closest to #1 but shouldn't get warm started again + mab.warm_start(arm_to_features={1: [0, 1], 2: [-1, -1], 3: [0, 1]}, distance_quantile=0.5) + self.assertDictEqual(mab._imp.arm_to_expectation, {1: 0.5, 2: 0.0, 3: 0.0}) + def test_greedy_contexts(self): arms, mab = self.predict(arms=[1, 2, 3], decisions=[1, 1, 1, 3, 2, 2, 3, 1, 3], diff --git a/tests/test_lingreedy.py b/tests/test_lingreedy.py index bb1ae05..8281a9d 100644 --- a/tests/test_lingreedy.py +++ b/tests/test_lingreedy.py @@ -859,3 +859,34 @@ def test_warm_start(self): mab.warm_start(arm_to_features={1: [0, 1], 2: [0, 0], 3: [0.5, 0.5]}, distance_quantile=0.5) self.assertListAlmostEqual(mab._imp.arm_to_model[3].beta, [0.19635284, 0.11556404, 0.57675997, 0.30597964, -0.39100933]) + + def test_double_warm_start(self): + _, mab = self.predict(arms=[1, 2, 3], + decisions=[1, 1, 1, 1, 2, 2, 2, 1, 2, 1], + rewards=[0, 1, 1, 0, 1, 0, 1, 1, 1, 1], + learning_policy=LearningPolicy.LinGreedy(epsilon=0.25), + context_history=[[0, 1, 2, 3, 5], [1, 1, 1, 1, 1], [0, 0, 1, 0, 0], + [0, 2, 2, 3, 5], [1, 3, 1, 1, 1], [0, 0, 0, 0, 0], + [0, 1, 4, 3, 5], [0, 1, 2, 4, 5], [1, 2, 1, 1, 3], + [0, 2, 1, 0, 0]], + contexts=[[0, 1, 2, 3, 5], [1, 1, 1, 1, 1]], + seed=123456, + num_run=4, + is_predict=True) + + # Before warm start + self.assertEqual(mab._imp.trained_arms, [1, 2]) + self.assertDictEqual(mab._imp.arm_to_expectation, {1: 0.0, 2: 0.0, 3: 0.0}) + self.assertListAlmostEqual(mab._imp.arm_to_model[1].beta, + [0.19635284, 0.11556404, 0.57675997, 0.30597964, -0.39100933]) + self.assertListAlmostEqual(mab._imp.arm_to_model[3].beta, [0, 0, 0, 0, 0]) + + # Warm start + mab.warm_start(arm_to_features={1: [0, 1], 2: [0.5, 0.5], 3: [0, 1]}, distance_quantile=0.5) + self.assertListAlmostEqual(mab._imp.arm_to_model[3].beta, + [0.19635284, 0.11556404, 0.57675997, 0.30597964, -0.39100933]) + + # Warm start again, #3 shouldn't change even though it's closer to #2 now + mab.warm_start(arm_to_features={1: [0, 1], 2: [0.5, 0.5], 3: [0.5, 0.5]}, distance_quantile=0.5) + self.assertListAlmostEqual(mab._imp.arm_to_model[3].beta, + [0.19635284, 0.11556404, 0.57675997, 0.30597964, -0.39100933]) diff --git a/tests/test_lints.py b/tests/test_lints.py index f76e18b..88cbb7e 100644 --- a/tests/test_lints.py +++ b/tests/test_lints.py @@ -576,3 +576,32 @@ def test_warm_start(self): # Warm start mab.warm_start(arm_to_features={1: [0, 1], 2: [0, 0], 3: [0.5, 0.5]}, distance_quantile=0.5) self.assertListAlmostEqual(mab._imp.arm_to_model[3].beta, [0.19635284, 0.11556404, 0.57675997, 0.30597964, -0.39100933]) + + def test_double_warm_start(self): + _, mab = self.predict(arms=[1, 2, 3], + decisions=[1, 1, 1, 1, 2, 2, 2, 1, 2, 1], + rewards=[0, 1, 1, 0, 1, 0, 1, 1, 1, 1], + learning_policy=LearningPolicy.LinTS(alpha=0.24), + context_history=[[0, 1, 2, 3, 5], [1, 1, 1, 1, 1], [0, 0, 1, 0, 0], + [0, 2, 2, 3, 5], [1, 3, 1, 1, 1], [0, 0, 0, 0, 0], + [0, 1, 4, 3, 5], [0, 1, 2, 4, 5], [1, 2, 1, 1, 3], + [0, 2, 1, 0, 0]], + contexts=[[0, 1, 2, 3, 5], [1, 1, 1, 1, 1]], + seed=123456, + num_run=4, + is_predict=True) + + # Before warm start + self.assertEqual(mab._imp.trained_arms, [1, 2]) + self.assertDictEqual(mab._imp.arm_to_expectation, {1: 0.0, 2: 0.0, 3: 0.0}) + self.assertListAlmostEqual(mab._imp.arm_to_model[1].beta, [0.19635284, 0.11556404, 0.57675997, 0.30597964, -0.39100933]) + self.assertListAlmostEqual(mab._imp.arm_to_model[3].beta, [0, 0, 0, 0, 0]) + + # Warm start + mab.warm_start(arm_to_features={1: [0, 1], 2: [0.5, 0.5], 3: [0, 1]}, distance_quantile=0.5) + self.assertListAlmostEqual(mab._imp.arm_to_model[3].beta, [0.19635284, 0.11556404, 0.57675997, 0.30597964, -0.39100933]) + + # Warm start again, #3 shouldn't change even though it's closer to #2 now + mab.warm_start(arm_to_features={1: [0, 1], 2: [0.5, 0.5], 3: [0.5, 0.5]}, distance_quantile=0.5) + self.assertListAlmostEqual(mab._imp.arm_to_model[3].beta, + [0.19635284, 0.11556404, 0.57675997, 0.30597964, -0.39100933]) diff --git a/tests/test_linucb.py b/tests/test_linucb.py index 00930b6..02547ab 100755 --- a/tests/test_linucb.py +++ b/tests/test_linucb.py @@ -665,9 +665,30 @@ def test_warm_start(self): mab.warm_start(arm_to_features={1: [0, 1], 2: [0, 0], 3: [0.5, 0.5]}, distance_quantile=0.5) self.assertListAlmostEqual(mab._imp.arm_to_model[3].beta, [0.19635284, 0.11556404, 0.57675997, 0.30597964, -0.39100933]) + def test_double_warm_start(self): + _, mab = self.predict(arms=[1, 2, 3], + decisions=[1, 1, 1, 1, 2, 2, 2, 1, 2, 1], + rewards=[0, 1, 1, 0, 1, 0, 1, 1, 1, 1], + learning_policy=LearningPolicy.LinTS(alpha=0.24), + context_history=[[0, 1, 2, 3, 5], [1, 1, 1, 1, 1], [0, 0, 1, 0, 0], + [0, 2, 2, 3, 5], [1, 3, 1, 1, 1], [0, 0, 0, 0, 0], + [0, 1, 4, 3, 5], [0, 1, 2, 4, 5], [1, 2, 1, 1, 3], + [0, 2, 1, 0, 0]], + contexts=[[0, 1, 2, 3, 5], [1, 1, 1, 1, 1]], + seed=123456, + num_run=4, + is_predict=True) + # Before warm start + self.assertEqual(mab._imp.trained_arms, [1, 2]) + self.assertDictEqual(mab._imp.arm_to_expectation, {1: 0.0, 2: 0.0, 3: 0.0}) + self.assertListAlmostEqual(mab._imp.arm_to_model[1].beta, [0.19635284, 0.11556404, 0.57675997, 0.30597964, -0.39100933]) + self.assertListAlmostEqual(mab._imp.arm_to_model[3].beta, [0, 0, 0, 0, 0]) + # Warm start + mab.warm_start(arm_to_features={1: [0, 1], 2: [0.5, 0.5], 3: [0, 1]}, distance_quantile=0.5) + self.assertListAlmostEqual(mab._imp.arm_to_model[3].beta, [0.19635284, 0.11556404, 0.57675997, 0.30597964, -0.39100933]) - - - + # Warm start again, #3 shouldn't change even though it's closer to #2 now + mab.warm_start(arm_to_features={1: [0, 1], 2: [0.5, 0.5], 3: [0.5, 0.5]}, distance_quantile=0.5) + self.assertListAlmostEqual(mab._imp.arm_to_model[3].beta, [0.19635284, 0.11556404, 0.57675997, 0.30597964, -0.39100933]) diff --git a/tests/test_popularity.py b/tests/test_popularity.py index 29d9c7e..ba6da18 100644 --- a/tests/test_popularity.py +++ b/tests/test_popularity.py @@ -299,6 +299,28 @@ def test_warm_start(self): mab.warm_start(arm_to_features={1: [0, 1], 2: [0, 0], 3: [0, 1]}, distance_quantile=0.5) self.assertDictEqual(mab._imp.arm_to_expectation, {1: 1.0, 2: 0.0, 3: 1.0}) + def test_double_warm_start(self): + + _, mab = self.predict(arms=[1, 2, 3], + decisions=[1, 1, 1, 2, 2, 2, 1, 1, 1], + rewards=[0, 0, 0, 0, 0, 0, 1, 1, 1], + learning_policy=LearningPolicy.Popularity(), + seed=7, + num_run=1, + is_predict=False) + + # Before warm start + self.assertEqual(mab._imp.trained_arms, [1, 2]) + self.assertDictEqual(mab._imp.arm_to_expectation, {1: 1.0, 2: 0.0, 3: 0.0}) + + # Warm start + mab.warm_start(arm_to_features={1: [0, 1], 2: [0.5, 0.5], 3: [0, 1]}, distance_quantile=0.5) + self.assertDictEqual(mab._imp.arm_to_expectation, {1: 1.0, 2: 0.0, 3: 1.0}) + + # Warm start again, #3 shouldn't change even though it's closer to #2 now + mab.warm_start(arm_to_features={1: [0, 1], 2: [0.5, 0.5], 3: [0.5, 0.5]}, distance_quantile=0.5) + self.assertDictEqual(mab._imp.arm_to_expectation, {1: 1.0, 2: 0.0, 3: 1.0}) + def test_popularity_contexts(self): arms, mab = self.predict(arms=[1, 2, 3], decisions=[1, 1, 1, 3, 2, 2, 3, 1, 3], diff --git a/tests/test_ridge.py b/tests/test_ridge.py index fc31b71..27c8fa9 100755 --- a/tests/test_ridge.py +++ b/tests/test_ridge.py @@ -516,4 +516,35 @@ def test_warm_start(self): # Warm start mab.warm_start(arm_to_features={1: [0, 1], 2: [0, 0], 3: [0.5, 0.5]}, distance_quantile=0.5) - self.assertListAlmostEqual(mab._imp.arm_to_model[3].beta, [0.19635284, 0.11556404, 0.57675997, 0.30597964, -0.39100933]) \ No newline at end of file + self.assertListAlmostEqual(mab._imp.arm_to_model[3].beta, [0.19635284, 0.11556404, 0.57675997, 0.30597964, -0.39100933]) + + def test_double_warm_start(self): + _, mab = self.predict(arms=[1, 2, 3], + decisions=[1, 1, 1, 1, 2, 2, 2, 1, 2, 1], + rewards=[0, 1, 1, 0, 1, 0, 1, 1, 1, 1], + learning_policy=LearningPolicy.LinTS(alpha=0.24), + context_history=[[0, 1, 2, 3, 5], [1, 1, 1, 1, 1], [0, 0, 1, 0, 0], + [0, 2, 2, 3, 5], [1, 3, 1, 1, 1], [0, 0, 0, 0, 0], + [0, 1, 4, 3, 5], [0, 1, 2, 4, 5], [1, 2, 1, 1, 3], + [0, 2, 1, 0, 0]], + contexts=[[0, 1, 2, 3, 5], [1, 1, 1, 1, 1]], + seed=123456, + num_run=4, + is_predict=True) + + # Before warm start + self.assertEqual(mab._imp.trained_arms, [1, 2]) + self.assertDictEqual(mab._imp.arm_to_expectation, {1: 0.0, 2: 0.0, 3: 0.0}) + self.assertListAlmostEqual(mab._imp.arm_to_model[1].beta, + [0.19635284, 0.11556404, 0.57675997, 0.30597964, -0.39100933]) + self.assertListAlmostEqual(mab._imp.arm_to_model[3].beta, [0, 0, 0, 0, 0]) + + # Warm start + mab.warm_start(arm_to_features={1: [0, 1], 2: [0.5, 0.5], 3: [0, 1]}, distance_quantile=0.5) + self.assertListAlmostEqual(mab._imp.arm_to_model[3].beta, + [0.19635284, 0.11556404, 0.57675997, 0.30597964, -0.39100933]) + + # Warm start again, #3 shouldn't change even though it's closer to #2 now + mab.warm_start(arm_to_features={1: [0, 1], 2: [0.5, 0.5], 3: [0.5, 0.5]}, distance_quantile=0.5) + self.assertListAlmostEqual(mab._imp.arm_to_model[3].beta, + [0.19635284, 0.11556404, 0.57675997, 0.30597964, -0.39100933]) diff --git a/tests/test_softmax.py b/tests/test_softmax.py index 634c7b5..421c1db 100755 --- a/tests/test_softmax.py +++ b/tests/test_softmax.py @@ -390,6 +390,31 @@ def test_warm_start(self): self.assertDictEqual(mab._imp.arm_to_expectation, {1: 0.38365173119055074, 2: 0.23269653761889864, 3: 0.38365173119055074}) + def test_double_warm_start(self): + + _, mab = self.predict(arms=[1, 2, 3], + decisions=[1, 1, 1, 2, 2, 2, 1, 1, 1], + rewards=[0, 0, 0, 0, 0, 0, 1, 1, 1], + learning_policy=LearningPolicy.Softmax(tau=1), + seed=7, + num_run=1, + is_predict=False) + + # Before warm start + self.assertEqual(mab._imp.trained_arms, [1, 2]) + self.assertDictEqual(mab._imp.arm_to_expectation, {1: 0.45186276187760605, 2: 0.274068619061197, + 3: 0.274068619061197}) + + # Warm start + mab.warm_start(arm_to_features={1: [0, 1], 2: [0.5, 0.5], 3: [0, 1]}, distance_quantile=0.5) + self.assertDictEqual(mab._imp.arm_to_expectation, {1: 0.38365173119055074, 2: 0.23269653761889864, + 3: 0.38365173119055074}) + + # Warm start again, #3 is closest to #2 but shouldn't get warm started again + mab.warm_start(arm_to_features={1: [0, 1], 2: [0.5, 0.5], 3: [0.5, 0.5]}, distance_quantile=0.5) + self.assertDictEqual(mab._imp.arm_to_expectation, {1: 0.38365173119055074, 2: 0.23269653761889864, + 3: 0.38365173119055074}) + def test_softmax_contexts(self): arms, mab = self.predict(arms=[1, 2, 3], decisions=[1, 1, 1, 3, 2, 2, 3, 1, 3], diff --git a/tests/test_thompson.py b/tests/test_thompson.py index 9aa2fd3..1d75862 100755 --- a/tests/test_thompson.py +++ b/tests/test_thompson.py @@ -515,6 +515,30 @@ def test_warm_start(self): self.assertDictEqual(mab._imp.arm_to_fail_count, {1: 3, 2: 4, 3: 3}) self.assertDictEqual(mab._imp.arm_to_success_count, {1: 5, 2: 1, 3: 5}) + def test_double_warm_start(self): + _, mab = self.predict(arms=[1, 2, 3], + decisions=[1, 1, 1, 2, 2, 2, 1, 1, 1], + rewards=[1, 0, 0, 0, 0, 0, 1, 1, 1], + learning_policy=LearningPolicy.ThompsonSampling(), + seed=7, + num_run=1, + is_predict=False) + + # Before warm start + self.assertEqual(mab._imp.trained_arms, [1, 2]) + self.assertDictEqual(mab._imp.arm_to_fail_count, {1: 3, 2: 4, 3: 1}) + self.assertDictEqual(mab._imp.arm_to_success_count, {1: 5, 2: 1, 3: 1}) + + # Warm start + mab.warm_start(arm_to_features={1: [0, 1], 2: [0.5, 0.5], 3: [0, 1]}, distance_quantile=0.5) + self.assertDictEqual(mab._imp.arm_to_fail_count, {1: 3, 2: 4, 3: 3}) + self.assertDictEqual(mab._imp.arm_to_success_count, {1: 5, 2: 1, 3: 5}) + + # Warm start again, #3 is closest to #2 but shouldn't get warm started again + mab.warm_start(arm_to_features={1: [0, 1], 2: [0.5, 0.5], 3: [0.5, 0.5]}, distance_quantile=0.5) + self.assertDictEqual(mab._imp.arm_to_fail_count, {1: 3, 2: 4, 3: 3}) + self.assertDictEqual(mab._imp.arm_to_success_count, {1: 5, 2: 1, 3: 5}) + def test_ts_contexts(self): arms, mab = self.predict(arms=[1, 2, 3], decisions=[1, 1, 1, 3, 2, 2, 3, 1, 3], diff --git a/tests/test_treebandit.py b/tests/test_treebandit.py index 1353b25..f722054 100644 --- a/tests/test_treebandit.py +++ b/tests/test_treebandit.py @@ -386,3 +386,55 @@ def test_remove_arm(self): self.assertTrue(3 not in mab._imp.arm_to_tree) self.assertTrue(3 not in mab._imp.arm_to_leaf_to_rewards) self.assertTrue(3 not in mab._imp.lp.arms) + + def test_warm_start(self): + arms_1, mab = self.predict(arms=[1, 2, 4], + decisions=[1, 1, 1, 2, 2], + rewards=[0, 1, 1, 0, 0], + learning_policy=LearningPolicy.EpsilonGreedy(epsilon=0), + neighborhood_policy=NeighborhoodPolicy.TreeBandit(), + context_history=[[0, 1, 2, 3, 5], [1, 1, 1, 1, 1], [0, 0, 1, 0, 0], + [0, 2, 2, 3, 5], [1, 3, 1, 1, 1]], + contexts=[[0, 1, 2, 3, 5], [1, 1, 1, 1, 1]], + seed=123456, + num_run=1, + is_predict=True) + + mab.add_arm(3) + + arms_2 = mab._imp.predict([[0, 1, 2, 3, 5], [1, 1, 1, 1, 1]]) + + self.assertListEqual(arms_1, [1, 1]) + self.assertListEqual(arms_1, arms_2) + + # Warm start + mab.warm_start(arm_to_features={1: [0, 1], 2: [0.5, 0.5], 3: [0, 1], 4: [10, 10]}, distance_quantile=0.5) + self.assertDictEqual(mab._imp.predict_expectations([[1, 1, 1, 1, 1]]), {1: 1, 2: 0, 3: 1, 4: 0}) + + def test_double_warm_start(self): + arms_1, mab = self.predict(arms=[1, 2, 4], + decisions=[1, 1, 1, 2, 2], + rewards=[0, 1, 1, 0, 0], + learning_policy=LearningPolicy.EpsilonGreedy(epsilon=0), + neighborhood_policy=NeighborhoodPolicy.TreeBandit(), + context_history=[[0, 1, 2, 3, 5], [1, 1, 1, 1, 1], [0, 0, 1, 0, 0], + [0, 2, 2, 3, 5], [1, 3, 1, 1, 1]], + contexts=[[0, 1, 2, 3, 5], [1, 1, 1, 1, 1]], + seed=123456, + num_run=1, + is_predict=True) + + mab.add_arm(3) + + arms_2 = mab._imp.predict([[0, 1, 2, 3, 5], [1, 1, 1, 1, 1]]) + + self.assertListEqual(arms_1, [1, 1]) + self.assertListEqual(arms_1, arms_2) + + # Warm start + mab.warm_start(arm_to_features={1: [0, 1], 2: [0.5, 0.5], 3: [0, 1], 4: [10, 10]}, distance_quantile=0.5) + self.assertDictEqual(mab._imp.predict_expectations([[1, 1, 1, 1, 1]]), {1: 1, 2: 0, 3: 1, 4: 0}) + + # Warm start again, #3 is closest to #2 but shouldn't get warm started again + mab.warm_start(arm_to_features={1: [0, 1], 2: [0.5, 0.5], 3: [0.5, 0.5], 4: [10, 10]}, distance_quantile=0.5) + self.assertDictEqual(mab._imp.predict_expectations([[1, 1, 1, 1, 1]]), {1: 1, 2: 0, 3: 1, 4: 0}) diff --git a/tests/test_ucb.py b/tests/test_ucb.py index b57470b..1f8ba00 100755 --- a/tests/test_ucb.py +++ b/tests/test_ucb.py @@ -408,6 +408,28 @@ def test_warm_start(self): mab.warm_start(arm_to_features={1: [0, 1], 2: [0, 0], 3: [0.5, 0.5]}, distance_quantile=0.5) self.assertDictEqual(mab._imp.arm_to_mean, {1: 0.5, 2: 0.0, 3: 0.5}) + def test_double_warm_start(self): + + _, mab = self.predict(arms=[1, 2, 3], + decisions=[1, 1, 1, 2, 2, 2, 1, 1, 1], + rewards=[0, 0, 0, 0, 0, 0, 1, 1, 1], + learning_policy=LearningPolicy.UCB1(1.0), + seed=7, + num_run=1, + is_predict=False) + + # Before warm start + self.assertEqual(mab._imp.trained_arms, [1, 2]) + self.assertDictEqual(mab._imp.arm_to_mean, {1: 0.5, 2: 0.0, 3: 0.0}) + + # Warm start + mab.warm_start(arm_to_features={1: [0, 1], 2: [0.5, 0.5], 3: [0, 1]}, distance_quantile=0.5) + self.assertDictEqual(mab._imp.arm_to_mean, {1: 0.5, 2: 0.0, 3: 0.5}) + + # Warm start again, #3 is closest to #2 but shouldn't get warm started again + mab.warm_start(arm_to_features={1: [0, 1], 2: [0.5, 0.5], 3: [0.5, 0.5]}, distance_quantile=0.5) + self.assertDictEqual(mab._imp.arm_to_mean, {1: 0.5, 2: 0.0, 3: 0.5}) + def test_ucb_contexts(self): arms, mab = self.predict(arms=[1, 2, 3], decisions=[1, 1, 1, 3, 2, 2, 3, 1, 3], From 93038aa49af1eb684eacf38b4fb094e4253feed7 Mon Sep 17 00:00:00 2001 From: "Kilitcioglu, Doruk" Date: Thu, 19 Jan 2023 08:40:00 -0500 Subject: [PATCH 2/4] Remove py 3.6 from CI Signed-off-by: Kilitcioglu, Doruk --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9d34b7e..08f3375 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -13,7 +13,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.6, 3.7, 3.8, 3.9] + python-version: ["3.7", "3.8", "3.9", "3.10"] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} From 014ba9af0f596b645233c13c8edd0d6108bb00bb Mon Sep 17 00:00:00 2001 From: "Kilitcioglu, Doruk" Date: Thu, 19 Jan 2023 11:41:54 -0500 Subject: [PATCH 3/4] Add scaler fix Signed-off-by: Kilitcioglu, Doruk --- mabwiser/linear.py | 21 +++++++++++++++++++++ tests/test_ridge.py | 22 +++++++++++++++++++++- 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/mabwiser/linear.py b/mabwiser/linear.py index a1700f4..b02df44 100755 --- a/mabwiser/linear.py +++ b/mabwiser/linear.py @@ -10,6 +10,26 @@ from mabwiser.base_mab import BaseMAB from mabwiser.utils import Arm, Num, argmax, _BaseRNG, create_rng +SCALER_TOLERANCE = 1e-6 + + +def fix_small_variance(scaler: StandardScaler) -> NoReturn: + """ + Set variances close to zero to be equal to one in trained standard scaler to make computations stable. + + :param scaler: the scaler to check and fix variances for + """ + if hasattr(scaler, 'scale_') and hasattr(scaler, 'var_'): + # Get a mask to pull indices where std smaller than scaler_tolerance + mask = scaler.scale_ <= SCALER_TOLERANCE + + # Fix standard deviation + scaler.scale_[mask] = 1.0e+00 + + # Fix variance accordingly. var_ is allowed to be 0 in scaler. + # This helps to track if scale_ are set as ones due to zeros in variances. + scaler.var_[mask] = 0.0e+00 + class _RidgeRegression: @@ -45,6 +65,7 @@ def fit(self, X, y): self.scaler.fit(X) else: self.scaler.partial_fit(X) + fix_small_variance(self.scaler) X = self.scaler.transform(X) # X transpose diff --git a/tests/test_ridge.py b/tests/test_ridge.py index 27c8fa9..f4ec652 100755 --- a/tests/test_ridge.py +++ b/tests/test_ridge.py @@ -6,7 +6,7 @@ from sklearn.preprocessing import StandardScaler from mabwiser.mab import LearningPolicy -from mabwiser.linear import _RidgeRegression +from mabwiser.linear import _RidgeRegression, fix_small_variance from tests.test_base import BaseTest @@ -548,3 +548,23 @@ def test_double_warm_start(self): mab.warm_start(arm_to_features={1: [0, 1], 2: [0.5, 0.5], 3: [0.5, 0.5]}, distance_quantile=0.5) self.assertListAlmostEqual(mab._imp.arm_to_model[3].beta, [0.19635284, 0.11556404, 0.57675997, 0.30597964, -0.39100933]) + + def test_fix_small_variance(self): + rng = np.random.default_rng(1234) + context = rng.random((10000, 10)) + + # Set first feature to have variance close to zero + context[0, 0] = 0.0001 + context[1:, 0] = [0] * (10000 - 1) + + scaler = StandardScaler() + scaler.fit(context) + + self.assertAlmostEqual(scaler.scale_[0], 9.99949999e-07) + self.assertAlmostEqual(scaler.var_[0], 9.99900000e-13) + + # Fix small variance + fix_small_variance(scaler) + + self.assertAlmostEqual(scaler.scale_[0], 1) + self.assertAlmostEqual(scaler.var_[0], 0) From 925d902fbef82dadc4dd57920b45228885260359 Mon Sep 17 00:00:00 2001 From: "Kilitcioglu, Doruk" Date: Thu, 19 Jan 2023 15:17:35 -0500 Subject: [PATCH 4/4] Bump version and docs Signed-off-by: Kilitcioglu, Doruk --- .github/workflows/ci.yml | 3 +- CHANGELOG.txt | 8 +++++ docs/.buildinfo | 2 +- docs/_static/documentation_options.js | 2 +- docs/_static/pygments.css | 34 ++++++++++---------- docs/about.html | 2 +- docs/api.html | 6 ++-- docs/contributing.html | 2 +- docs/examples.html | 2 +- docs/genindex.html | 2 +- docs/index.html | 46 +++++++++++++-------------- docs/installation.html | 2 +- docs/new_bandit.html | 2 +- docs/py-modindex.html | 2 +- docs/quick.html | 2 +- docs/search.html | 2 +- docs/searchindex.js | 2 +- mabwiser/_version.py | 2 +- 18 files changed, 66 insertions(+), 57 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 08f3375..6acf0d7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -10,10 +10,11 @@ on: jobs: Test: - runs-on: ubuntu-latest + runs-on: ${{ matrix.os }} strategy: matrix: python-version: ["3.7", "3.8", "3.9", "3.10"] + os: [ubuntu-latest, macos-latest, windows-latest] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} diff --git a/CHANGELOG.txt b/CHANGELOG.txt index ea072ce..e1ee9e0 100755 --- a/CHANGELOG.txt +++ b/CHANGELOG.txt @@ -2,6 +2,14 @@ MABWiser CHANGELOG ===================== +January, 19, 2022 2.5.0 +------------------------------------------------------------------------------- +major: +- Update warm start logic to only warm start an arm once + +minor: +- Implement fix for fitting scalers in Linear policies when variance is too small + March, 28, 2022 2.4.1 ------------------------------------------------------------------------------- minor: diff --git a/docs/.buildinfo b/docs/.buildinfo index d9e4d7f..d399e1b 100644 --- a/docs/.buildinfo +++ b/docs/.buildinfo @@ -1,4 +1,4 @@ # Sphinx build info version 1 # This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. -config: 3d55ee819e2d2bd60e727943c52c4b22 +config: e4fe041dc60ac002d2f7d6fda6f40d56 tags: 645f666f9bcd5a90fca523b33c5a78b7 diff --git a/docs/_static/documentation_options.js b/docs/_static/documentation_options.js index 2e89cf7..0012242 100644 --- a/docs/_static/documentation_options.js +++ b/docs/_static/documentation_options.js @@ -1,6 +1,6 @@ var DOCUMENTATION_OPTIONS = { URL_ROOT: document.getElementById("documentation_options").getAttribute('data-url_root'), - VERSION: '2.4.0', + VERSION: '2.5.0', LANGUAGE: 'None', COLLAPSE_INDEX: false, BUILDER: 'html', diff --git a/docs/_static/pygments.css b/docs/_static/pygments.css index 08bec68..582d5c3 100644 --- a/docs/_static/pygments.css +++ b/docs/_static/pygments.css @@ -5,22 +5,22 @@ td.linenos .special { color: #000000; background-color: #ffffc0; padding-left: 5 span.linenos.special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; } .highlight .hll { background-color: #ffffcc } .highlight { background: #f8f8f8; } -.highlight .c { color: #3D7B7B; font-style: italic } /* Comment */ +.highlight .c { color: #408080; font-style: italic } /* Comment */ .highlight .err { border: 1px solid #FF0000 } /* Error */ .highlight .k { color: #008000; font-weight: bold } /* Keyword */ .highlight .o { color: #666666 } /* Operator */ -.highlight .ch { color: #3D7B7B; font-style: italic } /* Comment.Hashbang */ -.highlight .cm { color: #3D7B7B; font-style: italic } /* Comment.Multiline */ -.highlight .cp { color: #9C6500 } /* Comment.Preproc */ -.highlight .cpf { color: #3D7B7B; font-style: italic } /* Comment.PreprocFile */ -.highlight .c1 { color: #3D7B7B; font-style: italic } /* Comment.Single */ -.highlight .cs { color: #3D7B7B; font-style: italic } /* Comment.Special */ +.highlight .ch { color: #408080; font-style: italic } /* Comment.Hashbang */ +.highlight .cm { color: #408080; font-style: italic } /* Comment.Multiline */ +.highlight .cp { color: #BC7A00 } /* Comment.Preproc */ +.highlight .cpf { color: #408080; font-style: italic } /* Comment.PreprocFile */ +.highlight .c1 { color: #408080; font-style: italic } /* Comment.Single */ +.highlight .cs { color: #408080; font-style: italic } /* Comment.Special */ .highlight .gd { color: #A00000 } /* Generic.Deleted */ .highlight .ge { font-style: italic } /* Generic.Emph */ -.highlight .gr { color: #E40000 } /* Generic.Error */ +.highlight .gr { color: #FF0000 } /* Generic.Error */ .highlight .gh { color: #000080; font-weight: bold } /* Generic.Heading */ -.highlight .gi { color: #008400 } /* Generic.Inserted */ -.highlight .go { color: #717171 } /* Generic.Output */ +.highlight .gi { color: #00A000 } /* Generic.Inserted */ +.highlight .go { color: #888888 } /* Generic.Output */ .highlight .gp { color: #000080; font-weight: bold } /* Generic.Prompt */ .highlight .gs { font-weight: bold } /* Generic.Strong */ .highlight .gu { color: #800080; font-weight: bold } /* Generic.Subheading */ @@ -33,15 +33,15 @@ span.linenos.special { color: #000000; background-color: #ffffc0; padding-left: .highlight .kt { color: #B00040 } /* Keyword.Type */ .highlight .m { color: #666666 } /* Literal.Number */ .highlight .s { color: #BA2121 } /* Literal.String */ -.highlight .na { color: #687822 } /* Name.Attribute */ +.highlight .na { color: #7D9029 } /* Name.Attribute */ .highlight .nb { color: #008000 } /* Name.Builtin */ .highlight .nc { color: #0000FF; font-weight: bold } /* Name.Class */ .highlight .no { color: #880000 } /* Name.Constant */ .highlight .nd { color: #AA22FF } /* Name.Decorator */ -.highlight .ni { color: #717171; font-weight: bold } /* Name.Entity */ -.highlight .ne { color: #CB3F38; font-weight: bold } /* Name.Exception */ +.highlight .ni { color: #999999; font-weight: bold } /* Name.Entity */ +.highlight .ne { color: #D2413A; font-weight: bold } /* Name.Exception */ .highlight .nf { color: #0000FF } /* Name.Function */ -.highlight .nl { color: #767600 } /* Name.Label */ +.highlight .nl { color: #A0A000 } /* Name.Label */ .highlight .nn { color: #0000FF; font-weight: bold } /* Name.Namespace */ .highlight .nt { color: #008000; font-weight: bold } /* Name.Tag */ .highlight .nv { color: #19177C } /* Name.Variable */ @@ -58,11 +58,11 @@ span.linenos.special { color: #000000; background-color: #ffffc0; padding-left: .highlight .dl { color: #BA2121 } /* Literal.String.Delimiter */ .highlight .sd { color: #BA2121; font-style: italic } /* Literal.String.Doc */ .highlight .s2 { color: #BA2121 } /* Literal.String.Double */ -.highlight .se { color: #AA5D1F; font-weight: bold } /* Literal.String.Escape */ +.highlight .se { color: #BB6622; font-weight: bold } /* Literal.String.Escape */ .highlight .sh { color: #BA2121 } /* Literal.String.Heredoc */ -.highlight .si { color: #A45A77; font-weight: bold } /* Literal.String.Interpol */ +.highlight .si { color: #BB6688; font-weight: bold } /* Literal.String.Interpol */ .highlight .sx { color: #008000 } /* Literal.String.Other */ -.highlight .sr { color: #A45A77 } /* Literal.String.Regex */ +.highlight .sr { color: #BB6688 } /* Literal.String.Regex */ .highlight .s1 { color: #BA2121 } /* Literal.String.Single */ .highlight .ss { color: #19177C } /* Literal.String.Symbol */ .highlight .bp { color: #008000 } /* Name.Builtin.Pseudo */ diff --git a/docs/about.html b/docs/about.html index 7de1105..30eeaf1 100644 --- a/docs/about.html +++ b/docs/about.html @@ -4,7 +4,7 @@ - About Multi-Armed Bandits — MABWiser 2.4.0 documentation + About Multi-Armed Bandits — MABWiser 2.5.0 documentation