Skip to content

Commit

Permalink
Update Egglog version
Browse files Browse the repository at this point in the history
  • Loading branch information
ninehusky committed Jan 31, 2025
1 parent f52e2b4 commit f04b2f7
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 174 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ version = "0.1.0"
edition = "2021"

[dependencies]
egglog = "0.3.0"
egglog = "0.4.0"
egraph-serialize = "0.2.0"
env_logger = "0.11.5"
indexmap = "2.6.0"
Expand Down
284 changes: 127 additions & 157 deletions src/chomper.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use crate::language::{mathlang_to_z3, MathLang};
use crate::PredicateInterpreter;
use std::fmt::Debug;
use std::hash::Hash;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::{fmt::Display, str::FromStr, sync::Arc};

use rand::rngs::StdRng;
Expand Down Expand Up @@ -82,11 +81,7 @@ pub trait Chomper {
/// `get_language()`.
fn get_initial_egraph(&self) -> EGraph {
let mut egraph = EGraph::default();
let sort = Arc::new(EqSort {
name: self.get_language().get_name().into(),
});

egraph.add_arcsort(sort.clone()).unwrap();
egraph
.parse_and_run_program(None, &self.get_language().to_egglog_src())
.unwrap();
Expand All @@ -101,16 +96,11 @@ pub trait Chomper {

/// Adds the given term to the e-graph.
/// Optionally, sets the eclass id of the term to the given id.
fn add_term(&self, term: &Sexp, egraph: &mut EGraph, eclass_id: Option<usize>) {
fn add_term(&self, term: &Sexp, egraph: &mut EGraph) {
info!("adding term: {}", term);
let term = format_sexp(term);
let prog = format!("({} {})", UNIVERSAL_RELATION, term);
egraph.parse_and_run_program(None, &prog).unwrap();
if let Some(id) = eclass_id {
let prog = format!("(set (eclass {term}) {id})", term = term, id = id);
info!("running program: {}", prog);
egraph.parse_and_run_program(None, &prog).unwrap();
}
}

/// Runs the existing set of `total-rewrites` and `cond-rewrites` in the e-graph
Expand Down Expand Up @@ -145,32 +135,6 @@ pub trait Chomper {
log_rewrite_stats(results);
}

/// Returns a map from e-class id to a candidate term in the e-class.
fn get_eclass_term_map(&self, egraph: &mut EGraph) -> HashMap<usize, Sexp> {
let eclass_report_prog = r#"
(push)
(run-schedule
(saturate eclass-report))
(pop)
"#;

let mut outputs = egraph
.parse_and_run_program(None, eclass_report_prog)
.unwrap()
.into_iter()
.peekable();

let mut eclass_term_map = HashMap::default();
while outputs.peek().is_some() {
outputs.next().unwrap();
let eclass = outputs.next().unwrap().to_string().parse::<i64>().unwrap();
outputs.next().unwrap();
let term = Sexp::from_str(&outputs.next().unwrap()).unwrap();
eclass_term_map.insert(eclass.try_into().unwrap(), term);
}
eclass_term_map
}

fn validate_implication(&self, p1: &Sexp, p2: &Sexp) -> bool;

/// Returns a vector of Egglog rules, i.e., Egglog programs.
Expand Down Expand Up @@ -236,121 +200,24 @@ pub trait Chomper {
}

/// Returns a vector of candidate rules between e-classes in the e-graph.
/// It better be the case that by the time we're calling this,
/// that every term in the e-graph has a Cvec associated with it.
fn cvec_match(
&self,
egraph: &mut EGraph,
predicate_map: &HashMap<Vec<bool>, Vec<Sexp>>,
// Yeah.
cvecs: &HashMap<i64, CVec<dyn ChompyLanguage<Constant = Self::Constant>>>,
env: &HashMap<String, CVec<dyn ChompyLanguage<Constant = Self::Constant>>>,
) -> Vec<Rule> {
let eclass_term_map = self.get_eclass_term_map(egraph);
let mut candidate_rules = vec![];
let ec_keys: Vec<&usize> = eclass_term_map.keys().collect();

info!("number of e-classes: {}", ec_keys.len());

for i in 0..ec_keys.len() {
let ec1 = ec_keys[i];
let term1 = eclass_term_map.get(ec1).unwrap();
// TODO:
// if term1 = (Const x) {
// continue;
// }
if let Sexp::List(l) = term1 {
if l[0] == Sexp::Atom("Const".to_string()) {
continue;
}
}
let cvec1 = self.get_language().eval(term1, env);
// if all terms in the cvec are equal, cotinue.
if cvec1.iter().all(|x| x == cvec1.first().unwrap()) {
continue;
}
for ec2 in ec_keys.iter().skip(i + 1) {
let term2 = eclass_term_map.get(ec2).unwrap();
if let Sexp::List(l) = term2 {
if l[0] == Sexp::Atom("Const".to_string()) {
continue;
}
}
let cvec2 = self.get_language().eval(term2, env);
if cvec1 == cvec2 {
// we add (l ~> r) and (r ~> l) as candidate rules, because
// we don't know which ordering is "sound" according to
// variable binding -- see `all_variables_bound`.
candidate_rules.push(Rule {
condition: None,
lhs: term1.clone(),
rhs: term2.clone(),
});
candidate_rules.push(Rule {
condition: None,
lhs: term2.clone(),
rhs: term1.clone(),
});
} else {
let mask = cvec1
.iter()
.zip(cvec2.iter())
.map(|(a, b)| a == b)
.collect::<Vec<bool>>();
// if they never match, we can't generate a rule.
if mask.iter().all(|x| *x) {
panic!("cvec1 != cvec2, yet we have a mask of all true");
}

if mask.iter().all(|x| !x) {
continue;
}

// if under the mask, all of cvec1 is the same value, then skip.
let cvec1_vals_under_pred = cvec1
.iter()
.zip(mask.iter())
.filter(|(_, &b)| b)
.map(|(x, _)| x.clone())
.collect::<Vec<_>>();

let cvec2_vals_under_pred = cvec2
.iter()
.zip(mask.iter())
.filter(|(_, &b)| b)
.map(|(x, _)| x.clone())
.collect::<Vec<_>>();

// TODO: make this happen conditionally via a flag
// get num of unique values under the predicate.
let num_unique_vals =
cvec1_vals_under_pred.iter().collect::<HashSet<_>>().len();

if num_unique_vals == 1 {
continue;
}

let num_unique_vals =
cvec2_vals_under_pred.iter().collect::<HashSet<_>>().len();

if num_unique_vals == 1 {
continue;
}
let find_cvec_prog = r#"
(run-schedule
(saturate discover-candidates))
"#;
let terms = egraph.parse_and_run_program(None, find_cvec_prog).unwrap();
// TODO: let's just see what's in "terms"

if let Some(preds) = predicate_map.get(&mask) {
for pred in preds {
candidate_rules.push(Rule {
condition: Some(pred.clone()),
lhs: term1.clone(),
rhs: term2.clone(),
});
candidate_rules.push(Rule {
condition: Some(pred.clone()),
lhs: term2.clone(),
rhs: term1.clone(),
});
}
}
}
}
}
candidate_rules
vec![]
}

/// Returns a map from variable names to their values.
Expand Down Expand Up @@ -412,8 +279,8 @@ pub trait Chomper {
const MAX_DERIVABILITY_ITERATIONS: usize = 3;

let mut egraph = initial_egraph.clone();
self.add_term(&lhs, &mut egraph, None);
self.add_term(&rhs, &mut egraph, None);
self.add_term(&lhs, &mut egraph);
self.add_term(&rhs, &mut egraph);
if let Some(cond) = &rule.condition {
let cond = format_sexp(&simple_concretize(cond));
egraph
Expand Down Expand Up @@ -497,6 +364,104 @@ pub trait Chomper {
result
}

fn assign_cvecs(
&self,
egraph: &mut EGraph,
env: &HashMap<String, CVec<dyn ChompyLanguage<Constant = Self::Constant>>>,
) -> () {
let get_unassigned_terms_prog = r#"
(run-schedule
(saturate find-no-cvec-terms))
"#;
let terms = egraph
.parse_and_run_program(None, get_unassigned_terms_prog)
.unwrap();

let sexps = terms.iter().map(|term| Sexp::from_str(term).unwrap());

let mut hasher = DefaultHasher::new();
for sexp in sexps {
let cvec = self.get_language().eval(&sexp, &env);
cvec.hash(&mut hasher);
let hash: u64 = hasher.finish();
if hash == 0 {
panic!(
"Can't have cvec hash of 0. We just can't. Term was: {}",
sexp
);
}
let sexp = format_sexp(&sexp);
egraph
.parse_and_run_program(None, &format!("(HasCvecHash {sexp} {hash})"))
.unwrap();
}
}

fn get_candidates(&self, egraph: &mut EGraph) -> Vec<Rule> {
let mut candidates: Vec<Rule> = vec![];
let get_cond_candidates_prog = r#"
(run-schedule
(saturate discover-cond-candidates))
"#;
let get_total_candidates_prog = r#"
(run-schedule
(saturate discover-total-candidates))
"#;

let cond_candidates = egraph
.parse_and_run_program(None, get_cond_candidates_prog)
.unwrap();
for candidate in cond_candidates {
let sexp = Sexp::from_str(&candidate).unwrap();
if let Sexp::List(l) = sexp {
// strip off the identifier
assert_eq!(
l.len(),
4,
"unexpected length {} of candidate: {:?}",
l.len(),
l
);
assert_eq!(l[0], Sexp::Atom("ConditionalRule".to_string()));
candidates.push(Rule {
condition: Some(l[1].clone()),
lhs: l[2].clone(),
rhs: l[3].clone(),
});
} else {
panic!("Why is your rule not a list? : {}", sexp);
}
}

let total_candidates = egraph
.parse_and_run_program(None, get_total_candidates_prog)
.unwrap();

for candidate in total_candidates {
let sexp = Sexp::from_str(&candidate).unwrap();
if let Sexp::List(l) = sexp {
// strip off the identifier
assert_eq!(
l.len(),
3,
"unexpected length {} of candidate: {:?}",
l.len(),
l
);
assert_eq!(l[0], Sexp::Atom("TotalRule".to_string()));
candidates.push(Rule {
condition: None,
lhs: l[1].clone(),
rhs: l[2].clone(),
});
} else {
panic!("Why is your rule not a list? : {}", sexp);
}
}

candidates
}

fn run_chompy(&self, max_size: usize) -> Vec<Rule> {
const MAX_ECLASS_ID: usize = 6000;
let mut egraph = self.get_initial_egraph();
Expand All @@ -506,9 +471,11 @@ pub trait Chomper {
let mut rules: Vec<Rule> = vec![];
let atoms = language.base_atoms();
let pvecs = self.get_pvecs(&env);
for term in atoms.force() {
self.add_term(&term, &mut egraph, None);
}
// TODO: we should remove commented out portion entirely; the idea of
// new workload is that it should be the set union of atoms and new stuff.
// for term in atoms.force() {
// self.add_term(&term, &mut egraph, None);
// }
let mut max_eclass_id: usize = 1;

// chompy does not consider terms with constants as subterms after programs
Expand All @@ -525,7 +492,11 @@ pub trait Chomper {

println!("workload len: {}", new_workload.force().len());
for term in &new_workload.force() {
self.add_term(term, &mut egraph, Some(max_eclass_id));
self.add_term(term, &mut egraph);
let sexp = format_sexp(term);
egraph
.parse_and_run_program(None, format!("(set (HasCvecHash {sexp}) 0)").as_str())
.unwrap();

max_eclass_id += 1;
if max_eclass_id > MAX_ECLASS_ID {
Expand All @@ -538,11 +509,10 @@ pub trait Chomper {
self.run_rewrites(&mut egraph, Some(7));
info!("i'm done running rewrites");

let mut candidates = self
.cvec_match(&mut egraph, &pvecs, &env)
.into_iter()
.filter(all_variables_bound)
.collect::<Vec<Rule>>();
self.assign_cvecs(&mut egraph, &env);

let mut candidates = self.get_candidates(&mut egraph);

if candidates.is_empty()
|| candidates
.iter()
Expand Down
17 changes: 17 additions & 0 deletions src/cvec.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
use egglog::{ast::Symbol, sort::Sort};

/// Wow, I can't believe we're back to creating this!
pub struct CvecSort {}

// impl Sort for CvecSort {
// fn name(&self) -> Symbol {
// "Cvec".into()
// }
//
// fn as_arc_any(
// self: std::sync::Arc<Self>,
// ) -> std::sync::Arc<dyn std::any::Any + Send + Sync + 'static> {
// self
// }
// }
Loading

0 comments on commit f04b2f7

Please sign in to comment.