From 29ade86f147d0cbb20bbf7ba04a1bdcf42697eac Mon Sep 17 00:00:00 2001 From: Andrew Cheung Date: Wed, 22 Jan 2025 15:10:55 -0800 Subject: [PATCH 1/4] Get started --- src/chomper.rs | 7 +++---- src/language.rs | 15 +++++++++++++++ 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/src/chomper.rs b/src/chomper.rs index af03d50..1ddfccd 100644 --- a/src/chomper.rs +++ b/src/chomper.rs @@ -296,10 +296,9 @@ pub trait Chomper { format!( r#" (rule - (({UNIVERSAL_RELATION} {lhs})) - ((let temp (ite {cond} {rhs} {lhs})) - ({UNIVERSAL_RELATION} temp) - (union {lhs} temp)) + (({UNIVERSAL_RELATION} {lhs}) + (= (Condition {cond}) (TRUE))) + ((union {lhs} {rhs})) :ruleset cond-rewrites) "# ) diff --git a/src/language.rs b/src/language.rs index 752f323..986ba19 100644 --- a/src/language.rs +++ b/src/language.rs @@ -225,14 +225,29 @@ pub trait ChompyLanguage { (function Const ({const_type}) {name}) (function Var (String) {name}) {func_defs_str} + (function eclass ({name}) i64 :merge (min old new)) (relation universe ({name})) (relation cond-equal ({name} {name})) +(datatype Predicate + (TRUE) + (Condition {name})) + +(relation Implies (Predicate Predicate)) + + ;;; forward ruleset definitions (ruleset eclass-report) (ruleset non-cond-rewrites) (ruleset cond-rewrites) +(ruleset condition-propogation) + +(rule + ((Implies ?a ?b) + (= ?a (TRUE))) + ((union ?b (TRUE))) + :ruleset condition-propogation) ;;; a "function", more or less, that prints out each e-class and its ;;; term. From b85f9d507dbc4a6e19043dd0406e61b0dce6ea01 Mon Sep 17 00:00:00 2001 From: Andrew Cheung Date: Fri, 24 Jan 2025 19:38:49 -0800 Subject: [PATCH 2/4] Do a little more --- src/chomper.rs | 259 +++++++++++++++++++++++++++++++++++++++++----- src/language.rs | 11 +- tests/math/mod.rs | 3 +- 3 files changed, 239 insertions(+), 34 deletions(-) diff --git a/src/chomper.rs b/src/chomper.rs index 1ddfccd..568fe4b 100644 --- a/src/chomper.rs +++ b/src/chomper.rs @@ -1,10 +1,12 @@ -use crate::language::MathLang; +use crate::language::{mathlang_to_z3, MathLang}; use crate::PredicateInterpreter; use std::fmt::Debug; +use std::hash::Hash; use std::{fmt::Display, str::FromStr, sync::Arc}; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; +use z3::ast::Ast; use crate::{ ite::DummySort, @@ -75,7 +77,7 @@ impl Ord for Rule { /// Chompers manage the state of the e-graph. pub trait Chomper { - type Constant: Display + Debug + Clone + PartialEq; + type Constant: Display + Debug + Clone + PartialEq + Hash + Eq; fn get_language(&self) -> Box>; fn make_pred_interpreter() -> impl PredicateInterpreter + 'static; @@ -101,6 +103,12 @@ pub trait Chomper { egraph .parse_and_run_program(None, &self.get_language().to_egglog_src()) .unwrap(); + + for implication_prog in self.predicate_implication_ruleset() { + egraph + .parse_and_run_program(None, &implication_prog.to_string()) + .unwrap(); + } egraph } @@ -146,6 +154,7 @@ pub trait Chomper { info!("running rewrites: {}", prog); let results = egraph.parse_and_run_program(None, &prog).unwrap(); + log_rewrite_stats(results); } @@ -173,6 +182,70 @@ pub trait Chomper { eclass_term_map } + fn validate_implication(&self, p1: &Sexp, p2: &Sexp) -> bool; + + /// Returns a vector of Egglog rules, i.e., Egglog programs. + fn predicate_implication_ruleset(&self) -> Vec { + fn generalize_predicate(sexp: &Sexp, vars: Vec) -> Sexp { + match sexp { + Sexp::Atom(a) => { + if vars.contains(a) { + Sexp::from_str(&format!("?{}", a)).unwrap() + } else { + sexp.clone() + } + } + Sexp::List(l) => Sexp::List( + l.iter() + .map(|s| generalize_predicate(s, vars.clone())) + .collect(), + ), + } + } + + // I don't know if this is smart. + let predicates: Vec = self + .get_language() + .get_predicates() + .filter(Filter::Canon(self.get_language().get_vars())) + .force(); + + let mut result: Vec = vec![]; + + let vars = self.get_language().get_vars(); + + // go pairwise + for p in &predicates { + for q in &predicates { + // creating dummy rule to ensure + if p == q + || !all_variables_bound(&Rule { + condition: None, + lhs: p.clone(), + rhs: q.clone(), + }) + { + continue; + } + + // if p => q, then add (p -> q) to the list of implications. + if self.validate_implication(&p, &q) { + let p = generalize_predicate(&p, vars.clone()); + let q = generalize_predicate(&q, vars.clone()); + result.push(format!( + r#" +(rule + ((= (Condition {p}) (TRUE))) + ((union (Condition {q}) (TRUE))) + :ruleset condition-propagation) + "# + )); + } + } + } + result + } + /// Returns a vector of candidate rules between e-classes in the e-graph. fn cvec_match( &self, @@ -187,9 +260,27 @@ pub trait Chomper { for i in 0..ec_keys.len() { let ec1 = ec_keys[i]; let term1 = eclass_term_map.get(ec1).unwrap(); + // TODO: + // if term1 = (Const x) { + // continue; + // } + if let Sexp::List(l) = term1 { + if l[0] == Sexp::Atom("Const".to_string()) { + continue; + } + } let cvec1 = self.get_language().eval(term1, env); + // if all terms in the cvec are equal, cotinue. + if cvec1.iter().all(|x| x == cvec1.first().unwrap()) { + continue; + } for ec2 in ec_keys.iter().skip(i + 1) { let term2 = eclass_term_map.get(ec2).unwrap(); + if let Sexp::List(l) = term2 { + if l[0] == Sexp::Atom("Const".to_string()) { + continue; + } + } let cvec2 = self.get_language().eval(term2, env); if cvec1 == cvec2 { // we add (l ~> r) and (r ~> l) as candidate rules, because @@ -215,10 +306,42 @@ pub trait Chomper { if mask.iter().all(|x| *x) { panic!("cvec1 != cvec2, yet we have a mask of all true"); } + if mask.iter().all(|x| !x) { continue; } + // if under the mask, all of cvec1 is the same value, then skip. + let cvec1_vals_under_pred = cvec1 + .iter() + .zip(mask.iter()) + .filter(|(_, &b)| b) + .map(|(x, _)| x.clone()) + .collect::>(); + + let cvec2_vals_under_pred = cvec2 + .iter() + .zip(mask.iter()) + .filter(|(_, &b)| b) + .map(|(x, _)| x.clone()) + .collect::>(); + + // TODO: make this happen conditionally via a flag + // get num of unique values under the predicate. + let num_unique_vals = + cvec1_vals_under_pred.iter().collect::>().len(); + + if num_unique_vals == 1 { + continue; + } + + let num_unique_vals = + cvec2_vals_under_pred.iter().collect::>().len(); + + if num_unique_vals == 1 { + continue; + } + if let Some(preds) = predicate_map.get(&mask) { for pred in preds { candidate_rules.push(Rule { @@ -244,6 +367,24 @@ pub trait Chomper { &self, ) -> HashMap>>; + fn run_condition_propagation(&self, egraph: &mut EGraph, iters: Option) { + if let Some(iters) = iters { + egraph + .parse_and_run_program( + None, + &format!( + "(run-schedule (repeat {iters} (run condition-propagation)))", + iters = iters + ), + ) + .unwrap(); + } else { + egraph + .parse_and_run_program(None, "(run-schedule (saturate condition-propagation))") + .unwrap(); + } + } + /// Returns if the given rule can be derived from the ruleset within the given e-graph. /// Assumes that `rule` has been generalized (see `ChompyLanguage::generalize_rule`). fn rule_is_derivable( @@ -252,30 +393,71 @@ pub trait Chomper { rule: &Rule, env_cache: &mut HashMap<(String, String), Vec>>, ) -> bool { + // TODO: make a cleaner implementation of below: + if let Some(cond) = &rule.condition { + if cond.to_string() == rule.lhs.to_string() { + info!( + "skipping bad rule with bad form of if c then c ~> r : {}", + rule + ); + return true; + } + } + info!("assessing rule: {}", rule); - // terms is a vector of (lhs, rhs) pairs with NO variables--not even 1... - let terms: Vec<(Sexp, Sexp)> = self.get_language().concretize_rule(rule, env_cache); - const MAX_DERIVABILITY_ITERATIONS: usize = 7; - for (lhs, rhs) in terms { - let mut egraph = initial_egraph.clone(); - self.add_term(&lhs, &mut egraph, None); - self.add_term(&rhs, &mut egraph, None); - let l_sexpr = format_sexp(&lhs); - let r_sexpr = format_sexp(&rhs); - self.run_rewrites(&mut egraph, Some(MAX_DERIVABILITY_ITERATIONS)); - // self.run_rewrites(&mut egraph, None); - let result = - egraph.parse_and_run_program(None, &format!("(check (= {l_sexpr} {r_sexpr}))")); - if !result.is_ok() { - // the existing ruleset was unable to derive the equality. - return false; + fn simple_concretize(sexp: &Sexp) -> Sexp { + match sexp { + Sexp::Atom(a) => { + if a.starts_with("?") { + Sexp::from_str(format!("(Var {})", a[1..].to_string()).as_str()).unwrap() + } else { + sexp.clone() + } + } + Sexp::List(l) => Sexp::List(l.iter().map(simple_concretize).collect()), } } + + // terms is a vector of (lhs, rhs) pairs with NO variables--not even 1... + + let lhs = simple_concretize(&rule.lhs); + let rhs = simple_concretize(&rule.rhs); + const MAX_DERIVABILITY_ITERATIONS: usize = 3; + + let mut egraph = initial_egraph.clone(); + self.add_term(&lhs, &mut egraph, None); + self.add_term(&rhs, &mut egraph, None); + let l_sexpr = format_sexp(&lhs); + let r_sexpr = format_sexp(&rhs); + if let Some(cond) = &rule.condition { + let cond = format_sexp(&simple_concretize(cond)); + egraph + .parse_and_run_program(None, &format!("(union (Condition {cond}) (TRUE))")) + .unwrap(); + self.run_condition_propagation(&mut egraph, Some(MAX_DERIVABILITY_ITERATIONS)); + } + self.run_rewrites(&mut egraph, Some(MAX_DERIVABILITY_ITERATIONS)); + let result = self.check_equality(&mut egraph, &lhs, &rhs); + if !result { + // the existing ruleset was unable to derive the equality. + return false; + } // the existing ruleset was able to derive the equality on all the given examples. true } + fn check_equality(&self, egraph: &mut EGraph, lhs: &Sexp, rhs: &Sexp) -> bool { + let res = egraph + .parse_and_run_program( + None, + &format!("(check (= {} {}))", format_sexp(lhs), format_sexp(rhs)), + ) + .is_ok(); + let log = egraph.parse_and_run_program(None, "(print-stats)").unwrap(); + res + } + fn add_rewrite(&self, egraph: &mut EGraph, rule: &Rule) { let lhs = format_sexp(&rule.lhs); let rhs = format_sexp(&rule.rhs); @@ -349,13 +531,13 @@ pub trait Chomper { let mut old_workload = atoms.clone(); let mut max_eclass_id: usize = 1; for size in 1..=max_size { - info!("CONSIDERING PROGRAMS OF SIZE {}:", size); + println!("CONSIDERING PROGRAMS OF SIZE {}:", size); let new_workload = atoms.clone().append( language .produce(&old_workload.clone()) .filter(Filter::MetricEq(Metric::Atoms, size)), ); - info!("workload len: {}", new_workload.force().len()); + println!("workload len: {}", new_workload.force().len()); for term in &new_workload.force() { self.add_term(term, &mut egraph, Some(max_eclass_id)); max_eclass_id += 1; @@ -384,6 +566,7 @@ pub trait Chomper { { break; } + println!("NUM CANDIDATES: {}", candidates.len()); seen_rules.extend(candidates.iter().map(|rule| rule.to_string())); let mut just_rewrite_egraph = self.get_initial_egraph(); for rule in rules.iter() { @@ -520,7 +703,7 @@ fn log_rewrite_stats(outputs: Vec) { s.split('s').next().unwrap().parse().unwrap() } - let long_time = 0.0001; + let long_time = 0.1; let last_two = outputs.iter().rev().take(2).collect::>(); for line in last_two.iter().rev() { // the last, third to last, and fifth to last tokens are the relevant ones. @@ -529,10 +712,13 @@ fn log_rewrite_stats(outputs: Vec) { let apply_time = chop_off_seconds(tokens[tokens.len() - 3]); let search_time = chop_off_seconds(tokens[tokens.len() - 5]); if search_time > long_time || apply_time > long_time || rebuild_time > long_time { - info!("Running rewrites took a long time!"); - info!("Egglog output:"); - info!("{}", line); + info!("LONG TIME"); } + // if search_time > long_time || apply_time > long_time || rebuild_time > long_time { + // + info!("Egglog output:"); + info!("{}", line); + // } } } /// A sample implementation of the Chomper trait for the MathLang language. @@ -541,6 +727,31 @@ pub struct MathChomper; impl Chomper for MathChomper { type Constant = i64; + fn validate_implication(&self, p1: &Sexp, p2: &Sexp) -> bool { + // TODO: Vivien suggests using Z3's incremental mode to avoid having to tear down and + // rebuild the context every time. + let p1: MathLang = p1.clone().into(); + let p2: MathLang = p2.clone().into(); + let mut cfg = z3::Config::new(); + cfg.set_timeout_msec(1000); + let ctx = z3::Context::new(&cfg); + let solver = z3::Solver::new(&ctx); + let p1 = mathlang_to_z3(&ctx, &MathLang::from(p1.clone())); + let p2 = mathlang_to_z3(&ctx, &MathLang::from(p2.clone())); + let one = z3::ast::Int::from_i64(&ctx, 1); + let assert_prog = &z3::ast::Bool::implies(&p1._eq(&one), &p2._eq(&one)); + solver.assert(&assert_prog); + let result = solver.check(); + match result { + z3::SatResult::Sat => true, + z3::SatResult::Unsat => false, + z3::SatResult::Unknown => { + info!("Z3 could not determine the validity of the implication."); + false + } + } + } + fn make_pred_interpreter() -> impl crate::PredicateInterpreter { #[derive(Debug)] struct DummyPredicateInterpreter; diff --git a/src/language.rs b/src/language.rs index 986ba19..6412cb2 100644 --- a/src/language.rs +++ b/src/language.rs @@ -234,20 +234,13 @@ pub trait ChompyLanguage { (TRUE) (Condition {name})) -(relation Implies (Predicate Predicate)) ;;; forward ruleset definitions (ruleset eclass-report) (ruleset non-cond-rewrites) (ruleset cond-rewrites) -(ruleset condition-propogation) - -(rule - ((Implies ?a ?b) - (= ?a (TRUE))) - ((union ?b (TRUE))) - :ruleset condition-propogation) +(ruleset condition-propagation) ;;; a "function", more or less, that prints out each e-class and its ;;; term. @@ -738,7 +731,7 @@ impl ChompyLanguage for MathLang { /// Converts the given `MathLang` term to a `z3::ast::Int`. This function is useful for /// validating rules in the `MathLang` language. -fn mathlang_to_z3<'a>(ctx: &'a z3::Context, math_lang: &MathLang) -> z3::ast::Int<'a> { +pub fn mathlang_to_z3<'a>(ctx: &'a z3::Context, math_lang: &MathLang) -> z3::ast::Int<'a> { let zero = z3::ast::Int::from_i64(ctx, 0); let one = z3::ast::Int::from_i64(ctx, 1); match math_lang { diff --git a/tests/math/mod.rs b/tests/math/mod.rs index 767c548..4e404ba 100644 --- a/tests/math/mod.rs +++ b/tests/math/mod.rs @@ -20,7 +20,8 @@ pub mod tests { for p in predicates.force() { println!("Predicate: {}", p); } - let rules = chomper.run_chompy(8); + + let rules = chomper.run_chompy(10); let hand_picked_rules = vec![ Rule { condition: Sexp::from_str("(Neq ?x (Const 0))").ok(), From 4ecdbf519e7c203b720bc134dca3013f6e18bbde Mon Sep 17 00:00:00 2001 From: Andrew Cheung Date: Mon, 27 Jan 2025 13:27:03 -0800 Subject: [PATCH 3/4] Some debugging prints --- src/chomper.rs | 41 +++++++++++++++++++++++++++++++---------- src/language.rs | 36 ++++++++++++++++++++++-------------- 2 files changed, 53 insertions(+), 24 deletions(-) diff --git a/src/chomper.rs b/src/chomper.rs index 568fe4b..6e267fa 100644 --- a/src/chomper.rs +++ b/src/chomper.rs @@ -166,11 +166,25 @@ pub trait Chomper { (saturate eclass-report)) (pop) "#; + let size_prog = egraph.parse_and_run_program(None, "(print-size)").unwrap(); + println!("EGRAPH SIZE BEFORE ECLASS REPORT:"); + for line in size_prog { + println!("{}", line); + } + let mut outputs = egraph .parse_and_run_program(None, eclass_report_prog) .unwrap() .into_iter() .peekable(); + + let size_prog = egraph.parse_and_run_program(None, "(print-size)").unwrap(); + + println!("EGRAPH SIZE AFTER ECLASS REPORT:"); + for line in size_prog { + println!("{}", line); + } + let mut eclass_term_map = HashMap::default(); while outputs.peek().is_some() { outputs.next().unwrap(); @@ -257,6 +271,11 @@ pub trait Chomper { let mut candidate_rules = vec![]; let ec_keys: Vec<&usize> = eclass_term_map.keys().collect(); + println!( + "how many eclasses do we have? well, we have {} eclasses", + ec_keys.len() + ); + for i in 0..ec_keys.len() { let ec1 = ec_keys[i]; let term1 = eclass_term_map.get(ec1).unwrap(); @@ -528,15 +547,20 @@ pub trait Chomper { for term in atoms.force() { self.add_term(&term, &mut egraph, None); } - let mut old_workload = atoms.clone(); let mut max_eclass_id: usize = 1; + + // chompy does not consider terms with constants as subterms after programs + // the size exceeds a certain threshold. + let const_threshold = 5; + for size in 1..=max_size { println!("CONSIDERING PROGRAMS OF SIZE {}:", size); - let new_workload = atoms.clone().append( - language - .produce(&old_workload.clone()) - .filter(Filter::MetricEq(Metric::Atoms, size)), - ); + let mut new_workload = atoms.clone().append(language.produce(size)); + + if size > const_threshold { + new_workload = new_workload.filter(Filter::Excludes("Const".parse().unwrap())); + } + println!("workload len: {}", new_workload.force().len()); for term in &new_workload.force() { self.add_term(term, &mut egraph, Some(max_eclass_id)); @@ -545,13 +569,10 @@ pub trait Chomper { panic!("max eclass id reached"); } } - if !new_workload.force().is_empty() { - old_workload = new_workload; - } let mut seen_rules: HashSet = Default::default(); loop { info!("trying to merge new terms using existing rewrites..."); - self.run_rewrites(&mut egraph, None); + self.run_rewrites(&mut egraph, Some(7)); info!("i'm done running rewrites"); let mut candidates = self diff --git a/src/language.rs b/src/language.rs index 6412cb2..8d41bcc 100644 --- a/src/language.rs +++ b/src/language.rs @@ -5,6 +5,7 @@ use rand::rngs::StdRng; use rand::Rng; use ruler::{ enumo::{Sexp, Workload}, + recipe_utils::iter_metric, HashMap, }; @@ -148,21 +149,28 @@ pub trait ChompyLanguage { /// MathLang::from(x.clone())).collect::>(); /// assert_eq!(expected, actual); /// ``` - fn produce(&self, old_workload: &Workload) -> Workload { - let mut result_workload = Workload::empty(); - let funcs = self.get_funcs(); - for arity in 0..funcs.len() { - let sketch = "(FUNC ".to_string() + &" EXPR ".repeat(arity) + ")"; - let funcs = Workload::new(funcs[arity].clone()); - - result_workload = Workload::append( - result_workload, - Workload::new(&[sketch.to_string()]) - .plug("FUNC", &funcs) - .plug("EXPR", old_workload), - ); + /// TODO: we can probably get away with just using `iter_metric` instead of + /// rewriting this subroutine that just ends up calling `iter_metric` anyway. + fn produce(&self, size: usize) -> Workload { + let mut funcs_and_atoms: Workload = Workload::empty(); + // add all the base atoms. + funcs_and_atoms = Workload::append(funcs_and_atoms, self.base_atoms()); + + // add all the functions. + for arity in 0..self.get_funcs().len() { + let funcs = self.get_funcs()[arity].clone(); + let mut new_workload = Workload::empty(); + for func in funcs { + let sketch = "(FUNC ".to_string() + &" EXPR ".repeat(arity) + ")"; + new_workload = Workload::append( + new_workload, + Workload::new(&[sketch.to_string()]).plug("FUNC", &Workload::new(&[func])), + ); + } + funcs_and_atoms = Workload::append(funcs_and_atoms, new_workload); } - result_workload + + iter_metric(funcs_and_atoms, "EXPR", ruler::enumo::Metric::Atoms, size) } /// Returns the base set of atoms in the language. From 255ba4718696ea2f045ac4488e036b17ea9d5b27 Mon Sep 17 00:00:00 2001 From: Andrew Cheung Date: Thu, 30 Jan 2025 11:22:05 -0800 Subject: [PATCH 4/4] Clippy --- src/chomper.rs | 53 +++++++++++++---------------------------------- src/language.rs | 18 +++++++++------- tests/math/mod.rs | 2 +- tests/mod.rs | 2 +- 4 files changed, 27 insertions(+), 48 deletions(-) diff --git a/src/chomper.rs b/src/chomper.rs index 6e267fa..2dc1094 100644 --- a/src/chomper.rs +++ b/src/chomper.rs @@ -15,7 +15,7 @@ use crate::{ use egglog::{sort::EqSort, EGraph}; use log::info; use ruler::{ - enumo::{Filter, Metric, Sexp}, + enumo::{Filter, Sexp}, HashMap, HashSet, }; @@ -166,11 +166,6 @@ pub trait Chomper { (saturate eclass-report)) (pop) "#; - let size_prog = egraph.parse_and_run_program(None, "(print-size)").unwrap(); - println!("EGRAPH SIZE BEFORE ECLASS REPORT:"); - for line in size_prog { - println!("{}", line); - } let mut outputs = egraph .parse_and_run_program(None, eclass_report_prog) @@ -178,13 +173,6 @@ pub trait Chomper { .into_iter() .peekable(); - let size_prog = egraph.parse_and_run_program(None, "(print-size)").unwrap(); - - println!("EGRAPH SIZE AFTER ECLASS REPORT:"); - for line in size_prog { - println!("{}", line); - } - let mut eclass_term_map = HashMap::default(); while outputs.peek().is_some() { outputs.next().unwrap(); @@ -243,9 +231,9 @@ pub trait Chomper { } // if p => q, then add (p -> q) to the list of implications. - if self.validate_implication(&p, &q) { - let p = generalize_predicate(&p, vars.clone()); - let q = generalize_predicate(&q, vars.clone()); + if self.validate_implication(p, q) { + let p = generalize_predicate(p, vars.clone()); + let q = generalize_predicate(q, vars.clone()); result.push(format!( r#" (rule @@ -271,10 +259,7 @@ pub trait Chomper { let mut candidate_rules = vec![]; let ec_keys: Vec<&usize> = eclass_term_map.keys().collect(); - println!( - "how many eclasses do we have? well, we have {} eclasses", - ec_keys.len() - ); + info!("number of e-classes: {}", ec_keys.len()); for i in 0..ec_keys.len() { let ec1 = ec_keys[i]; @@ -406,12 +391,7 @@ pub trait Chomper { /// Returns if the given rule can be derived from the ruleset within the given e-graph. /// Assumes that `rule` has been generalized (see `ChompyLanguage::generalize_rule`). - fn rule_is_derivable( - &self, - initial_egraph: &EGraph, - rule: &Rule, - env_cache: &mut HashMap<(String, String), Vec>>, - ) -> bool { + fn rule_is_derivable(&self, initial_egraph: &EGraph, rule: &Rule) -> bool { // TODO: make a cleaner implementation of below: if let Some(cond) = &rule.condition { if cond.to_string() == rule.lhs.to_string() { @@ -428,8 +408,8 @@ pub trait Chomper { fn simple_concretize(sexp: &Sexp) -> Sexp { match sexp { Sexp::Atom(a) => { - if a.starts_with("?") { - Sexp::from_str(format!("(Var {})", a[1..].to_string()).as_str()).unwrap() + if let Some(stripped) = a.strip_prefix("?") { + Sexp::from_str(format!("(Var {})", stripped).as_str()).unwrap() } else { sexp.clone() } @@ -447,8 +427,6 @@ pub trait Chomper { let mut egraph = initial_egraph.clone(); self.add_term(&lhs, &mut egraph, None); self.add_term(&rhs, &mut egraph, None); - let l_sexpr = format_sexp(&lhs); - let r_sexpr = format_sexp(&rhs); if let Some(cond) = &rule.condition { let cond = format_sexp(&simple_concretize(cond)); egraph @@ -467,14 +445,12 @@ pub trait Chomper { } fn check_equality(&self, egraph: &mut EGraph, lhs: &Sexp, rhs: &Sexp) -> bool { - let res = egraph + egraph .parse_and_run_program( None, &format!("(check (= {} {}))", format_sexp(lhs), format_sexp(rhs)), ) - .is_ok(); - let log = egraph.parse_and_run_program(None, "(print-stats)").unwrap(); - res + .is_ok() } fn add_rewrite(&self, egraph: &mut EGraph, rule: &Rule) { @@ -539,7 +515,6 @@ pub trait Chomper { let mut egraph = self.get_initial_egraph(); let env = self.initialize_env(); - let env_cache = &mut HashMap::default(); let language = self.get_language(); let mut rules: Vec = vec![]; let atoms = language.base_atoms(); @@ -597,7 +572,7 @@ pub trait Chomper { for rule in &candidates[..] { let valid = language.validate_rule(rule); let rule = language.generalize_rule(&rule.clone()); - let derivable = self.rule_is_derivable(&just_rewrite_egraph, &rule, env_cache); + let derivable = self.rule_is_derivable(&just_rewrite_egraph, &rule); info!("candidate rule: {}", rule); info!("validation result: {:?}", valid); info!("is derivable? {}", if derivable { "yes" } else { "no" }); @@ -757,11 +732,11 @@ impl Chomper for MathChomper { cfg.set_timeout_msec(1000); let ctx = z3::Context::new(&cfg); let solver = z3::Solver::new(&ctx); - let p1 = mathlang_to_z3(&ctx, &MathLang::from(p1.clone())); - let p2 = mathlang_to_z3(&ctx, &MathLang::from(p2.clone())); + let p1 = mathlang_to_z3(&ctx, &p1.clone()); + let p2 = mathlang_to_z3(&ctx, &p2.clone()); let one = z3::ast::Int::from_i64(&ctx, 1); let assert_prog = &z3::ast::Bool::implies(&p1._eq(&one), &p2._eq(&one)); - solver.assert(&assert_prog); + solver.assert(assert_prog); let result = solver.check(); match result { z3::SatResult::Sat => true, diff --git a/src/language.rs b/src/language.rs index 8d41bcc..023c274 100644 --- a/src/language.rs +++ b/src/language.rs @@ -254,7 +254,12 @@ pub trait ChompyLanguage { ;;; term. ;;; i'm not 100% sure why this only runs once per e-class -- it's because ;;; the (eclass ?term) can only be matched on once? -(rule ((eclass ?term)) ((extract "eclass:") (extract (eclass ?term)) (extract "candidate term:") (extract ?term)) :ruleset eclass-report) +(rule ((eclass ?term)) + ((extract "eclass:") + (extract (eclass ?term)) + (extract "candidate term:") + (extract ?term)) + :ruleset eclass-report) "# ); src.to_string() @@ -483,10 +488,7 @@ impl ChompyLanguage for MathLang { } if let Some(cond) = &rule.condition { - if let Some(cached) = env_cache.get(&( - rule.condition.clone().unwrap().to_string(), - rule.lhs.to_string(), - )) { + if let Some(cached) = env_cache.get(&(cond.to_string(), rule.lhs.to_string())) { info!("cache hit for : {:?}", rule); let result: Vec<(Sexp, Sexp)> = cached .iter() @@ -590,14 +592,16 @@ impl ChompyLanguage for MathLang { concretized_rules } + // TODO: change back. fn get_funcs(&self) -> Vec> { vec![ vec![], - vec!["Abs".to_string(), "Neg".to_string()], + // vec!["Abs".to_string(), "Neg".to_string()], + vec!["Neg".to_string()], vec![ "Add".to_string(), "Sub".to_string(), - "Mul".to_string(), + // "Mul".to_string(), "Div".to_string(), "Neq".to_string(), "Gt".to_string(), diff --git a/tests/math/mod.rs b/tests/math/mod.rs index 4e404ba..14e7072 100644 --- a/tests/math/mod.rs +++ b/tests/math/mod.rs @@ -21,7 +21,7 @@ pub mod tests { println!("Predicate: {}", p); } - let rules = chomper.run_chompy(10); + let rules = chomper.run_chompy(6); let hand_picked_rules = vec![ Rule { condition: Sexp::from_str("(Neq ?x (Const 0))").ok(), diff --git a/tests/mod.rs b/tests/mod.rs index 76c5740..8712aa3 100644 --- a/tests/mod.rs +++ b/tests/mod.rs @@ -51,7 +51,7 @@ pub fn evaluate_ruleset( let b = Box::new(language); for rule in other_rules { let rule = transform_rule(rule, &b); - let result = chomper.rule_is_derivable(&egraph, &rule, &mut Default::default()); + let result = chomper.rule_is_derivable(&egraph, &rule); if result { println!("Rule is derivable: {:?}", rule); } else {