Skip to content

Commit

Permalink
Sort rules by size, add them in ascending order
Browse files Browse the repository at this point in the history
  • Loading branch information
ninehusky committed Jan 21, 2025
1 parent cd9f425 commit 7c85309
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 8 deletions.
54 changes: 46 additions & 8 deletions src/chomper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use ruler::{

const UNIVERSAL_RELATION: &str = "universe";

#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Rule {
pub condition: Option<Sexp>,
pub lhs: Sexp,
Expand All @@ -36,6 +36,43 @@ impl Display for Rule {
}
}

impl Rule {
pub fn cost(&self) -> usize {
fn cost(sexp: &Sexp) -> usize {
match sexp {
Sexp::Atom(a) => {
// prioritize variables over other things.
if a == "Var" {
1
} else {
2
}
}
Sexp::List(l) => l.iter().map(cost).sum(),
}
}

cost(&self.lhs)
+ cost(&self.rhs)
+ match &self.condition {
None => 0,
Some(cond) => cost(cond),
}
}
}

impl PartialOrd for Rule {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}

impl Ord for Rule {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.cost().cmp(&other.cost())
}
}

/// Chompers manage the state of the e-graph.
pub trait Chomper {
type Constant: Display + Debug + Clone + PartialEq;
Expand Down Expand Up @@ -230,13 +267,13 @@ pub trait Chomper {
// self.run_rewrites(&mut egraph, None);
let result =
egraph.parse_and_run_program(None, &format!("(check (= {l_sexpr} {r_sexpr}))"));
if result.is_ok() {
// the existing ruleset was able to derive the equality.
return true;
if !result.is_ok() {
// the existing ruleset was unable to derive the equality.
return false;
}
}
// the existing ruleset failed to derive the equality on all the given examples.
false
// the existing ruleset was able to derive the equality on all the given examples.
true
}

fn add_rewrite(&self, egraph: &mut EGraph, rule: &Rule) {
Expand Down Expand Up @@ -336,7 +373,7 @@ pub trait Chomper {
self.run_rewrites(&mut egraph, None);
info!("i'm done running rewrites");

let candidates = self
let mut candidates = self
.cvec_match(&mut egraph, &pvecs, &env)
.into_iter()
.filter(all_variables_bound)
Expand All @@ -353,6 +390,7 @@ pub trait Chomper {
for rule in rules.iter() {
self.add_rewrite(&mut just_rewrite_egraph, rule);
}
candidates.sort_by(|a, b| a.cmp(b));
for rule in &candidates[..] {
let valid = language.validate_rule(rule);
let rule = language.generalize_rule(&rule.clone());
Expand All @@ -363,7 +401,7 @@ pub trait Chomper {
if valid == ValidationResult::Valid && !derivable {
let rule = language.generalize_rule(&rule.clone());
info!("rule: {}", rule);
println!("rule: {}", rule);
println!("RULE IS: {}", rule);
rules.push(rule.clone());
self.add_rewrite(&mut egraph, &rule);
self.add_rewrite(&mut just_rewrite_egraph, &rule);
Expand Down
1 change: 1 addition & 0 deletions src/language.rs
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,7 @@ impl ChompyLanguage for MathLang {
concretized_rules.push((subst(&rule.lhs, &env), subst(&rule.rhs, &env)));
}
}
info!("concretized rules for {}: {:?}", rule, concretized_rules);
concretized_rules
}

Expand Down

0 comments on commit 7c85309

Please sign in to comment.