Skip to content

Commit

Permalink
fix and reverting to old timeout punishment
Browse files Browse the repository at this point in the history
  • Loading branch information
FedericoAureliano committed Oct 13, 2020
1 parent 754b558 commit de1d670
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 12 deletions.
1 change: 0 additions & 1 deletion bin/medley
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def main():
)
global_parser.add_argument(
"--kind",
"-k",
help="configurations for neighbor and thompson",
type=str,
default="full",
Expand Down
13 changes: 8 additions & 5 deletions medleysolver/classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def __init__(self, epsilon, decay, kind):
self.decay = decay
self.counter = 0
self.kind = kind
self.k = k

def get_ordering(self, point, count):
if self.kind == "greedy":
Expand Down Expand Up @@ -169,8 +168,9 @@ def get_ordering(self, point, count):
for i in range(len(SOLVERS))]

ps = [thetas[i].T @ point + beta.T @ point + self.alpha * np.sqrt(sigmas[i]) for i in range(len(SOLVERS))]

i_order = sorted(random.shuffle(list(range(len(ps)))), key=lambda x: -1 * ps[x])
ss = list(range(len(ps)))
random.shuffle(ss)
i_order = sorted(ss, key=lambda x: -1 * ps[x])
order = [list(SOLVERS.keys())[int(choice)] for choice in i_order]
return list(unique_everseen(order))

Expand Down Expand Up @@ -209,11 +209,14 @@ def get_ordering(self, point, count):
if np.random.rand() >= self.epsilon * (self.decay ** count) and self.solved:
candidates = sorted(self.solved, key=lambda entry: np.linalg.norm(entry.datapoint - point))[:self.k]
methods = [x.solve_method for x in candidates]
order = sorted(random.shuffle(list(SOLVERS.keys())), key= lambda x: methods.count(x))
ss = list(SOLVERS.keys())
random.shuffle(ss)
order = sorted(ss, key= lambda x: methods.count(x))
else:
order = Random.get_ordering(self, point, count)

return order
return list(unique_everseen(order))

def update(self, solved_prob, rewards):
#TODO: Implement pruning
if is_solved(solved_prob.result):
Expand Down
7 changes: 1 addition & 6 deletions medleysolver/timers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,5 @@ def update(self, solver, time, timeout, success, error):
if error:
self.timers[solver].add_error()
else:
if time < timeout/3:
# give more time
self.timers[solver].add_timeout()
else:
# remove time (assuming 1 is small compared to timeout)
self.timers[solver].add_sample(1)
self.timers[solver].add_timeout()

0 comments on commit de1d670

Please sign in to comment.