diff --git a/bin/medley b/bin/medley index b778c52..e3f39ce 100755 --- a/bin/medley +++ b/bin/medley @@ -5,7 +5,7 @@ import glob import dill from medleysolver.runner import execute from medleysolver.constants import SOLVERS -from medleysolver.timers import Exponential, Constant, NearestExponential, PerfectTimer +from medleysolver.timers import Exponential, Constant, NearestExponential, PerfectTimer, SGD from medleysolver.classifiers import * def main(): @@ -69,7 +69,7 @@ def main(): help="choose how timeout is distributed amongst solvers", type=str, default="expo", - choices=["expo", "const", "nearest", "perfect"] + choices=["expo", "const", "nearest", "perfect", "sgd"] ) global_parser.add_argument( @@ -201,6 +201,8 @@ def main(): timeout_manager = NearestExponential(args.set_lambda, args.confidence, args.timeout) elif args.timeout_manager == "perfect": timeout_manager = PerfectTimer() + elif args.timeout_manager == "sgd": + timeout_manager = SGD() else: raise RuntimeError("timeout_manager not properly set") diff --git a/medleysolver/timers.py b/medleysolver/timers.py index caa9add..2ed4f3c 100644 --- a/medleysolver/timers.py +++ b/medleysolver/timers.py @@ -2,7 +2,7 @@ from medleysolver.constants import SOLVERS, ERROR_RESULT, SAT_RESULT, UNSAT_RESULT from sklearn.linear_model import SGDRegressor from medleysolver.dispatch import output2result - +import numpy as np import csv class TimerInterface(object): @@ -62,21 +62,24 @@ def update(self, solver, time, timeout, success, error, point): self.naughtylist.add(solver) class SGD(TimerInterface): - def __init__(self, init_lambda, confidence, T): + def __init__(self): self.fitted = [False for _ in SOLVERS] self.models = [SGDRegressor() for _ in SOLVERS] + self.solvers_to_i = {s: i for i, s in enumerate(list(SOLVERS.keys()))} def get_timeout(self, solver, times, problem, point): - if not self.fitted[solver]: return 60 - clf = self.models[solver] + point = point.reshape(1, -1) + sindex = self.solvers_to_i[solver] + if not self.fitted[sindex]: return 60 + clf = self.models[sindex] return clf.predict(point) def update(self, solver, time, timeout, success, error, point): - clf = self.models[solver] - if self.fitted[solver]: - clf.partial_fit(point, time) - else: - clf.fit(point, time) + point = point.reshape(1, -1) + time = np.array([time]) + sindex = self.solvers_to_i[solver] + clf = self.models[sindex] + clf.partial_fit(point, time) class PerfectTimer(TimerInterface): def get_timeout(self, solver, position, problem, point):