From ecb8fb8ff951b2e5bebb6de9de5f331ce0f612a0 Mon Sep 17 00:00:00 2001 From: Yihong Zhang Date: Wed, 2 Oct 2024 23:41:54 -0700 Subject: [PATCH] checkpoint --- src/actions.rs | 15 +-- src/lib.rs | 239 ++++------------------------------------------- src/scheduler.rs | 219 +++++++++++++++++++++++++++++++++++++++++++ src/sort/fn.rs | 5 +- 4 files changed, 246 insertions(+), 232 deletions(-) create mode 100644 src/scheduler.rs diff --git a/src/actions.rs b/src/actions.rs index a14f7448..b42c9540 100644 --- a/src/actions.rs +++ b/src/actions.rs @@ -227,8 +227,7 @@ impl EGraph { } MergeFn::Expr(merge_prog) => { let values = [old_value, new_value]; - let mut stack = vec![]; - self.run_actions(&mut stack, &values, &merge_prog, true)?; + let mut stack = self.run_actions(&values, &merge_prog, true)?; stack.pop().unwrap() } }; @@ -243,7 +242,7 @@ impl EGraph { let values = [old_value, new_value]; // We need to pass a new stack instead of reusing the old one // because Load(Stack(idx)) use absolute index. - self.run_actions(&mut Vec::new(), &values, &prog, true)?; + self.run_actions(&values, &prog, true)?; } } } else { @@ -252,13 +251,15 @@ impl EGraph { Ok(()) } + /// Runs actions with the given substitution + /// Returns the resulting stack if successful pub(crate) fn run_actions( &mut self, - stack: &mut Vec, subst: &[Value], program: &Program, make_defaults: bool, - ) -> Result<(), Error> { + ) -> Result, Error> { + let mut stack = vec![]; for instr in &program.0 { match instr { Instruction::Load(load) => match load { @@ -337,7 +338,7 @@ impl EGraph { let new_value = stack.pop().unwrap(); let new_len = stack.len() - function.schema.input.len(); - self.perform_set(*f, new_value, stack)?; + self.perform_set(*f, new_value, &mut stack)?; stack.truncate(new_len) } Instruction::Union(arity) => { @@ -423,6 +424,6 @@ impl EGraph { } } } - Ok(()) + Ok(stack) } } diff --git a/src/lib.rs b/src/lib.rs index a2a178d2..9cd49fe5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,6 +18,7 @@ mod core; mod extract; mod function; mod gj; +mod scheduler; mod serialize; pub mod sort; mod termdag; @@ -64,6 +65,7 @@ pub use value::*; pub use function::Function; use function::*; use gj::*; +use scheduler::Scheduler; use unionfind::*; use util::*; @@ -684,7 +686,6 @@ impl EGraph { } fn apply_merges(&mut self, func: Symbol, merges: &[DeferredMerge]) -> usize { - let mut stack = Vec::new(); let mut function = self.functions.get_mut(&func).unwrap(); let n_unions = self.unionfind.n_unions(); let merge_prog = match &function.merge.merge_vals { @@ -694,17 +695,13 @@ impl EGraph { for (inputs, old, new) in merges { if let Some(prog) = function.merge.on_merge.clone() { - self.run_actions(&mut stack, &[*old, *new], &prog, true) - .unwrap(); + self.run_actions(&[*old, *new], &prog, true).unwrap(); function = self.functions.get_mut(&func).unwrap(); - stack.clear(); } if let Some(prog) = &merge_prog { // TODO: error handling? - self.run_actions(&mut stack, &[*old, *new], prog, true) - .unwrap(); + let mut stack = self.run_actions(&[*old, *new], prog, true).unwrap(); let merged = stack.pop().expect("merges should produce a value"); - stack.clear(); function = self.functions.get_mut(&func).unwrap(); function.insert(inputs, merged, self.timestamp); } @@ -852,38 +849,7 @@ impl EGraph { // returns whether the egraph was updated fn run_schedule(&mut self, sched: &ResolvedSchedule) -> RunReport { - match sched { - ResolvedSchedule::Run(span, config) => self.run_rules(span, config), - ResolvedSchedule::Repeat(_span, limit, sched) => { - let mut report = RunReport::default(); - for _i in 0..*limit { - let rec = self.run_schedule(sched); - report = report.union(&rec); - if !rec.updated { - break; - } - } - report - } - ResolvedSchedule::Saturate(_span, sched) => { - let mut report = RunReport::default(); - loop { - let rec = self.run_schedule(sched); - report = report.union(&rec); - if !rec.updated { - break; - } - } - report - } - ResolvedSchedule::Sequence(_span, scheds) => { - let mut report = RunReport::default(); - for sched in scheds { - report = report.union(&self.run_schedule(sched)); - } - report - } - } + scheduler::SimpleScheduler.run_schedule(self, sched) } /// Extract a value to a [`TermDag`] and [`Term`] @@ -903,180 +869,6 @@ impl EGraph { termdag.to_string(&term) } - fn run_rules(&mut self, span: &Span, config: &ResolvedRunConfig) -> RunReport { - let mut report: RunReport = Default::default(); - - // first rebuild - let rebuild_start = Instant::now(); - let updates = self.rebuild_nofail(); - log::debug!("database size: {}", self.num_tuples()); - log::debug!("Made {updates} updates"); - // add to the rebuild time for this ruleset - report.add_ruleset_rebuild_time(config.ruleset, rebuild_start.elapsed()); - self.timestamp += 1; - - let GenericRunConfig { ruleset, until } = config; - - if let Some(facts) = until { - if self.check_facts(span, facts).is_ok() { - log::info!( - "Breaking early because of facts:\n {}!", - ListDisplay(facts, "\n") - ); - return report; - } - } - - let subreport = self.step_rules(*ruleset); - report = report.union(&subreport); - - log::debug!("database size: {}", self.num_tuples()); - self.timestamp += 1; - - if self.num_tuples() > self.node_limit { - log::warn!("Node limit reached, {} nodes. Stopping!", self.num_tuples()); - } - - report - } - - /// Search all the rules in a ruleset. - /// Add the search results for a rule to search_results, a map indexed by rule name. - fn search_rules( - &self, - ruleset: Symbol, - run_report: &mut RunReport, - search_results: &mut HashMap, - ) { - let rules = self - .rulesets - .get(&ruleset) - .unwrap_or_else(|| panic!("ruleset does not exist: {}", &ruleset)); - match rules { - Ruleset::Rules(_ruleset_name, rule_names) => { - let copy_rules = rule_names.clone(); - let search_start = Instant::now(); - - for (rule_name, rule) in copy_rules.iter() { - let mut all_matches = vec![]; - let rule_search_start = Instant::now(); - let mut did_match = false; - let timestamp = self.rule_last_run_timestamp.get(rule_name).unwrap_or(&0); - self.run_query(&rule.query, *timestamp, false, |values| { - did_match = true; - assert_eq!(values.len(), rule.query.vars.len()); - all_matches.extend_from_slice(values); - Ok(()) - }); - let rule_search_time = rule_search_start.elapsed(); - log::trace!( - "Searched for {rule_name} in {:.3}s ({} results)", - rule_search_time.as_secs_f64(), - all_matches.len() - ); - run_report.add_rule_search_time(*rule_name, rule_search_time); - search_results.insert( - *rule_name, - SearchResult { - all_matches, - did_match, - }, - ); - } - - let search_time = search_start.elapsed(); - run_report.add_ruleset_search_time(ruleset, search_time); - } - Ruleset::Combined(_name, sub_rulesets) => { - let start_time = Instant::now(); - for sub_ruleset in sub_rulesets { - self.search_rules(*sub_ruleset, run_report, search_results); - } - let search_time = start_time.elapsed(); - run_report.add_ruleset_search_time(ruleset, search_time); - } - } - } - - fn apply_rules( - &mut self, - ruleset: Symbol, - run_report: &mut RunReport, - search_results: &HashMap, - ) { - // TODO this clone is not efficient - let rules = self.rulesets.get(&ruleset).unwrap().clone(); - match rules { - Ruleset::Rules(_name, compiled_rules) => { - let apply_start = Instant::now(); - let rule_names = compiled_rules.keys().cloned().collect::>(); - for rule_name in rule_names { - let SearchResult { - all_matches, - did_match, - } = search_results.get(&rule_name).unwrap(); - let rule = compiled_rules.get(&rule_name).unwrap(); - let num_vars = rule.query.vars.len(); - - // make sure the query requires matches - if num_vars != 0 { - run_report.add_rule_num_matches(rule_name, all_matches.len() / num_vars); - } - - self.rule_last_run_timestamp - .insert(rule_name, self.timestamp); - let rule_apply_start = Instant::now(); - - let stack = &mut vec![]; - - // when there are no variables, a query can still fail to match - // here we handle that case - if num_vars == 0 { - if *did_match { - stack.clear(); - self.run_actions(stack, &[], &rule.program, true) - .unwrap_or_else(|e| { - panic!("error while running actions for {rule_name}: {e}") - }); - } - } else { - for values in all_matches.chunks(num_vars) { - stack.clear(); - self.run_actions(stack, values, &rule.program, true) - .unwrap_or_else(|e| { - panic!("error while running actions for {rule_name}: {e}") - }); - } - } - - // add to the rule's apply time - run_report.add_rule_apply_time(rule_name, rule_apply_start.elapsed()); - } - run_report.add_ruleset_apply_time(ruleset, apply_start.elapsed()); - } - Ruleset::Combined(_name, sub_rulesets) => { - let start_time = Instant::now(); - for sub_ruleset in sub_rulesets { - self.apply_rules(sub_ruleset, run_report, search_results); - } - let apply_time = start_time.elapsed(); - run_report.add_ruleset_apply_time(ruleset, apply_time); - } - } - } - - fn step_rules(&mut self, ruleset: Symbol) -> RunReport { - let n_unions_before = self.unionfind.n_unions(); - let mut run_report = Default::default(); - let mut search_results = HashMap::::default(); - self.search_rules(ruleset, &mut run_report, &mut search_results); - self.apply_rules(ruleset, &mut run_report, &search_results); - run_report.updated |= - self.did_change_tables() || n_unions_before != self.unionfind.n_unions(); - - run_report - } - fn did_change_tables(&self) -> bool { for (_name, function) in &self.functions { if function.nodes.max_ts() >= self.timestamp { @@ -1089,11 +881,10 @@ impl EGraph { fn add_rule_with_name( &mut self, - name: String, + name: Symbol, rule: ast::ResolvedRule, ruleset: Symbol, ) -> Result { - let name = Symbol::from(name); let core_rule = rule.to_canonicalized_core_rule(self.type_info())?; let (query, actions) = (core_rule.body, core_rule.head); @@ -1126,8 +917,8 @@ impl EGraph { &mut self, rule: ast::ResolvedRule, ruleset: Symbol, + name: Symbol, ) -> Result { - let name = format!("{}", rule); self.add_rule_with_name(name, rule, ruleset) } @@ -1140,8 +931,7 @@ impl EGraph { let program = self .compile_actions(&Default::default(), &actions) .map_err(Error::TypeErrors)?; - let mut stack = vec![]; - self.run_actions(&mut stack, &[], &program, true)?; + self.run_actions(&[], &program, true)?; Ok(()) } @@ -1172,8 +962,7 @@ impl EGraph { let program = self .compile_expr(&Default::default(), &actions, &target) .map_err(Error::TypeErrors)?; - let mut stack = vec![]; - self.run_actions(&mut stack, &[], &program, make_defaults)?; + let mut stack = self.run_actions(&[], &program, make_defaults)?; Ok(stack.pop().unwrap()) } @@ -1288,7 +1077,7 @@ impl EGraph { rule, name, } => { - self.add_rule(rule, ruleset)?; + self.add_rule(rule, ruleset, name)?; log::info!("Declared rule {name}.") } ResolvedNCommand::RunSchedule(sched) => { @@ -1605,6 +1394,14 @@ impl EGraph { pub(crate) fn type_info_mut(&mut self) -> &mut TypeInfo { &mut self.type_info } + + pub fn get_timestamp(&self) -> u32 { + return self.timestamp; + } + + pub fn bump_timestamp(&mut self) { + self.timestamp += 1; + } } // Currently, only the following errors can thrown without location information: diff --git a/src/scheduler.rs b/src/scheduler.rs new file mode 100644 index 00000000..a79f1074 --- /dev/null +++ b/src/scheduler.rs @@ -0,0 +1,219 @@ +use std::time::Instant; + +use crate::{Ruleset, Symbol}; + +use crate::{EGraph, HashMap, ListDisplay, ResolvedRunConfig, ResolvedSchedule, RunReport, SearchResult, Span}; + +pub trait Scheduler { + fn run_schedule(&mut self, egraph: &mut EGraph, sched: &ResolvedSchedule) -> RunReport { + match sched { + ResolvedSchedule::Run(span, config) => self.run_rules(egraph, span, config), + ResolvedSchedule::Repeat(_span, limit, sched) => { + let mut report = RunReport::default(); + for _i in 0..*limit { + let rec = self.run_schedule(egraph, sched); + report = report.union(&rec); + if !rec.updated { + break; + } + } + report + } + ResolvedSchedule::Saturate(_span, sched) => { + let mut report = RunReport::default(); + loop { + let rec = self.run_schedule(egraph, sched); + report = report.union(&rec); + if !rec.updated { + break; + } + } + report + } + ResolvedSchedule::Sequence(_span, scheds) => { + let mut report = RunReport::default(); + for sched in scheds { + report = report.union(&self.run_schedule(egraph, sched)); + } + report + } + } + } + + fn run_rules( + &mut self, + egraph: &mut EGraph, + span: &Span, + config: &ResolvedRunConfig, + ) -> RunReport { + let mut report: RunReport = Default::default(); + + // first rebuild + let rebuild_start = Instant::now(); + let updates = egraph.rebuild_nofail(); + log::debug!("database size: {}", egraph.num_tuples()); + log::debug!("Made {updates} updates"); + // add to the rebuild time for this ruleset + report.add_ruleset_rebuild_time(config.ruleset, rebuild_start.elapsed()); + egraph.bump_timestamp(); + + let ResolvedRunConfig { ruleset, until } = config; + + if let Some(facts) = until { + if egraph.check_facts(span, facts).is_ok() { + log::info!( + "Breaking early because of facts:\n {}!", + ListDisplay(facts, "\n") + ); + return report; + } + } + + let subreport = self.step_rules(egraph, *ruleset); + report = report.union(&subreport); + + log::debug!("database size: {}", egraph.num_tuples()); + egraph.bump_timestamp(); + + report + } + + fn step_rules(&mut self, egraph: &mut EGraph, ruleset: Symbol) -> RunReport { + let n_unions_before = egraph.unionfind.n_unions(); + let mut run_report = Default::default(); + let mut search_results = HashMap::::default(); + self.search_rules(egraph, ruleset, &mut run_report, &mut search_results); + self.apply_rules(egraph, ruleset, &mut run_report, &search_results); + run_report.updated |= + egraph.did_change_tables() || n_unions_before != egraph.unionfind.n_unions(); + + run_report + } + + /// Search all the rules in a ruleset. + /// Add the search results for a rule to search_results, a map indexed by rule name. + fn search_rules( + &self, + egraph: &EGraph, + ruleset: Symbol, + run_report: &mut RunReport, + search_results: &mut HashMap, + ) { + let rules = egraph + .rulesets + .get(&ruleset) + .unwrap_or_else(|| panic!("ruleset does not exist: {}", &ruleset)); + match rules { + Ruleset::Rules(_ruleset_name, rule_names) => { + let copy_rules = rule_names.clone(); + let search_start = Instant::now(); + + for (rule_name, rule) in copy_rules.iter() { + let mut all_matches = vec![]; + let rule_search_start = Instant::now(); + let mut did_match = false; + let timestamp = egraph.rule_last_run_timestamp.get(rule_name).unwrap_or(&0); + egraph.run_query(&rule.query, *timestamp, false, |values| { + did_match = true; + assert_eq!(values.len(), rule.query.vars.len()); + all_matches.extend_from_slice(values); + Ok(()) + }); + let rule_search_time = rule_search_start.elapsed(); + log::trace!( + "Searched for {rule_name} in {:.3}s ({} results)", + rule_search_time.as_secs_f64(), + all_matches.len() + ); + run_report.add_rule_search_time(*rule_name, rule_search_time); + search_results.insert( + *rule_name, + SearchResult { + all_matches, + did_match, + }, + ); + } + + let search_time = search_start.elapsed(); + run_report.add_ruleset_search_time(ruleset, search_time); + } + Ruleset::Combined(_name, sub_rulesets) => { + let start_time = Instant::now(); + for sub_ruleset in sub_rulesets { + self.search_rules(egraph, *sub_ruleset, run_report, search_results); + } + let search_time = start_time.elapsed(); + run_report.add_ruleset_search_time(ruleset, search_time); + } + } + } + + fn apply_rules( + &mut self, + egraph: &mut EGraph, + ruleset: Symbol, + run_report: &mut RunReport, + search_results: &HashMap, + ) { + // TODO this clone is not efficient + let rules = egraph.rulesets.get(&ruleset).unwrap().clone(); + match rules { + Ruleset::Rules(_name, compiled_rules) => { + let apply_start = Instant::now(); + let rule_names = compiled_rules.keys().cloned().collect::>(); + for rule_name in rule_names { + let SearchResult { + all_matches, + did_match, + } = search_results.get(&rule_name).unwrap(); + let rule = compiled_rules.get(&rule_name).unwrap(); + let num_vars = rule.query.vars.len(); + + // make sure the query requires matches + if num_vars != 0 { + run_report.add_rule_num_matches(rule_name, all_matches.len() / num_vars); + } + + egraph.rule_last_run_timestamp + .insert(rule_name, egraph.get_timestamp()); + let rule_apply_start = Instant::now(); + + // when there are no variables, a query can still fail to match + // here we handle that case + if num_vars == 0 { + if *did_match { + egraph.run_actions(&[], &rule.program, true) + .unwrap_or_else(|e| { + panic!("error while running actions for {rule_name}: {e}") + }); + } + } else { + for values in all_matches.chunks(num_vars) { + egraph.run_actions(values, &rule.program, true) + .unwrap_or_else(|e| { + panic!("error while running actions for {rule_name}: {e}") + }); + } + } + + // add to the rule's apply time + run_report.add_rule_apply_time(rule_name, rule_apply_start.elapsed()); + } + run_report.add_ruleset_apply_time(ruleset, apply_start.elapsed()); + } + Ruleset::Combined(_name, sub_rulesets) => { + let start_time = Instant::now(); + for sub_ruleset in sub_rulesets { + self.apply_rules(egraph, sub_ruleset, run_report, search_results); + } + let apply_time = start_time.elapsed(); + run_report.add_ruleset_apply_time(ruleset, apply_time); + } + } + } +} + +pub struct SimpleScheduler; + +impl Scheduler for SimpleScheduler {} \ No newline at end of file diff --git a/src/sort/fn.rs b/src/sort/fn.rs index 317c0d27..ddc9385b 100644 --- a/src/sort/fn.rs +++ b/src/sort/fn.rs @@ -419,9 +419,6 @@ fn call_fn(egraph: &mut EGraph, name: &Symbol, types: Vec, args: Vec