Skip to content

Commit

Permalink
added featurization options
Browse files Browse the repository at this point in the history
  • Loading branch information
Nikhil Pimpalkhare authored and Nikhil Pimpalkhare committed Oct 13, 2020
1 parent 918ad9b commit 55ad970
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 11 deletions.
11 changes: 10 additions & 1 deletion bin/medley
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,15 @@ def main():
default=10
)

global_parser.add_argument(
"--feature_setting",
"-f",
help="choose how queries are featurized",
type=str,
choices=["bow", "probes", "both"],
default="both"
)

args = global_parser.parse_args()

if not args.output.endswith(".csv"):
Expand Down Expand Up @@ -174,7 +183,7 @@ def main():
with open(args.load_classifier, "rb") as f:
classifier = dill.load(f)

execute(problems, args.output, classifier, timeout_manager, args.timeout)
execute(problems, args.output, classifier, timeout_manager, args.timeout, args.feature_setting)

if args.save_classifier:
classifier.save(args.save_classifier)
Expand Down
1 change: 0 additions & 1 deletion medleysolver/classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def __init__(self, epsilon, decay):
self.epsilon = epsilon
self.decay = decay
self.counter = 0
self.k = k

def get_ordering(self, point, count):
if np.random.rand() >= self.epsilon * (self.decay ** count) and self.solved:
Expand Down
18 changes: 11 additions & 7 deletions medleysolver/compute_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,23 @@ def get_syntactic_count_features(file_path):
return features

cache = {}
def get_features(file_path,logic="",track=""):
def get_features(file_path, feature_setting, logic="",track=""):
if file_path not in cache:
cache[file_path] = {}
if logic not in cache[file_path]:
cache[file_path][logic] = {}
if track not in cache[file_path][logic]:
cache[file_path][logic][track] = {}

features = get_syntactic_count_features(file_path)
g = z3.Goal()
g.add(z3.parse_smt2_file(file_path))
results = [z3.Probe(x)(g) for x in PROBES]
features = features + results
if feature_setting == "bow":
features = get_syntactic_count_features(file_path)
elif feature_setting == "probes":
g = z3.Goal()
g.add(z3.parse_smt2_file(file_path))
features = [z3.Probe(x)(g) for x in PROBES]
else:
g = z3.Goal()
g.add(z3.parse_smt2_file(file_path))
features = get_syntactic_count_features(file_path) + [z3.Probe(x)(g) for x in PROBES]

cache[file_path][logic][track] = features
return features
Expand Down
4 changes: 2 additions & 2 deletions medleysolver/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
from medleysolver.distributions import ExponentialDist
from medleysolver.dispatch import run_problem

def execute(problems, output, classifier, time_manager, timeout):
def execute(problems, output, classifier, time_manager, timeout, feature_setting):
mean = 0
writer = csv.writer(open(output, 'w'))

for c, prob in tqdm.tqdm(enumerate(problems, 1)):
start = time.time()
point = np.array(get_features(prob))
point = np.array(get_features(prob, feature_setting))
#normalizing point
mean = (c - 1) / c * mean + 1 / c * point
point = point / (mean+1e-9)
Expand Down

0 comments on commit 55ad970

Please sign in to comment.