Skip to content

Commit

Permalink
Unoptimized cvec matching
Browse files Browse the repository at this point in the history
  • Loading branch information
ninehusky committed Feb 1, 2025
1 parent b0ee40a commit b436ebc
Show file tree
Hide file tree
Showing 3 changed files with 233 additions and 77 deletions.
139 changes: 85 additions & 54 deletions src/chomper.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use crate::cvec::CvecSort;
use crate::language::{mathlang_to_z3, MathLang};
use std::fmt::Debug;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::{fmt::Display, str::FromStr, sync::Arc};

use egglog::ast::Span;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use z3::ast::Ast;
Expand Down Expand Up @@ -82,6 +84,10 @@ pub trait Chomper {
fn get_initial_egraph(&self) -> EGraph {
let mut egraph = EGraph::default();

egraph
.add_arcsort(Arc::new(CvecSort::new()), Span::Panic)
.unwrap();

egraph
.parse_and_run_program(None, &self.get_language().to_egglog_src())
.unwrap();
Expand All @@ -99,7 +105,12 @@ pub trait Chomper {
fn add_term(&self, term: &Sexp, egraph: &mut EGraph) {
info!("adding term: {}", term);
let term = format_sexp(term);
let prog = format!("({} {})", UNIVERSAL_RELATION, term);
let prog = format!(
r#"
{term}
(universe {term})
"#
);
egraph.parse_and_run_program(None, &prog).unwrap();
}

Expand Down Expand Up @@ -276,11 +287,19 @@ pub trait Chomper {

let lhs = simple_concretize(&rule.lhs);
let rhs = simple_concretize(&rule.rhs);

if !lhs.to_string().contains("?") && !rhs.to_string().contains("?") {
// we don't want to assert equalities between constants.
return true;
}

const MAX_DERIVABILITY_ITERATIONS: usize = 3;

let mut egraph = initial_egraph.clone();
self.add_term(&lhs, &mut egraph);
self.add_term(&rhs, &mut egraph);
println!("lhs: {}", lhs);
println!("rhs: {}", rhs);
if let Some(cond) = &rule.condition {
let cond = format_sexp(&simple_concretize(cond));
egraph
Expand Down Expand Up @@ -379,9 +398,9 @@ pub trait Chomper {

let sexps = terms.iter().map(|term| Sexp::from_str(term).unwrap());

let mut hasher = DefaultHasher::new();
for sexp in sexps {
let cvec = self.get_language().eval(&sexp, &env);
let mut hasher = DefaultHasher::new();
cvec.hash(&mut hasher);
let hash: u64 = hasher.finish();
if hash == 0 {
Expand All @@ -390,72 +409,77 @@ pub trait Chomper {
sexp
);
}
let hash = format!("{:?}", cvec);
let sexp = format_sexp(&sexp);
egraph
.parse_and_run_program(None, &format!("(HasCvecHash {sexp} {hash})"))
.parse_and_run_program(
None,
&format!("(set (HasCvecHash {sexp}) (SomeCvec \"{hash}\"))"),
)
.unwrap();
}
}

fn get_candidates(&self, egraph: &mut EGraph) -> Vec<Rule> {
let mut candidates: Vec<Rule> = vec![];
let get_cond_candidates_prog = r#"
(run-schedule
(saturate discover-cond-candidates))
"#;
let get_total_candidates_prog = r#"

println!("BEFORE");
let size_info = egraph
.parse_and_run_program(None, r#"(print-size)"#)
.unwrap();
for info in size_info {
println!("{}", info);
}

let get_candidates_prog = r#"
(run-schedule
(saturate discover-cond-candidates)
(saturate discover-total-candidates))
(run-schedule
(saturate print-candidates))
"#;

let cond_candidates = egraph
.parse_and_run_program(None, get_cond_candidates_prog)
let found_candidates = egraph
.parse_and_run_program(None, get_candidates_prog)
.unwrap();
for candidate in cond_candidates {
let sexp = Sexp::from_str(&candidate).unwrap();
if let Sexp::List(l) = sexp {
// strip off the identifier
assert_eq!(
l.len(),
4,
"unexpected length {} of candidate: {:?}",
l.len(),
l
);
assert_eq!(l[0], Sexp::Atom("ConditionalRule".to_string()));
candidates.push(Rule {
condition: Some(l[1].clone()),
lhs: l[2].clone(),
rhs: l[3].clone(),
});
} else {
panic!("Why is your rule not a list? : {}", sexp);
}
}

let total_candidates = egraph
.parse_and_run_program(None, get_total_candidates_prog)
println!("AFTER");
let size_info = egraph
.parse_and_run_program(None, r#"(print-size)"#)
.unwrap();
for info in size_info {
println!("{}", info);
}

for candidate in total_candidates {
for candidate in found_candidates {
let sexp = Sexp::from_str(&candidate).unwrap();
if let Sexp::List(l) = sexp {
// strip off the identifier
assert_eq!(
l.len(),
3,
"unexpected length {} of candidate: {:?}",
l.len(),
l
);
assert_eq!(l[0], Sexp::Atom("TotalRule".to_string()));
candidates.push(Rule {
condition: None,
lhs: l[1].clone(),
rhs: l[2].clone(),
});
if let Sexp::List(ref l) = sexp {
match l.len() {
4 => {
assert_eq!(l[0], Sexp::Atom("ConditionalRule".to_string()));
candidates.push(Rule {
condition: Some(l[1].clone()),
lhs: l[2].clone(),
rhs: l[3].clone(),
});
}
3 => {
assert_eq!(l[0], Sexp::Atom("TotalRule".to_string()));
candidates.push(Rule {
condition: None,
lhs: l[1].clone(),
rhs: l[2].clone(),
});
}
_ => panic!(
"Unexpected length of rule sexpression {:?}: {}",
sexp,
l.len()
),
}
} else {
panic!("Why is your rule not a list? : {}", sexp);
panic!("Unexpected sexp: {:?}", sexp);
}
}

Expand Down Expand Up @@ -495,7 +519,10 @@ pub trait Chomper {
self.add_term(term, &mut egraph);
let sexp = format_sexp(term);
egraph
.parse_and_run_program(None, format!("(set (HasCvecHash {sexp}) 0)").as_str())
.parse_and_run_program(
None,
format!("(set (HasCvecHash {sexp}) (NoneCvec))").as_str(),
)
.unwrap();

max_eclass_id += 1;
Expand All @@ -509,9 +536,13 @@ pub trait Chomper {
self.run_rewrites(&mut egraph, Some(7));
info!("i'm done running rewrites");

println!("assigning cvecs...");
self.assign_cvecs(&mut egraph, &env);
println!("done assigning cvecs");

println!("getting candidates...");
let mut candidates = self.get_candidates(&mut egraph);
println!("done getting candidates");

if candidates.is_empty()
|| candidates
Expand All @@ -531,9 +562,9 @@ pub trait Chomper {
let valid = language.validate_rule(rule);
let rule = language.generalize_rule(&rule.clone());
let derivable = self.rule_is_derivable(&just_rewrite_egraph, &rule);
info!("candidate rule: {}", rule);
info!("validation result: {:?}", valid);
info!("is derivable? {}", if derivable { "yes" } else { "no" });
println!("candidate rule: {}", rule);
println!("validation result: {:?}", valid);
println!("is derivable? {}", if derivable { "yes" } else { "no" });
if valid == ValidationResult::Valid && !derivable {
let rule = language.generalize_rule(&rule.clone());
info!("rule: {}", rule);
Expand Down
Loading

0 comments on commit b436ebc

Please sign in to comment.