From 55ad970faf383074429401308605d9187a87d4cd Mon Sep 17 00:00:00 2001 From: Nikhil Pimpalkhare Date: Tue, 13 Oct 2020 00:22:33 -0700 Subject: [PATCH] added featurization options --- bin/medley | 11 ++++++++++- medleysolver/classifiers.py | 1 - medleysolver/compute_features.py | 18 +++++++++++------- medleysolver/runner.py | 4 ++-- 4 files changed, 23 insertions(+), 11 deletions(-) diff --git a/bin/medley b/bin/medley index 6c3a5b0..054e2c9 100755 --- a/bin/medley +++ b/bin/medley @@ -125,6 +125,15 @@ def main(): default=10 ) + global_parser.add_argument( + "--feature_setting", + "-f", + help="choose how queries are featurized", + type=str, + choices=["bow", "probes", "both"], + default="both" + ) + args = global_parser.parse_args() if not args.output.endswith(".csv"): @@ -174,7 +183,7 @@ def main(): with open(args.load_classifier, "rb") as f: classifier = dill.load(f) - execute(problems, args.output, classifier, timeout_manager, args.timeout) + execute(problems, args.output, classifier, timeout_manager, args.timeout, args.feature_setting) if args.save_classifier: classifier.save(args.save_classifier) diff --git a/medleysolver/classifiers.py b/medleysolver/classifiers.py index 2fc6bf4..fc2a948 100644 --- a/medleysolver/classifiers.py +++ b/medleysolver/classifiers.py @@ -31,7 +31,6 @@ def __init__(self, epsilon, decay): self.epsilon = epsilon self.decay = decay self.counter = 0 - self.k = k def get_ordering(self, point, count): if np.random.rand() >= self.epsilon * (self.decay ** count) and self.solved: diff --git a/medleysolver/compute_features.py b/medleysolver/compute_features.py index 5f0293b..c571d3d 100644 --- a/medleysolver/compute_features.py +++ b/medleysolver/compute_features.py @@ -48,19 +48,23 @@ def get_syntactic_count_features(file_path): return features cache = {} -def get_features(file_path,logic="",track=""): +def get_features(file_path, feature_setting, logic="",track=""): if file_path not in cache: cache[file_path] = {} if logic not in cache[file_path]: cache[file_path][logic] = {} if track not in cache[file_path][logic]: cache[file_path][logic][track] = {} - - features = get_syntactic_count_features(file_path) - g = z3.Goal() - g.add(z3.parse_smt2_file(file_path)) - results = [z3.Probe(x)(g) for x in PROBES] - features = features + results + if feature_setting == "bow": + features = get_syntactic_count_features(file_path) + elif feature_setting == "probes": + g = z3.Goal() + g.add(z3.parse_smt2_file(file_path)) + features = [z3.Probe(x)(g) for x in PROBES] + else: + g = z3.Goal() + g.add(z3.parse_smt2_file(file_path)) + features = get_syntactic_count_features(file_path) + [z3.Probe(x)(g) for x in PROBES] cache[file_path][logic][track] = features return features diff --git a/medleysolver/runner.py b/medleysolver/runner.py index 22c5d2c..fb4d6ce 100644 --- a/medleysolver/runner.py +++ b/medleysolver/runner.py @@ -6,13 +6,13 @@ from medleysolver.distributions import ExponentialDist from medleysolver.dispatch import run_problem -def execute(problems, output, classifier, time_manager, timeout): +def execute(problems, output, classifier, time_manager, timeout, feature_setting): mean = 0 writer = csv.writer(open(output, 'w')) for c, prob in tqdm.tqdm(enumerate(problems, 1)): start = time.time() - point = np.array(get_features(prob)) + point = np.array(get_features(prob, feature_setting)) #normalizing point mean = (c - 1) / c * mean + 1 / c * point point = point / (mean+1e-9)