Skip to content

Commit

Permalink
Merge pull request #82 from AshishPvjs/master
Browse files Browse the repository at this point in the history
Fixes #75
  • Loading branch information
skadio authored Aug 2, 2023
2 parents 3c0b5b1 + 4ddd394 commit 91d97fc
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 31 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
MABWiser CHANGELOG
=====================

August, 02, 2023 2.7.1
-------------------------------------------------------------------------------
minor:
- Implemented LearningPolicyType and NeighborhoodPolicyType to simplify input for MAB.
- Updated tests to accommodate LearningPolicyType and NeighborhoodPolicyType.

February, 07, 2023 2.7.0
-------------------------------------------------------------------------------
major:
Expand Down
2 changes: 1 addition & 1 deletion mabwiser/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@

__author__ = "FMR LLC"
__email__ = "[email protected]"
__version__ = "2.7.0"
__version__ = "2.7.1"
__copyright__ = "Copyright (C), FMR LLC"
54 changes: 31 additions & 23 deletions mabwiser/mab.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
- ``NeighborhoodPolicy``
"""

from typing import List, Union, Dict, NamedTuple, NoReturn, Callable, Optional
from typing import List, Union, Dict, NamedTuple, NoReturn, Callable, Optional, NewType

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -667,6 +667,27 @@ def _is_compatible(self, learning_policy: LearningPolicy):
LearningPolicy.ThompsonSampling))


# LearningPolicyType is the Union of all possible learning policies
LearningPolicyType = NewType('LearningPolicyType', Union[LearningPolicy.EpsilonGreedy,
LearningPolicy.Popularity,
LearningPolicy.Random,
LearningPolicy.Softmax,
LearningPolicy.ThompsonSampling,
LearningPolicy.UCB1,
LearningPolicy.LinGreedy,
LearningPolicy.LinTS,
LearningPolicy.LinUCB])


# NeighborhoodPolicyType is the Union of all possible neighborhood policies
NeighborhoodPolicyType = NewType('NeighborhoodPolicyType', Union[None,
NeighborhoodPolicy.LSHNearest,
NeighborhoodPolicy.Clusters,
NeighborhoodPolicy.KNearest,
NeighborhoodPolicy.Radius,
NeighborhoodPolicy.TreeBandit])


class MAB:
"""**MABWiser: Contextual Multi-Armed Bandit Library**
Expand All @@ -676,10 +697,10 @@ class MAB:
Attributes
----------
arms : list
The list of all of the arms available for decisions. Arms can be integers, strings, etc.
learning_policy : LearningPolicy
The list of all the arms available for decisions. Arms can be integers, strings, etc.
learning_policy : LearningPolicyType
The learning policy.
neighborhood_policy : NeighborhoodPolicy
neighborhood_policy : NeighborhoodPolicyType
The neighborhood policy.
is_contextual : bool
True if contextual policy is given, false otherwise. This is a read-only data field.
Expand All @@ -695,7 +716,7 @@ class MAB:
- “loky” used by default, can induce some communication and memory overhead when exchanging input and
output data with the worker Python processes.
- “multiprocessing” previous process-based backend based on multiprocessing.Pool. Less robust than loky.
- “threading” is a very low-overhead backend but it suffers from the Python Global Interpreter Lock if the
- “threading” is a very low-overhead backend but, it suffers from the Python Global Interpreter Lock if the
called function relies a lot on Python objects.
Default value is None. In this case the default backend selected by joblib will be used.
Expand Down Expand Up @@ -731,21 +752,8 @@ class MAB:

def __init__(self,
arms: List[Arm], # The list of arms
learning_policy: Union[LearningPolicy.EpsilonGreedy,
LearningPolicy.Popularity,
LearningPolicy.Random,
LearningPolicy.Softmax,
LearningPolicy.ThompsonSampling,
LearningPolicy.UCB1,
LearningPolicy.LinGreedy,
LearningPolicy.LinTS,
LearningPolicy.LinUCB], # The learning policy
neighborhood_policy: Union[None,
NeighborhoodPolicy.LSHNearest,
NeighborhoodPolicy.Clusters,
NeighborhoodPolicy.KNearest,
NeighborhoodPolicy.Radius,
NeighborhoodPolicy.TreeBandit] = None, # The context policy, optional
learning_policy: LearningPolicyType, # The learning policy
neighborhood_policy: NeighborhoodPolicyType = None, # The context policy, optional
seed: int = Constants.default_seed, # The random seed
n_jobs: int = 1, # Number of parallel jobs
backend: str = None # Parallel backend implementation
Expand All @@ -757,11 +765,11 @@ def __init__(self,
Parameters
----------
arms : List[Union[int, float, str]]
The list of all of the arms available for decisions.
The list of all the arms available for decisions.
Arms can be integers, strings, etc.
learning_policy : LearningPolicy
learning_policy : LearningPolicyType
The learning policy.
neighborhood_policy : NeighborhoodPolicy, optional
neighborhood_policy : NeighborhoodPolicyType, optional
The context policy. Default value is None.
seed : numbers.Rational, optional
The random seed to initialize the random number generator.
Expand Down
10 changes: 3 additions & 7 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import pandas as pd

from mabwiser.mab import MAB, LearningPolicy, NeighborhoodPolicy
from mabwiser.mab import MAB, LearningPolicy, NeighborhoodPolicy, LearningPolicyType, NeighborhoodPolicyType
from mabwiser.utils import Arm, Num


Expand Down Expand Up @@ -76,12 +76,8 @@ class BaseTest(unittest.TestCase):
def predict(arms: List[Arm],
decisions: Union[List, np.ndarray, pd.Series],
rewards: Union[List, np.ndarray, pd.Series],
learning_policy: Union[LearningPolicy.EpsilonGreedy, LearningPolicy.Popularity, LearningPolicy.Random,
LearningPolicy.Softmax, LearningPolicy.ThompsonSampling, LearningPolicy.UCB1,
LearningPolicy.LinGreedy, LearningPolicy.LinTS, LearningPolicy.LinUCB],
neighborhood_policy: Union[None, NeighborhoodPolicy.Clusters, NeighborhoodPolicy.KNearest,
NeighborhoodPolicy.LSHNearest, NeighborhoodPolicy.Radius,
NeighborhoodPolicy.TreeBandit] = None,
learning_policy: LearningPolicyType,
neighborhood_policy: NeighborhoodPolicyType = None,
context_history: Union[None, List[Num], List[List[Num]], np.ndarray, pd.DataFrame, pd.Series] = None,
contexts: Union[None, List[Num], List[List[Num]], np.ndarray, pd.DataFrame, pd.Series] = None,
seed: Optional[int] = 123456,
Expand Down

0 comments on commit 91d97fc

Please sign in to comment.