Skip to content

Commit

Permalink
bringing in knearest
Browse files Browse the repository at this point in the history
  • Loading branch information
FedericoAureliano committed Oct 13, 2020
2 parents 5df67d7 + 1c3a06f commit 754b558
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 13 deletions.
11 changes: 10 additions & 1 deletion bin/medley
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def main():
help="select mechanism for choosing orderings of solvers",
type=str,
default="neighbor",
choices=["neighbor", "random", "MLP", "thompson", "linear", "preset", "exp3"]
choices=["neighbor", "random", "MLP", "thompson", "linear", "preset", "exp3", "knearest"]
)

global_parser.add_argument(
Expand Down Expand Up @@ -135,6 +135,13 @@ def main():
type=str
)

global_parser.add_argument(
"--k",
help="set k value for knearest neighbor",
type=int,
default=10
)

args = global_parser.parse_args()

if not args.output.endswith(".csv"):
Expand Down Expand Up @@ -175,6 +182,8 @@ def main():
timeout_manager = Constant(args.timeout)
elif args.classifier == "exp3":
classifier = Exp3(0.07)
elif args.classifier == "knearest":
classifier = KNearest(args.k, args.epsilon, args.epsilon_decay)
else:
raise RuntimeError("classifier not properly set")

Expand Down
38 changes: 29 additions & 9 deletions medleysolver/classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(self, epsilon, decay, kind):
self.decay = decay
self.counter = 0
self.kind = kind
self.k = k

def get_ordering(self, point, count):
if self.kind == "greedy":
Expand Down Expand Up @@ -124,14 +125,14 @@ def __init__(self, kind):

def get_ordering(self, point, count):
if self.kind == "single":
choices = self.dist.get_choice(self.kind)
order = [[list(SOLVERS.keys())[i] for i in choices][0]]
t_order = self.dist.get_ordering()
order = [[list(SOLVERS.keys())[int(choice)] for choice in t_order][0]]
remaining = [x for x in SOLVERS.keys() if x not in order]
random.shuffle(remaining)
order = order + remaining
else:
choices = self.dist.get_choice(self.kind)
order = [list(SOLVERS.keys())[i] for i in choices]
t_order = self.dist.get_ordering()
order = [list(SOLVERS.keys())[int(choice)] for choice in t_order]
return list(unique_everseen(order))

def update(self, solved_prob, rewards):
Expand Down Expand Up @@ -168,12 +169,9 @@ def get_ordering(self, point, count):
for i in range(len(SOLVERS))]

ps = [thetas[i].T @ point + beta.T @ point + self.alpha * np.sqrt(sigmas[i]) for i in range(len(SOLVERS))]
choice = np.random.choice(np.flatnonzero(np.isclose(ps, max(ps)))) #running argmax while arbitrarily breaking ties

order = [list(SOLVERS.keys())[int(choice)]]
remaining = [x for x in SOLVERS.keys() if x not in order]
random.shuffle(remaining)
order = order + remaining
i_order = sorted(random.shuffle(list(range(len(ps)))), key=lambda x: -1 * ps[x])
order = [list(SOLVERS.keys())[int(choice)] for choice in i_order]
return list(unique_everseen(order))

def update(self, solved_prob, rewards):
Expand All @@ -198,3 +196,25 @@ def get_ordering(self, point, count):

def update(self, solved_prob, rewards):
pass

class KNearest(ClassifierInterface):
def __init__(self, k, epsilon, decay):
self.k = k
self.epsilon = epsilon
self.decay = decay
self.solved = []
self.counter = 0

def get_ordering(self, point, count):
if np.random.rand() >= self.epsilon * (self.decay ** count) and self.solved:
candidates = sorted(self.solved, key=lambda entry: np.linalg.norm(entry.datapoint - point))[:self.k]
methods = [x.solve_method for x in candidates]
order = sorted(random.shuffle(list(SOLVERS.keys())), key= lambda x: methods.count(x))
else:
order = Random.get_ordering(self, point, count)

return order
def update(self, solved_prob, rewards):
#TODO: Implement pruning
if is_solved(solved_prob.result):
self.solved.append(solved_prob)
4 changes: 2 additions & 2 deletions medleysolver/compute_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def get_features(file_path,logic="",track=""):
g = z3.Goal()
g.add(z3.parse_smt2_file(file_path))
results = [z3.Probe(x)(g) for x in PROBES]
features = results
features = features + results

cache[file_path][logic][track] = features
return features
Expand All @@ -75,4 +75,4 @@ def get_check_sat(file_path):
line = line[:line.find(';')]
ret += line.count('check-sat')
cached_checksats[file_path] = ret
return ret
return ret
5 changes: 5 additions & 0 deletions medleysolver/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ def get_choice(self, kind="full"):
i = sorted(range(self.n), key=lambda x: samples[x], reverse=True)
return i

def get_ordering(self):
samples = [np.random.beta(self._as[x], self._bs[x]) for x in range(self.n)]
order = sorted(list(range(self.n)), key=lambda x: -1 * samples[x])
return order

def update(self, choice, reward):
"""
didSolve (bool): whether or not previous choice
Expand Down
2 changes: 1 addition & 1 deletion medleysolver/timers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self, const):
def get_timeout(self, solver):
return self.const

def update(self, solver, time, success):
def update(self, solver, time, timeout, success, error):
pass

class Exponential(TimerInterface):
Expand Down

0 comments on commit 754b558

Please sign in to comment.