Skip to content

Commit

Permalink
checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
yihozhang committed Oct 3, 2024
1 parent f90f61f commit ecb8fb8
Show file tree
Hide file tree
Showing 4 changed files with 246 additions and 232 deletions.
15 changes: 8 additions & 7 deletions src/actions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,7 @@ impl EGraph {
}
MergeFn::Expr(merge_prog) => {
let values = [old_value, new_value];
let mut stack = vec![];
self.run_actions(&mut stack, &values, &merge_prog, true)?;
let mut stack = self.run_actions(&values, &merge_prog, true)?;
stack.pop().unwrap()
}
};
Expand All @@ -243,7 +242,7 @@ impl EGraph {
let values = [old_value, new_value];
// We need to pass a new stack instead of reusing the old one
// because Load(Stack(idx)) use absolute index.
self.run_actions(&mut Vec::new(), &values, &prog, true)?;
self.run_actions(&values, &prog, true)?;
}
}
} else {
Expand All @@ -252,13 +251,15 @@ impl EGraph {
Ok(())
}

/// Runs actions with the given substitution
/// Returns the resulting stack if successful
pub(crate) fn run_actions(
&mut self,
stack: &mut Vec<Value>,
subst: &[Value],
program: &Program,
make_defaults: bool,
) -> Result<(), Error> {
) -> Result<Vec<Value>, Error> {
let mut stack = vec![];
for instr in &program.0 {
match instr {
Instruction::Load(load) => match load {
Expand Down Expand Up @@ -337,7 +338,7 @@ impl EGraph {
let new_value = stack.pop().unwrap();
let new_len = stack.len() - function.schema.input.len();

self.perform_set(*f, new_value, stack)?;
self.perform_set(*f, new_value, &mut stack)?;
stack.truncate(new_len)
}
Instruction::Union(arity) => {
Expand Down Expand Up @@ -423,6 +424,6 @@ impl EGraph {
}
}
}
Ok(())
Ok(stack)
}
}
239 changes: 18 additions & 221 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ mod core;
mod extract;
mod function;
mod gj;
mod scheduler;
mod serialize;
pub mod sort;
mod termdag;
Expand Down Expand Up @@ -64,6 +65,7 @@ pub use value::*;
pub use function::Function;
use function::*;
use gj::*;
use scheduler::Scheduler;
use unionfind::*;
use util::*;

Expand Down Expand Up @@ -684,7 +686,6 @@ impl EGraph {
}

fn apply_merges(&mut self, func: Symbol, merges: &[DeferredMerge]) -> usize {
let mut stack = Vec::new();
let mut function = self.functions.get_mut(&func).unwrap();
let n_unions = self.unionfind.n_unions();
let merge_prog = match &function.merge.merge_vals {
Expand All @@ -694,17 +695,13 @@ impl EGraph {

for (inputs, old, new) in merges {
if let Some(prog) = function.merge.on_merge.clone() {
self.run_actions(&mut stack, &[*old, *new], &prog, true)
.unwrap();
self.run_actions(&[*old, *new], &prog, true).unwrap();
function = self.functions.get_mut(&func).unwrap();
stack.clear();
}
if let Some(prog) = &merge_prog {
// TODO: error handling?
self.run_actions(&mut stack, &[*old, *new], prog, true)
.unwrap();
let mut stack = self.run_actions(&[*old, *new], prog, true).unwrap();
let merged = stack.pop().expect("merges should produce a value");
stack.clear();
function = self.functions.get_mut(&func).unwrap();
function.insert(inputs, merged, self.timestamp);
}
Expand Down Expand Up @@ -852,38 +849,7 @@ impl EGraph {

// returns whether the egraph was updated
fn run_schedule(&mut self, sched: &ResolvedSchedule) -> RunReport {
match sched {
ResolvedSchedule::Run(span, config) => self.run_rules(span, config),
ResolvedSchedule::Repeat(_span, limit, sched) => {
let mut report = RunReport::default();
for _i in 0..*limit {
let rec = self.run_schedule(sched);
report = report.union(&rec);
if !rec.updated {
break;
}
}
report
}
ResolvedSchedule::Saturate(_span, sched) => {
let mut report = RunReport::default();
loop {
let rec = self.run_schedule(sched);
report = report.union(&rec);
if !rec.updated {
break;
}
}
report
}
ResolvedSchedule::Sequence(_span, scheds) => {
let mut report = RunReport::default();
for sched in scheds {
report = report.union(&self.run_schedule(sched));
}
report
}
}
scheduler::SimpleScheduler.run_schedule(self, sched)
}

/// Extract a value to a [`TermDag`] and [`Term`]
Expand All @@ -903,180 +869,6 @@ impl EGraph {
termdag.to_string(&term)
}

fn run_rules(&mut self, span: &Span, config: &ResolvedRunConfig) -> RunReport {
let mut report: RunReport = Default::default();

// first rebuild
let rebuild_start = Instant::now();
let updates = self.rebuild_nofail();
log::debug!("database size: {}", self.num_tuples());
log::debug!("Made {updates} updates");
// add to the rebuild time for this ruleset
report.add_ruleset_rebuild_time(config.ruleset, rebuild_start.elapsed());
self.timestamp += 1;

let GenericRunConfig { ruleset, until } = config;

if let Some(facts) = until {
if self.check_facts(span, facts).is_ok() {
log::info!(
"Breaking early because of facts:\n {}!",
ListDisplay(facts, "\n")
);
return report;
}
}

let subreport = self.step_rules(*ruleset);
report = report.union(&subreport);

log::debug!("database size: {}", self.num_tuples());
self.timestamp += 1;

if self.num_tuples() > self.node_limit {
log::warn!("Node limit reached, {} nodes. Stopping!", self.num_tuples());
}

report
}

/// Search all the rules in a ruleset.
/// Add the search results for a rule to search_results, a map indexed by rule name.
fn search_rules(
&self,
ruleset: Symbol,
run_report: &mut RunReport,
search_results: &mut HashMap<Symbol, SearchResult>,
) {
let rules = self
.rulesets
.get(&ruleset)
.unwrap_or_else(|| panic!("ruleset does not exist: {}", &ruleset));
match rules {
Ruleset::Rules(_ruleset_name, rule_names) => {
let copy_rules = rule_names.clone();
let search_start = Instant::now();

for (rule_name, rule) in copy_rules.iter() {
let mut all_matches = vec![];
let rule_search_start = Instant::now();
let mut did_match = false;
let timestamp = self.rule_last_run_timestamp.get(rule_name).unwrap_or(&0);
self.run_query(&rule.query, *timestamp, false, |values| {
did_match = true;
assert_eq!(values.len(), rule.query.vars.len());
all_matches.extend_from_slice(values);
Ok(())
});
let rule_search_time = rule_search_start.elapsed();
log::trace!(
"Searched for {rule_name} in {:.3}s ({} results)",
rule_search_time.as_secs_f64(),
all_matches.len()
);
run_report.add_rule_search_time(*rule_name, rule_search_time);
search_results.insert(
*rule_name,
SearchResult {
all_matches,
did_match,
},
);
}

let search_time = search_start.elapsed();
run_report.add_ruleset_search_time(ruleset, search_time);
}
Ruleset::Combined(_name, sub_rulesets) => {
let start_time = Instant::now();
for sub_ruleset in sub_rulesets {
self.search_rules(*sub_ruleset, run_report, search_results);
}
let search_time = start_time.elapsed();
run_report.add_ruleset_search_time(ruleset, search_time);
}
}
}

fn apply_rules(
&mut self,
ruleset: Symbol,
run_report: &mut RunReport,
search_results: &HashMap<Symbol, SearchResult>,
) {
// TODO this clone is not efficient
let rules = self.rulesets.get(&ruleset).unwrap().clone();
match rules {
Ruleset::Rules(_name, compiled_rules) => {
let apply_start = Instant::now();
let rule_names = compiled_rules.keys().cloned().collect::<Vec<_>>();
for rule_name in rule_names {
let SearchResult {
all_matches,
did_match,
} = search_results.get(&rule_name).unwrap();
let rule = compiled_rules.get(&rule_name).unwrap();
let num_vars = rule.query.vars.len();

// make sure the query requires matches
if num_vars != 0 {
run_report.add_rule_num_matches(rule_name, all_matches.len() / num_vars);
}

self.rule_last_run_timestamp
.insert(rule_name, self.timestamp);
let rule_apply_start = Instant::now();

let stack = &mut vec![];

// when there are no variables, a query can still fail to match
// here we handle that case
if num_vars == 0 {
if *did_match {
stack.clear();
self.run_actions(stack, &[], &rule.program, true)
.unwrap_or_else(|e| {
panic!("error while running actions for {rule_name}: {e}")
});
}
} else {
for values in all_matches.chunks(num_vars) {
stack.clear();
self.run_actions(stack, values, &rule.program, true)
.unwrap_or_else(|e| {
panic!("error while running actions for {rule_name}: {e}")
});
}
}

// add to the rule's apply time
run_report.add_rule_apply_time(rule_name, rule_apply_start.elapsed());
}
run_report.add_ruleset_apply_time(ruleset, apply_start.elapsed());
}
Ruleset::Combined(_name, sub_rulesets) => {
let start_time = Instant::now();
for sub_ruleset in sub_rulesets {
self.apply_rules(sub_ruleset, run_report, search_results);
}
let apply_time = start_time.elapsed();
run_report.add_ruleset_apply_time(ruleset, apply_time);
}
}
}

fn step_rules(&mut self, ruleset: Symbol) -> RunReport {
let n_unions_before = self.unionfind.n_unions();
let mut run_report = Default::default();
let mut search_results = HashMap::<Symbol, SearchResult>::default();
self.search_rules(ruleset, &mut run_report, &mut search_results);
self.apply_rules(ruleset, &mut run_report, &search_results);
run_report.updated |=
self.did_change_tables() || n_unions_before != self.unionfind.n_unions();

run_report
}

fn did_change_tables(&self) -> bool {
for (_name, function) in &self.functions {
if function.nodes.max_ts() >= self.timestamp {
Expand All @@ -1089,11 +881,10 @@ impl EGraph {

fn add_rule_with_name(
&mut self,
name: String,
name: Symbol,
rule: ast::ResolvedRule,
ruleset: Symbol,
) -> Result<Symbol, Error> {
let name = Symbol::from(name);
let core_rule = rule.to_canonicalized_core_rule(self.type_info())?;
let (query, actions) = (core_rule.body, core_rule.head);

Expand Down Expand Up @@ -1126,8 +917,8 @@ impl EGraph {
&mut self,
rule: ast::ResolvedRule,
ruleset: Symbol,
name: Symbol,
) -> Result<Symbol, Error> {
let name = format!("{}", rule);
self.add_rule_with_name(name, rule, ruleset)
}

Expand All @@ -1140,8 +931,7 @@ impl EGraph {
let program = self
.compile_actions(&Default::default(), &actions)
.map_err(Error::TypeErrors)?;
let mut stack = vec![];
self.run_actions(&mut stack, &[], &program, true)?;
self.run_actions(&[], &program, true)?;
Ok(())
}

Expand Down Expand Up @@ -1172,8 +962,7 @@ impl EGraph {
let program = self
.compile_expr(&Default::default(), &actions, &target)
.map_err(Error::TypeErrors)?;
let mut stack = vec![];
self.run_actions(&mut stack, &[], &program, make_defaults)?;
let mut stack = self.run_actions(&[], &program, make_defaults)?;
Ok(stack.pop().unwrap())
}

Expand Down Expand Up @@ -1288,7 +1077,7 @@ impl EGraph {
rule,
name,
} => {
self.add_rule(rule, ruleset)?;
self.add_rule(rule, ruleset, name)?;
log::info!("Declared rule {name}.")
}
ResolvedNCommand::RunSchedule(sched) => {
Expand Down Expand Up @@ -1605,6 +1394,14 @@ impl EGraph {
pub(crate) fn type_info_mut(&mut self) -> &mut TypeInfo {
&mut self.type_info
}

pub fn get_timestamp(&self) -> u32 {
return self.timestamp;
}

pub fn bump_timestamp(&mut self) {
self.timestamp += 1;
}
}

// Currently, only the following errors can thrown without location information:
Expand Down
Loading

0 comments on commit ecb8fb8

Please sign in to comment.