diff --git a/src/lib.rs b/src/lib.rs index 8ad01462..3dfd30b9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -65,12 +65,18 @@ pub struct RunReport { pub rebuild_time: Duration, } +/// A report of the results of an extract action. #[derive(Debug, Clone)] -pub struct ExtractReport { - pub cost: usize, - pub term: Term, - pub variants: Vec, - pub termdag: TermDag, +pub enum ExtractReport { + Best { + termdag: TermDag, + cost: usize, + term: Term, + }, + Variants { + termdag: TermDag, + terms: Vec, + }, } impl RunReport { @@ -273,7 +279,14 @@ impl EGraph { pub fn pop(&mut self) -> Result<(), Error> { match self.egraphs.pop() { Some(e) => { + // Copy the reports and messages from the popped egraph + let extract_report = self.extract_report.clone(); + let run_report = self.run_report.clone(); + let messages = self.msgs.clone(); *self = e; + self.extract_report = extract_report.or(self.extract_report.clone()); + self.run_report = run_report.or(self.run_report.clone()); + self.msgs.extend(messages); Ok(()) } None => Err(Error::Pop), @@ -907,8 +920,6 @@ impl EGraph { fn run_command(&mut self, command: NCommand, should_run: bool) -> Result<(), Error> { let pre_rebuild = Instant::now(); - self.extract_report = None; - self.run_report = None; let rebuild_num = self.rebuild()?; if rebuild_num > 0 { log::info!( @@ -919,8 +930,6 @@ impl EGraph { self.debug_assert_invariants(); - self.extract_report = None; - self.run_report = None; match command { NCommand::SetOption { name, value } => { let str = format!("Set option {} to {}", name, value); @@ -1087,11 +1096,12 @@ impl EGraph { .create(true) .open(&filename) .map_err(|e| Error::IoError(filename.clone(), e))?; - + let mut termdag = TermDag::default(); for expr in exprs { + let (t, value) = self.eval_expr(&expr, None, true)?; + let term = self.extract(value, &mut termdag, &t).1; use std::io::Write; - let res = self.extract_expr(expr, 1)?; - writeln!(f, "{}", res.termdag.to_string(&res.term)) + writeln!(f, "{}", termdag.to_string(&term)) .map_err(|e| Error::IoError(filename.clone(), e))?; } @@ -1107,31 +1117,6 @@ impl EGraph { } } - // Extract an expression from the current state, returning the cost, the extracted expression and some number - // of other variants, if variants is not zero. - pub fn extract_expr(&mut self, e: Expr, num_variants: usize) -> Result { - let (t, value) = self.eval_expr(&e, None, true)?; - let mut termdag = TermDag::default(); - let (cost, expr) = self.extract(value, &mut termdag, &t); - let variants = match num_variants { - 0 => vec![], - 1 => vec![expr.clone()], - _ => { - if self.get_sort(&value).is_some_and(|sort| sort.is_eq_sort()) { - self.extract_variants(value, num_variants, &mut termdag) - } else { - vec![expr.clone()] - } - } - }; - Ok(ExtractReport { - cost, - term: expr, - variants, - termdag, - }) - } - pub fn process_commands( &mut self, program: Vec, diff --git a/src/typecheck.rs b/src/typecheck.rs index a0f89e88..9b294ea9 100644 --- a/src/typecheck.rs +++ b/src/typecheck.rs @@ -870,31 +870,37 @@ impl EGraph { let variants = values[1].bits as i64; if variants == 0 { - let (cost, expr) = self.extract( + let (cost, term) = self.extract( values[0], &mut termdag, self.type_info().sorts.get(&values[0].tag).unwrap(), ); - let extracted = termdag.to_string(&expr); + let extracted = termdag.to_string(&term); log::info!("extracted with cost {cost}: {}", extracted); self.print_msg(extracted); + self.extract_report = Some(ExtractReport::Best { + termdag, + cost, + term, + }); } else { if variants < 0 { panic!("Cannot extract negative number of variants"); } - let extracted = + let terms = self.extract_variants(values[0], variants as usize, &mut termdag); log::info!("extracted variants:"); let mut msg = String::default(); msg += "(\n"; - assert!(!extracted.is_empty()); - for expr in extracted { - let str = termdag.to_string(&expr); + assert!(!terms.is_empty()); + for expr in &terms { + let str = termdag.to_string(expr); log::info!(" {}", str); msg += &format!(" {}\n", str); } msg += ")"; self.print_msg(msg); + self.extract_report = Some(ExtractReport::Variants { termdag, terms }); } stack.truncate(new_len);