Skip to content

Commit

Permalink
contextual time
Browse files Browse the repository at this point in the history
  • Loading branch information
FedericoAureliano committed Oct 13, 2020
1 parent de1d670 commit 2cc4482
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 18 deletions.
8 changes: 6 additions & 2 deletions bin/medley
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import glob
import dill
from medleysolver.runner import execute
from medleysolver.constants import SOLVERS
from medleysolver.timers import Exponential, Constant
from medleysolver.timers import Exponential, Constant, NearestExponential
from medleysolver.classifiers import *

def main():
Expand Down Expand Up @@ -61,7 +61,7 @@ def main():
help="choose how timeout is distributed amongst solvers",
type=str,
default="expo",
choices=["expo", "const"]
choices=["expo", "const", "nearest"]
)

global_parser.add_argument(
Expand Down Expand Up @@ -163,6 +163,10 @@ def main():
if not args.set_const:
args.set_const = args.timeout // len(SOLVERS)
timeout_manager = Constant(args.set_const)
elif args.timeout_manager == "nearest":
if not args.set_lambda:
args.set_lambda = 1 / (args.timeout / len(SOLVERS))
timeout_manager = NearestExponential(args.set_lambda, args.confidence)
else:
raise RuntimeError("timeout_manager not properly set")

Expand Down
26 changes: 17 additions & 9 deletions medleysolver/classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ class ClassifierInterface(object):
def get_ordering(self, point, count):
raise NotImplementedError

def get_nearby_times(self, point, count):
return []

def update(self, solved_prob, rewards):
raise NotImplementedError

Expand All @@ -21,7 +24,7 @@ def get_ordering(self, point, count):
order = list(SOLVERS.keys())
random.shuffle(order)
return order

def update(self, solved_prob, rewards):
return

Expand Down Expand Up @@ -205,15 +208,20 @@ def __init__(self, k, epsilon, decay):
self.solved = []
self.counter = 0

def get_nearby_times(self, point, count):
positions = sorted(self.solved, key=lambda entry: np.linalg.norm(entry.datapoint - point))[:self.k]
positions = [(x.solve_method, x.time) for x in positions]
return positions

def get_ordering(self, point, count):
if np.random.rand() >= self.epsilon * (self.decay ** count) and self.solved:
candidates = sorted(self.solved, key=lambda entry: np.linalg.norm(entry.datapoint - point))[:self.k]
methods = [x.solve_method for x in candidates]
ss = list(SOLVERS.keys())
random.shuffle(ss)
order = sorted(ss, key= lambda x: methods.count(x))
else:
order = Random.get_ordering(self, point, count)
# if np.random.rand() >= self.epsilon * (self.decay ** count) and self.solved:
candidates = sorted(self.solved, key=lambda entry: np.linalg.norm(entry.datapoint - point))[:self.k]
methods = [x.solve_method for x in candidates]
ss = list(SOLVERS.keys())
random.shuffle(ss)
order = sorted(ss, key= lambda x: methods.count(x), reverse=True)
# else:
# order = Random.get_ordering(self, point, count)

return list(unique_everseen(order))

Expand Down
7 changes: 4 additions & 3 deletions medleysolver/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,23 @@ def execute(problems, output, classifier, time_manager, timeout, extra_time_to_f
point = point / (mean+1e-9)

order = classifier.get_ordering(point, c)
times = classifier.get_nearby_times(point, c)
end = time.time()

solver, elapsed, result, rewards, time_spent = apply_ordering(prob, order, timeout - (end - start), time_manager, extra_time_to_first)
solver, elapsed, result, rewards, time_spent = apply_ordering(prob, order, timeout - (end - start), time_manager, extra_time_to_first, times)
solved_prob = Solved_Problem(prob, point, solver, elapsed + (end - start), result, order, time_spent)

classifier.update(solved_prob, rewards)

writer.writerow(solved_prob)


def apply_ordering(problem, order, timeout, time_manager, extra_time_to_first):
def apply_ordering(problem, order, timeout, time_manager, extra_time_to_first, times):
elapsed = 0
rewards = [-1 for _ in SOLVERS] # negative rewards should be ignored.
time_spent = []

budgets = [int(time_manager.get_timeout(solver))+1 for solver in order]
budgets = [int(time_manager.get_timeout(solver, times))+1 for solver in order]

for i in range(len(budgets)):
budgets[i] = min(budgets[i], int(timeout - sum(budgets[:i])))
Expand Down
23 changes: 19 additions & 4 deletions medleysolver/timers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@
from medleysolver.constants import SOLVERS

class TimerInterface(object):
def get_timeout(self, solver):
def get_timeout(self, solver, position):
raise NotImplementedError

def update(self, solver, time, success):
def update(self, solver, time, success, error):
raise NotImplementedError

class Constant(TimerInterface):
def __init__(self, const):
self.const = const

def get_timeout(self, solver):
def get_timeout(self, solver, position):
return self.const

def update(self, solver, time, timeout, success, error):
Expand All @@ -22,7 +22,7 @@ class Exponential(TimerInterface):
def __init__(self, init_lambda, confidence):
self.timers = {solver:ExponentialDist(init_lambda, confidence) for solver in SOLVERS}

def get_timeout(self, solver):
def get_timeout(self, solver, position):
return self.timers[solver].get_cutoff()

def update(self, solver, time, timeout, success, error):
Expand All @@ -35,3 +35,18 @@ def update(self, solver, time, timeout, success, error):
else:
self.timers[solver].add_timeout()

class NearestExponential(TimerInterface):
def __init__(self, init_lambda, confidence):
self.init_lambda = init_lambda
self.confidence = confidence

def get_timeout(self, solver, times):
# want time based on times for same solver at nearby points
timer = ExponentialDist(self.init_lambda, self.confidence)
for (s, t) in times:
if s == solver:
timer.add_sample(t)
return timer.get_cutoff()

def update(self, solver, time, timeout, success, error):
pass

0 comments on commit 2cc4482

Please sign in to comment.