Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ExtractReport generation and extracting in sub-graph #205

Merged
merged 6 commits into from
Aug 29, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 26 additions & 37 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,18 @@ pub struct RunReport {
pub rebuild_time: Duration,
}

/// A report of the results of an extract action.
#[derive(Debug, Clone)]
pub struct ExtractReport {
pub cost: usize,
pub expr: Term,
pub variants: Vec<Term>,
pub termdag: TermDag,
pub enum ExtractReport {
Best {
termdag: TermDag,
cost: usize,
expr: Term,
},
Variants {
termdag: TermDag,
variants: Vec<Term>,
},
}

impl RunReport {
Expand Down Expand Up @@ -275,7 +281,18 @@ impl EGraph {
pub fn pop(&mut self) -> Result<(), Error> {
match self.egraphs.pop() {
Some(e) => {
// Copy the reports and messages from the popped egraph
let extract_report = self.extract_report.clone();
let run_report = self.run_report.clone();
let messages = self.msgs.clone();
*self = e;
if let Some(report) = extract_report {
self.extract_report = Some(report);
saulshanabrook marked this conversation as resolved.
Show resolved Hide resolved
}
if let Some(report) = run_report {
self.run_report = Some(report);
}
self.msgs.extend(messages);
Ok(())
}
None => Err(Error::Pop),
Expand Down Expand Up @@ -909,8 +926,6 @@ impl EGraph {

fn run_command(&mut self, command: NCommand, should_run: bool) -> Result<(), Error> {
let pre_rebuild = Instant::now();
self.extract_report = None;
self.run_report = None;
let rebuild_num = self.rebuild()?;
if rebuild_num > 0 {
log::info!(
Expand All @@ -921,8 +936,6 @@ impl EGraph {

self.debug_assert_invariants();

self.extract_report = None;
self.run_report = None;
match command {
NCommand::SetOption { name, value } => {
let str = format!("Set option {} to {}", name, value);
Expand Down Expand Up @@ -1089,11 +1102,12 @@ impl EGraph {
.create(true)
.open(&filename)
.map_err(|e| Error::IoError(filename.clone(), e))?;

let mut termdag = TermDag::default();
for expr in exprs {
let (t, value) = self.eval_expr(&expr, None, true)?;
let expr = self.extract(value, &mut termdag, &t).1;
use std::io::Write;
let res = self.extract_expr(expr, 1)?;
writeln!(f, "{}", res.termdag.to_string(&res.expr))
writeln!(f, "{}", termdag.to_string(&expr))
.map_err(|e| Error::IoError(filename.clone(), e))?;
}

Expand All @@ -1109,31 +1123,6 @@ impl EGraph {
}
}

// Extract an expression from the current state, returning the cost, the extracted expression and some number
// of other variants, if variants is not zero.
pub fn extract_expr(&mut self, e: Expr, num_variants: usize) -> Result<ExtractReport, Error> {
Comment on lines -1110 to -1112
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function was previously used in a number of places before #176 and was now only being used in the un-tested Output command.

let (t, value) = self.eval_expr(&e, None, true)?;
let mut termdag = TermDag::default();
let (cost, expr) = self.extract(value, &mut termdag, &t);
let variants = match num_variants {
0 => vec![],
1 => vec![expr.clone()],
_ => {
if self.get_sort(&value).is_some_and(|sort| sort.is_eq_sort()) {
self.extract_variants(value, num_variants, &mut termdag)
} else {
vec![expr.clone()]
}
}
};
Ok(ExtractReport {
cost,
expr,
variants,
termdag,
})
}

pub fn process_commands(
&mut self,
program: Vec<Command>,
Expand Down
13 changes: 11 additions & 2 deletions src/typecheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -889,6 +889,11 @@ impl EGraph {
let extracted = termdag.to_string(&expr);
log::info!("extracted with cost {cost}: {}", extracted);
self.print_msg(extracted);
self.extract_report = Some(ExtractReport::Best {
termdag,
cost,
expr,
});
} else {
if variants < 0 {
panic!("Cannot extract negative number of variants");
Expand All @@ -899,13 +904,17 @@ impl EGraph {
let mut msg = String::default();
msg += "(\n";
assert!(!extracted.is_empty());
for expr in extracted {
let str = termdag.to_string(&expr);
for expr in &extracted {
let str = termdag.to_string(expr);
log::info!(" {}", str);
msg += &format!(" {}\n", str);
}
msg += ")";
self.print_msg(msg);
self.extract_report = Some(ExtractReport::Variants {
termdag,
variants: extracted,
});
}

stack.truncate(new_len);
Expand Down