diff --git a/src/chomper.rs b/src/chomper.rs index 88acd8a..2b48711 100644 --- a/src/chomper.rs +++ b/src/chomper.rs @@ -1,8 +1,10 @@ +use crate::cvec::CvecSort; use crate::language::{mathlang_to_z3, MathLang}; use std::fmt::Debug; use std::hash::{DefaultHasher, Hash, Hasher}; use std::{fmt::Display, str::FromStr, sync::Arc}; +use egglog::ast::Span; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use z3::ast::Ast; @@ -82,6 +84,10 @@ pub trait Chomper { fn get_initial_egraph(&self) -> EGraph { let mut egraph = EGraph::default(); + egraph + .add_arcsort(Arc::new(CvecSort::new()), Span::Panic) + .unwrap(); + egraph .parse_and_run_program(None, &self.get_language().to_egglog_src()) .unwrap(); @@ -99,7 +105,12 @@ pub trait Chomper { fn add_term(&self, term: &Sexp, egraph: &mut EGraph) { info!("adding term: {}", term); let term = format_sexp(term); - let prog = format!("({} {})", UNIVERSAL_RELATION, term); + let prog = format!( + r#" + {term} + (universe {term}) + "# + ); egraph.parse_and_run_program(None, &prog).unwrap(); } @@ -276,11 +287,19 @@ pub trait Chomper { let lhs = simple_concretize(&rule.lhs); let rhs = simple_concretize(&rule.rhs); + + if !lhs.to_string().contains("?") && !rhs.to_string().contains("?") { + // we don't want to assert equalities between constants. + return true; + } + const MAX_DERIVABILITY_ITERATIONS: usize = 3; let mut egraph = initial_egraph.clone(); self.add_term(&lhs, &mut egraph); self.add_term(&rhs, &mut egraph); + println!("lhs: {}", lhs); + println!("rhs: {}", rhs); if let Some(cond) = &rule.condition { let cond = format_sexp(&simple_concretize(cond)); egraph @@ -379,9 +398,9 @@ pub trait Chomper { let sexps = terms.iter().map(|term| Sexp::from_str(term).unwrap()); - let mut hasher = DefaultHasher::new(); for sexp in sexps { let cvec = self.get_language().eval(&sexp, &env); + let mut hasher = DefaultHasher::new(); cvec.hash(&mut hasher); let hash: u64 = hasher.finish(); if hash == 0 { @@ -390,72 +409,77 @@ pub trait Chomper { sexp ); } + let hash = format!("{:?}", cvec); let sexp = format_sexp(&sexp); egraph - .parse_and_run_program(None, &format!("(HasCvecHash {sexp} {hash})")) + .parse_and_run_program( + None, + &format!("(set (HasCvecHash {sexp}) (SomeCvec \"{hash}\"))"), + ) .unwrap(); } } fn get_candidates(&self, egraph: &mut EGraph) -> Vec { let mut candidates: Vec = vec![]; - let get_cond_candidates_prog = r#" - (run-schedule - (saturate discover-cond-candidates)) - "#; - let get_total_candidates_prog = r#" + + println!("BEFORE"); + let size_info = egraph + .parse_and_run_program(None, r#"(print-size)"#) + .unwrap(); + for info in size_info { + println!("{}", info); + } + + let get_candidates_prog = r#" (run-schedule + (saturate discover-cond-candidates) (saturate discover-total-candidates)) + + (run-schedule + (saturate print-candidates)) "#; - let cond_candidates = egraph - .parse_and_run_program(None, get_cond_candidates_prog) + let found_candidates = egraph + .parse_and_run_program(None, get_candidates_prog) .unwrap(); - for candidate in cond_candidates { - let sexp = Sexp::from_str(&candidate).unwrap(); - if let Sexp::List(l) = sexp { - // strip off the identifier - assert_eq!( - l.len(), - 4, - "unexpected length {} of candidate: {:?}", - l.len(), - l - ); - assert_eq!(l[0], Sexp::Atom("ConditionalRule".to_string())); - candidates.push(Rule { - condition: Some(l[1].clone()), - lhs: l[2].clone(), - rhs: l[3].clone(), - }); - } else { - panic!("Why is your rule not a list? : {}", sexp); - } - } - let total_candidates = egraph - .parse_and_run_program(None, get_total_candidates_prog) + println!("AFTER"); + let size_info = egraph + .parse_and_run_program(None, r#"(print-size)"#) .unwrap(); + for info in size_info { + println!("{}", info); + } - for candidate in total_candidates { + for candidate in found_candidates { let sexp = Sexp::from_str(&candidate).unwrap(); - if let Sexp::List(l) = sexp { - // strip off the identifier - assert_eq!( - l.len(), - 3, - "unexpected length {} of candidate: {:?}", - l.len(), - l - ); - assert_eq!(l[0], Sexp::Atom("TotalRule".to_string())); - candidates.push(Rule { - condition: None, - lhs: l[1].clone(), - rhs: l[2].clone(), - }); + if let Sexp::List(ref l) = sexp { + match l.len() { + 4 => { + assert_eq!(l[0], Sexp::Atom("ConditionalRule".to_string())); + candidates.push(Rule { + condition: Some(l[1].clone()), + lhs: l[2].clone(), + rhs: l[3].clone(), + }); + } + 3 => { + assert_eq!(l[0], Sexp::Atom("TotalRule".to_string())); + candidates.push(Rule { + condition: None, + lhs: l[1].clone(), + rhs: l[2].clone(), + }); + } + _ => panic!( + "Unexpected length of rule sexpression {:?}: {}", + sexp, + l.len() + ), + } } else { - panic!("Why is your rule not a list? : {}", sexp); + panic!("Unexpected sexp: {:?}", sexp); } } @@ -495,7 +519,10 @@ pub trait Chomper { self.add_term(term, &mut egraph); let sexp = format_sexp(term); egraph - .parse_and_run_program(None, format!("(set (HasCvecHash {sexp}) 0)").as_str()) + .parse_and_run_program( + None, + format!("(set (HasCvecHash {sexp}) (NoneCvec))").as_str(), + ) .unwrap(); max_eclass_id += 1; @@ -509,9 +536,13 @@ pub trait Chomper { self.run_rewrites(&mut egraph, Some(7)); info!("i'm done running rewrites"); + println!("assigning cvecs..."); self.assign_cvecs(&mut egraph, &env); + println!("done assigning cvecs"); + println!("getting candidates..."); let mut candidates = self.get_candidates(&mut egraph); + println!("done getting candidates"); if candidates.is_empty() || candidates @@ -531,9 +562,9 @@ pub trait Chomper { let valid = language.validate_rule(rule); let rule = language.generalize_rule(&rule.clone()); let derivable = self.rule_is_derivable(&just_rewrite_egraph, &rule); - info!("candidate rule: {}", rule); - info!("validation result: {:?}", valid); - info!("is derivable? {}", if derivable { "yes" } else { "no" }); + println!("candidate rule: {}", rule); + println!("validation result: {:?}", valid); + println!("is derivable? {}", if derivable { "yes" } else { "no" }); if valid == ValidationResult::Valid && !derivable { let rule = language.generalize_rule(&rule.clone()); info!("rule: {}", rule); diff --git a/src/cvec.rs b/src/cvec.rs index b5f91b8..81a3b22 100644 --- a/src/cvec.rs +++ b/src/cvec.rs @@ -1,11 +1,15 @@ use std::sync::{Arc, Mutex}; +use std::str::FromStr; + +use ruler::enumo::Sexp; + use egglog::{ - ast::Symbol, + ast::{Span, Symbol}, constraint::SimpleTypeConstraint, sort::{FromSort, Sort, StringSort}, util::IndexMap, - PrimitiveLike, Value, + EGraph, PrimitiveLike, Value, }; const NONE_CVEC_IDENTIFIER: &str = "NoneCvec"; @@ -13,6 +17,12 @@ const NONE_CVEC_IDENTIFIER: &str = "NoneCvec"; type CvecMap = IndexMap; /// Wow, I can't believe we're back to creating this! +pub fn to_cvec_val(s: Symbol) -> Value { + Value { + tag: "Cvec".into(), + bits: Value::from(s).bits, + } +} #[derive(Debug)] pub struct CvecSort { @@ -22,8 +32,8 @@ pub struct CvecSort { impl CvecSort { pub fn new() -> Self { let mut cvecs: CvecMap = Default::default(); - let none_sym: Symbol = NONE_CVEC_IDENTIFIER.into(); - cvecs.insert(Value::from(none_sym), NONE_CVEC_IDENTIFIER.to_string()); + let val = to_cvec_val(NONE_CVEC_IDENTIFIER.into()); + cvecs.insert(val, NONE_CVEC_IDENTIFIER.to_string()); let cvecs = Mutex::new(cvecs); Self { cvecs } } @@ -43,11 +53,19 @@ impl Sort for CvecSort { fn extract_term( &self, _egraph: &egglog::EGraph, - _value: egglog::Value, + value: egglog::Value, _extractor: &egglog::extract::Extractor, - _termdag: &mut egglog::TermDag, + termdag: &mut egglog::TermDag, ) -> Option<(egglog::extract::Cost, egglog::Term)> { - todo!() + let cvecs = self.cvecs.lock().unwrap(); + let val: String = cvecs.get(&value).unwrap().to_string(); + + if val == NONE_CVEC_IDENTIFIER { + Some((1, termdag.app(Symbol::from("NoneCvec"), vec![]))) + } else { + let args = vec![termdag.lit(egglog::ast::Literal::String(Symbol::from(val)))]; + Some((1, termdag.app(Symbol::from("SomeCvec"), args))) + } } fn register_primitives(self: std::sync::Arc, info: &mut egglog::TypeInfo) { @@ -63,7 +81,7 @@ struct NoneCvec { impl PrimitiveLike for NoneCvec { fn name(&self) -> Symbol { - "NoneCvec".into() + NONE_CVEC_IDENTIFIER.into() } fn get_type_constraints( @@ -80,8 +98,8 @@ impl PrimitiveLike for NoneCvec { _egraph: Option<&mut egglog::EGraph>, ) -> Option { assert!(values.is_empty()); - let none_sym: Symbol = "None".into(); - Some(Value::from(none_sym)) + let none_cvec_value = to_cvec_val(NONE_CVEC_IDENTIFIER.into()); + Some(none_cvec_value) } } @@ -115,14 +133,13 @@ impl PrimitiveLike for SomeCvec { ) -> Option { assert_eq!(values.len(), 1); let cvecs = &mut *self.sort.cvecs.lock().unwrap(); - let param = Symbol::load(&StringSort, &values[0]).to_string(); - println!("param: {}", param); - if param == "NoneCvec" { + let param_str = Symbol::load(&egglog::sort::StringSort, &values[0]).to_string(); + if param_str == "NoneCvec" { panic!("SomeCvec called on NoneCvec"); } - cvecs.insert(values[0].clone(), param.clone()); - let string_sym: Symbol = param.into(); - Some(Value::from(string_sym)) + let param = to_cvec_val(param_str.clone().into()); + cvecs.insert(param, param_str); + Some(param) } } @@ -165,3 +182,109 @@ impl PrimitiveLike for MergeCvecs { } } } + +#[test] +pub fn none_ctor() { + let mut egraph = EGraph::default(); + egraph + .add_arcsort(Arc::new(CvecSort::new()), Span::Panic) + .unwrap(); + + let res = egraph + .parse_and_run_program( + None, + r#" + (let x (NoneCvec)) + (let y (NoneCvec)) + (check (= x y)) + (extract x) + "#, + ) + .unwrap(); + + assert_eq!(res.len(), 1); + assert_eq!( + Sexp::from_str(res[0].as_str()).unwrap(), + Sexp::from_str("(NoneCvec)").unwrap() + ); +} + +#[test] +pub fn some_ctor() { + let mut egraph = EGraph::default(); + egraph + .add_arcsort(Arc::new(CvecSort::new()), Span::Panic) + .unwrap(); + + let res = egraph + .parse_and_run_program( + None, + r#" + (let x (SomeCvec "hello")) + (let y (SomeCvec "hello")) + (let z (NoneCvec)) + (check (= x y)) + (check (!= x z)) + (extract x) + "#, + ) + .unwrap(); + + assert_eq!(res.len(), 1); + assert_eq!( + Sexp::from_str(res[0].as_str()).unwrap(), + Sexp::from_str("(SomeCvec \"hello\")").unwrap() + ); +} + +#[test] +pub fn merge_ctor() { + let mut egraph = EGraph::default(); + egraph + .add_arcsort(Arc::new(CvecSort::new()), Span::Panic) + .unwrap(); + + let res = egraph + .parse_and_run_program( + None, + r#" + (let x (NoneCvec)) + (let y (SomeCvec "hello")) + (let z (MergeCvecs x y)) + + + (datatype Term + (Const i64)) + + (function HasCvecHash (Term) Cvec :merge (MergeCvecs old new)) + + (let t1 (Const 1)) + (let t2 (Const 2)) + + (set (HasCvecHash t1) (SomeCvec "[Some(1), Some(1), Some(1), Some(1), Some(1), Some(1), Some(1), Some(1), Some(1), Some(1)]")) + (set (HasCvecHash t2) (SomeCvec "[Some(1), Some(1), Some(1), Some(1), Some(1), Some(1), Some(1), Some(1), Some(1), Some(1)]")) + + (rule + ((= (HasCvecHash ?t1) (HasCvecHash ?t2)) + (!= ?t1 ?t2)) + ((extract ?t1))) + + (run 10) + + (check (= (HasCvecHash t1) (HasCvecHash t2))) + + (check (= y z)) + (check (!= x z)) + (extract z) + "#, + ) + .unwrap(); + + println!("res: {:?}", res); + + assert_eq!(res.len(), 1); + assert_eq!( + Sexp::from_str(res[0].as_str()).unwrap(), + Sexp::from_str("(SomeCvec \"hello\")").unwrap() + ); +} diff --git a/src/language.rs b/src/language.rs index 2cedae3..a4e6dbc 100644 --- a/src/language.rs +++ b/src/language.rs @@ -232,6 +232,8 @@ pub trait ChompyLanguage { (Condition {name})) +(relation GoodCvec Cvec) + ;;; note that these are NOT rewrite rules; ;;; they're just likely candidates for rewrite rules. (datatype CandidateRule @@ -243,8 +245,8 @@ pub trait ChompyLanguage { ;;; cvec = i64 is not great, because if a cvec hashes to 0 (our default value for "no cvec"), then ;;; we're in a pickle. but i don't think we'll run into that issue. on the rust side, we just need ;;; to assert that hash(cvec) != 0. -(function HasCvecHash ({name}) i64 :merge (max old new)) -(relation ConditionallyEqual (Predicate i64 i64)) +(function HasCvecHash ({name}) Cvec :merge (MergeCvecs old new)) +(relation ConditionallyEqual (Predicate Cvec Cvec)) (relation universe ({name})) (relation cond-equal ({name} {name})) @@ -254,7 +256,7 @@ pub trait ChompyLanguage { ;;; extract the terms that don't have a cvec. (rule - ((= (HasCvecHash ?a) 0)) + ((= (HasCvecHash ?a) (NoneCvec))) ((extract ?a)) :ruleset find-no-cvec-terms) @@ -411,7 +413,8 @@ impl ChompyLanguage for MathLang { } fn get_vals(&self) -> Vec { - vec![-1, 0, 1] + vec![] + // TODO: change back to: vec![-1, 0, 1] } fn get_vars(&self) -> Vec { @@ -471,12 +474,11 @@ impl ChompyLanguage for MathLang { fn get_funcs(&self) -> Vec> { vec![ vec![], - // vec!["Abs".to_string(), "Neg".to_string()], - vec!["Neg".to_string()], + vec!["Abs".to_string(), "Neg".to_string()], vec![ "Add".to_string(), "Sub".to_string(), - // "Mul".to_string(), + "Mul".to_string(), "Div".to_string(), "Neq".to_string(), "Gt".to_string(),