Skip to content

Commit

Permalink
add epsilon greedy bandit
Browse files Browse the repository at this point in the history
  • Loading branch information
polgreen committed Mar 4, 2021
1 parent cb9f945 commit d2c975c
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 1 deletion.
4 changes: 3 additions & 1 deletion bin/medley
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def main():
help="select mechanism for choosing orderings of solvers",
type=str,
default="neighbor",
choices=["neighbor", "random", "MLP", "thompson", "linear", "preset", "exp3", "knearest", "perfect"]
choices=["neighbor", "random", "MLP", "thompson", "linear", "preset", "exp3", "knearest", "perfect", "greedy"]
)

global_parser.add_argument(
Expand Down Expand Up @@ -224,6 +224,8 @@ def main():
classifier = KNearest(args.k, args.epsilon, args.epsilon_decay, args.time_k)
elif args.classifier == "perfect":
classifier = PerfectSelector(args.time_k)
elif args.classifier == "greedy":
classifier = EpsilonGreedy(args.time_k, args.epsilon)
else:
raise RuntimeError("classifier not properly set")

Expand Down
42 changes: 42 additions & 0 deletions medleysolver/classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,48 @@ def update(self, solved_prob, rewards):
if is_solved(solved_prob.result):
self.solved.append(solved_prob)

class EpsilonGreedy(ClassifierInterface):
def __init__(self, time_k, epsilon):
self.epsilon=epsilon
self.counts=[0 for _ in SOLVERS]
self.values=[0.0 for _ in SOLVERS]
self.totals=[0.0 for _ in SOLVERS]
self.initialized=False
super(EpsilonGreedy, self).__init__(time_k)

def initialize(self):
# nothing to do here
return

def get_ordering(self, point, count, problem):
order = list(SOLVERS.keys())
explore = np.random.binomial(1, self.epsilon)
if explore==1 and sum(self.values)>0:
# return list of solvers sorted by self.values
value_order = sorted(list(range(len(SOLVERS))), key=lambda x: self.values[x], reverse=True)
order = [list(SOLVERS.keys())[int(choice)] for choice in value_order]
return order
else:
# shuffle to explore
np.random.shuffle(order)
return order

def update(self, solved_prob, rewards):
for i,r in enumerate(rewards):
if r > 0:
self.totals[i] += r
self.counts[i] += 1
self.values[i] = self.totals[i]/self.counts[i]
elif r == 0:
self.counts[i] += 1
self.values[i] = self.totals[i]/self.counts[i]
else:
pass
if is_solved(solved_prob.result):
self.solved.append(solved_prob)



class LinearBandit(ClassifierInterface):
def __init__(self, time_k, alpha=2.358):
self.initialized = False
Expand Down

0 comments on commit d2c975c

Please sign in to comment.