Skip to content

Commit

Permalink
completed sgd impl
Browse files Browse the repository at this point in the history
  • Loading branch information
Nikhil Pimpalkhare authored and Nikhil Pimpalkhare committed Oct 15, 2020
1 parent e706737 commit 5f5d30e
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
6 changes: 4 additions & 2 deletions bin/medley
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import glob
import dill
from medleysolver.runner import execute
from medleysolver.constants import SOLVERS
from medleysolver.timers import Exponential, Constant, NearestExponential, PerfectTimer
from medleysolver.timers import Exponential, Constant, NearestExponential, PerfectTimer, SGD
from medleysolver.classifiers import *

def main():
Expand Down Expand Up @@ -69,7 +69,7 @@ def main():
help="choose how timeout is distributed amongst solvers",
type=str,
default="expo",
choices=["expo", "const", "nearest", "perfect"]
choices=["expo", "const", "nearest", "perfect", "sgd"]
)

global_parser.add_argument(
Expand Down Expand Up @@ -201,6 +201,8 @@ def main():
timeout_manager = NearestExponential(args.set_lambda, args.confidence, args.timeout)
elif args.timeout_manager == "perfect":
timeout_manager = PerfectTimer()
elif args.timeout_manager == "sgd":
timeout_manager = SGD()
else:
raise RuntimeError("timeout_manager not properly set")

Expand Down
21 changes: 12 additions & 9 deletions medleysolver/timers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from medleysolver.constants import SOLVERS, ERROR_RESULT, SAT_RESULT, UNSAT_RESULT
from sklearn.linear_model import SGDRegressor
from medleysolver.dispatch import output2result

import numpy as np
import csv

class TimerInterface(object):
Expand Down Expand Up @@ -62,21 +62,24 @@ def update(self, solver, time, timeout, success, error, point):
self.naughtylist.add(solver)

class SGD(TimerInterface):
def __init__(self, init_lambda, confidence, T):
def __init__(self):
self.fitted = [False for _ in SOLVERS]
self.models = [SGDRegressor() for _ in SOLVERS]
self.solvers_to_i = {s: i for i, s in enumerate(list(SOLVERS.keys()))}

def get_timeout(self, solver, times, problem, point):
if not self.fitted[solver]: return 60
clf = self.models[solver]
point = point.reshape(1, -1)
sindex = self.solvers_to_i[solver]
if not self.fitted[sindex]: return 60
clf = self.models[sindex]
return clf.predict(point)

def update(self, solver, time, timeout, success, error, point):
clf = self.models[solver]
if self.fitted[solver]:
clf.partial_fit(point, time)
else:
clf.fit(point, time)
point = point.reshape(1, -1)
time = np.array([time])
sindex = self.solvers_to_i[solver]
clf = self.models[sindex]
clf.partial_fit(point, time)

class PerfectTimer(TimerInterface):
def get_timeout(self, solver, position, problem, point):
Expand Down

0 comments on commit 5f5d30e

Please sign in to comment.