Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reorganize EquivalenceClasses to use more efficient algorithms #30082

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
289 changes: 215 additions & 74 deletions src/transform/src/analysis/equivalences.rs
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,13 @@ pub struct EquivalenceClasses {
/// can be replaced by.
/// These classes are unified whenever possible, to minimize the number of classes.
pub classes: Vec<Vec<MirScalarExpr>>,

/// An expression simplification map.
///
/// This map reflects an equivalence relation based on a prior version of `self.classes`.
/// As users may add to `self.classes`, `self.remap` may become stale. We refresh `remap`
/// in `self.refresh()`, to the equivalence relation that derives from `self.classes`.
remap: BTreeMap<MirScalarExpr, MirScalarExpr>,
}

impl EquivalenceClasses {
Expand Down Expand Up @@ -324,7 +331,10 @@ impl EquivalenceClasses {
}
}

/// Sorts and deduplicates each class, and the classes themselves.
/// Sorts and deduplicates each class, removing literal errors.
///
/// This method does not ensure equivalence relation structure, but instead performs
/// only minimal structural clean-up.
fn tidy(&mut self) {
for class in self.classes.iter_mut() {
// Remove all literal errors, as they cannot be equated to other things.
Expand All @@ -337,14 +347,75 @@ impl EquivalenceClasses {
self.classes.dedup();
}

/// Restore equivalence relation structure to `self.classes` and refresh `self.remap`.
///
/// This method takes roughly linear time, and returns true iff `remap` has changed.
fn refresh(&mut self) -> bool {
self.tidy();

// remap may already be the correct answer, and if so we should avoid the work of rebuilding it.
// If it contains the same number of expressions as `self.classes`, and for every expression in
// `self.classes` the two agree on the representative, the are identical.
if self.remap.len() == self.classes.iter().map(|c| c.len()).sum::<usize>()
&& self
.classes
.iter()
.all(|c| c.iter().all(|e| self.remap.get(e) == Some(&c[0])))
{
// No change, so return false.
return false;
}

// Optimistically build the `remap` we would want.
// Note if any unions would be required, in which case we have further work to do,
// including re-forming `self.classes`.
let mut union_find = BTreeMap::default();
let mut dirtied = false;
for class in self.classes.iter() {
for expr in class.iter() {
if let Some(other) = union_find.insert(expr.clone(), class[0].clone()) {
// A merge is required, but have the more complex expression point at the simpler one.
// This allows `union_find` to end as the `remap` for the new `classes` we form, with
// the only required work being compressing all the paths.
if Self::mir_scalar_expr_complexity(&other, &class[0])
== std::cmp::Ordering::Less
{
union_find.union(&class[0], &other);
} else {
union_find.union(&other, &class[0]);
}
dirtied = true;
}
}
}
if dirtied {
let mut classes: BTreeMap<_, Vec<_>> = BTreeMap::default();
for class in self.classes.drain(..) {
for expr in class {
let root: MirScalarExpr = union_find.find(&expr).unwrap().clone();
classes.entry(root).or_default().push(expr);
}
}
self.classes = classes.into_values().collect();
self.tidy();
}

let changed = self.remap != union_find;
self.remap = union_find;
changed
}

/// Update `self` to maintain the same equivalences which potentially reducing along `Ord::le`.
///
/// Informally this means simplifying constraints, removing redundant constraints, and unifying equivalence classes.
pub fn minimize(&mut self, columns: &Option<Vec<ColumnType>>) {
// Repeatedly, we reduce each of the classes themselves, then unify the classes.
// This should strictly reduce complexity, and reach a fixed point.
// Ideally it is *confluent*, arriving at the same fixed point no matter the order of operations.
self.tidy();

// Ensure `self.classes` and `self.refresh` are equivalence relations.
// Users are allowed to mutate `self.classes`, so we must perform this normalization at least once.
self.refresh();

// We should not rely on nullability information present in `column_types`. (Doing this
// every time just before calling `reduce` was found to be a bottleneck during incident-217,
Expand All @@ -364,81 +435,32 @@ impl EquivalenceClasses {
// An expression can be simplified, a duplication found, or two classes unified.
let mut stable = false;
while !stable {
stable = self.minimize_once(&columns);
stable = !self.minimize_once(&columns);
}

// TODO: remove these measures once we are more confident about idempotence.
let prev = self.clone();
self.minimize_once(&columns);
mz_ore::soft_assert_eq_or_log!(self, &prev, "Equivalences::minimize() not idempotent");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm feeling a bit uneasy about removing this check just in the same PR where the stability check in minimize_once is also being changed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, this is fine after checking the new stability check in more detail, see #30082 (comment)

}

/// A single iteration of minimization, which we expect to repeat but benefit from factoring out.
///
/// This invocation should take roughly linear time.
/// It starts with equivalence class invariants maintained (closed under transitivity), and then
/// 1. Performs per-expression reduction, including the class structure to replace subexpressions.
/// 2. Applies idiom detection to e.g. unpack expressions equivalence to literal true or false.
/// 3. Restores the equivalence class invariants.
fn minimize_once(&mut self, columns: &Option<Vec<ColumnType>>) -> bool {
// We are complete unless we experience an expression simplification, or an equivalence class unification.
let mut stable = true;

// 0. Reduce each expression
// 1. Reduce each expression
//
// This is optional in that `columns` may not be provided (`reduce` requires type information).
if let Some(columns) = columns {
for class in self.classes.iter_mut() {
for expr in class.iter_mut() {
let prev_expr = expr.clone();
// This reduction first looks for subexpression substitutions that can be performed,
// and then applies expression reduction if column type information is provided.
for class in self.classes.iter_mut() {
for expr in class.iter_mut() {
self.remap.reduce_child(expr);
if let Some(columns) = columns {
expr.reduce(columns);
if &prev_expr != expr {
stable = false;
}
}
}
}

// 1. Reduce each class.
// Each class can be reduced in the context of *other* classes, which are available for substitution.
for class_index in 0..self.classes.len() {
for index in 0..self.classes[class_index].len() {
let mut cloned = self.classes[class_index][index].clone();
// Use `reduce_child` rather than `reduce_expr` to avoid entire expression replacement.
let reduced = self.reduce_child(&mut cloned);
if reduced {
self.classes[class_index][index] = cloned;
stable = false;
}
}
}

// 2. Unify classes.
// If the same expression is in two classes, we can unify the classes.
// This element may not be the representative.
// TODO: If all lists are sorted, this could be a linear merge among all.
// They stop being sorted as soon as we make any modification, though.
// But, it would be a fast rejection when faced with lots of data.
// `expr_to_class_index` tells us for each expression the class index where we last saw
// it.
// `to_merge` has pairs of classes to be merged. The first element of the pair should be
// an earlier class than the second, and `to_merge` should be in sorted order. (These
// invariants are important when handling expressions that appear in more than two
// classes: we'll first merge the first two into the second, and then all this into the
// third, and so on.)
let mut expr_to_class_index = BTreeMap::new();
let mut to_merge = Vec::new();
for (index, class) in self.classes.iter().enumerate() {
for expr in class {
if let Some(other_index) = expr_to_class_index.get(expr) {
to_merge.push((*other_index, index));
}
}
for expr in class {
expr_to_class_index.insert(expr, index);
}
}
for (from, to) in to_merge {
let prior = std::mem::take(&mut self.classes[from]);
self.classes[to].extend(prior);
stable = false;
}

// 3. Identify idioms
// 2. Identify idioms
// E.g. If Eq(x, y) must be true, we can introduce classes `[x, y]` and `[false, IsNull(x), IsNull(y)]`.
let mut to_add = Vec::new();
for class in self.classes.iter_mut() {
Expand All @@ -459,7 +481,6 @@ impl EquivalenceClasses {
expr1.clone().call_is_null(),
expr2.clone().call_is_null(),
]);
stable = false;
}
}
// Remove the more complex form of the expression.
Expand All @@ -482,7 +503,6 @@ impl EquivalenceClasses {
} = expr
{
to_add.push(vec![MirScalarExpr::literal_false(), (**e).clone()]);
stable = false;
}
}
class.retain(|expr| {
Expand All @@ -506,7 +526,6 @@ impl EquivalenceClasses {
} = expr
{
to_add.push(vec![MirScalarExpr::literal_true(), (**e).clone()]);
stable = false;
}
}
class.retain(|expr| {
Expand All @@ -525,9 +544,8 @@ impl EquivalenceClasses {
self.classes.extend(to_add);

// Tidy up classes, restore representative.
self.tidy();

stable
// Specifically, we want to remove literal errors before restoring the equivalence class structure.
self.refresh()
}

/// Produce the equivalences present in both inputs.
Expand All @@ -539,6 +557,7 @@ impl EquivalenceClasses {
// For each pair of equivalence classes, their intersection.
let mut equivalences = EquivalenceClasses {
classes: Vec::new(),
remap: Default::default(),
};
for class1 in self.classes.iter() {
for class2 in other.classes.iter() {
Expand Down Expand Up @@ -664,15 +683,15 @@ impl EquivalenceClasses {
}

/// Perform any simplification, report if effective.
pub fn reduce_expr(&self, expr: &mut MirScalarExpr) -> bool {
fn reduce_expr(&self, expr: &mut MirScalarExpr) -> bool {
let mut simplified = false;
simplified = simplified || self.reduce_child(expr);
simplified = simplified || self.replace(expr);
simplified
}

/// Perform any simplification on children, report if effective.
pub fn reduce_child(&self, expr: &mut MirScalarExpr) -> bool {
fn reduce_child(&self, expr: &mut MirScalarExpr) -> bool {
let mut simplified = false;
match expr {
MirScalarExpr::CallBinary { expr1, expr2, .. } => {
Expand Down Expand Up @@ -714,4 +733,126 @@ impl EquivalenceClasses {
}
false
}

/// Returns a map that can be used to replace (sub-)expressions.
pub fn reducer(&self) -> &BTreeMap<MirScalarExpr, MirScalarExpr> {
&self.remap
}
}

/// A type capable of simplifying `MirScalarExpr`s.
pub trait ExpressionReducer {
/// Attempt to replace `expr` itself with another expression.
/// Returns true if it does so.
fn replace(&self, expr: &mut MirScalarExpr) -> bool;
/// Attempt to replace any subexpressions of `expr` with other expressions.
/// Returns true if it does so.
fn reduce_expr(&self, expr: &mut MirScalarExpr) -> bool {
let mut simplified = false;
simplified = simplified || self.reduce_child(expr);
simplified = simplified || self.replace(expr);
simplified
}
/// Attempt to replace any subexpressions of `expr`'s children with other expressions.
/// Returns true if it does so.
fn reduce_child(&self, expr: &mut MirScalarExpr) -> bool {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be a duplicate code fragment: there is an other, identical reduce_child in the same file. It's not on the trait, though, so self is different, so maybe this is intentional? I'm not sure.

Copy link
Contributor Author

@frankmcsherry frankmcsherry Oct 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, so it is different, and it is intentional. This method, and the others in the trait, are meant to sub in for those methods on Equivalences. However, we haven't pivoted the whole codebase over from Equivalences::replace to ExpressionReducer::replace. Internally, minimize_once uses the ExpressionReducer version, but externally (e.g. in EquivalencePropagation) the interface hasn't changed yet (the first goal is to improve minimize rather than uses of Equivalences).

let mut simplified = false;
match expr {
MirScalarExpr::CallBinary { expr1, expr2, .. } => {
simplified = self.reduce_expr(expr1) || simplified;
simplified = self.reduce_expr(expr2) || simplified;
}
MirScalarExpr::CallUnary { expr, .. } => {
simplified = self.reduce_expr(expr) || simplified;
}
MirScalarExpr::CallVariadic { exprs, .. } => {
for expr in exprs.iter_mut() {
simplified = self.reduce_expr(expr) || simplified;
}
}
MirScalarExpr::If { cond: _, then, els } => {
// Do not simplify `cond`, as we cannot ensure the simplification
// continues to hold as expressions migrate around.
simplified = self.reduce_expr(then) || simplified;
simplified = self.reduce_expr(els) || simplified;
}
_ => {}
}
simplified
}
}

impl ExpressionReducer for BTreeMap<&MirScalarExpr, &MirScalarExpr> {
/// Perform any exact replacement for `expr`, report if it had an effect.
fn replace(&self, expr: &mut MirScalarExpr) -> bool {
if let Some(other) = self.get(expr) {
if other != &expr {
expr.clone_from(other);
return true;
}
}
false
}
}

impl ExpressionReducer for BTreeMap<MirScalarExpr, MirScalarExpr> {
/// Perform any exact replacement for `expr`, report if it had an effect.
fn replace(&self, expr: &mut MirScalarExpr) -> bool {
if let Some(other) = self.get(expr) {
if other != expr {
expr.clone_from(other);
return true;
}
}
false
}
}

trait UnionFind<T> {
/// Sets `self[x]` to the root from `x`, and returns a reference to the root.
fn find<'a>(&'a mut self, x: &T) -> Option<&'a T>;
/// Ensures that `x` and `y` have the same root.
fn union(&mut self, x: &T, y: &T);
}

impl<T: Clone + Ord> UnionFind<T> for BTreeMap<T, T> {
fn find<'a>(&'a mut self, x: &T) -> Option<&'a T> {
if !self.contains_key(x) {
None
} else {
if self[x] != self[&self[x]] {
// Path halving
let mut y = self[x].clone();
while y != self[&y] {
let grandparent = self[&self[&y]].clone();
*self.get_mut(&y).unwrap() = grandparent;
y.clone_from(&self[&y]);
}
*self.get_mut(x).unwrap() = y;
}
Some(&self[x])
}
}

fn union(&mut self, x: &T, y: &T) {
match (self.find(x).is_some(), self.find(y).is_some()) {
(true, true) => {
if self[x] != self[y] {
let root_x = self[x].clone();
let root_y = self[y].clone();
self.insert(root_x, root_y);
frankmcsherry marked this conversation as resolved.
Show resolved Hide resolved
}
}
(false, true) => {
self.insert(x.clone(), self[y].clone());
}
(true, false) => {
self.insert(y.clone(), self[x].clone());
}
(false, false) => {
self.insert(x.clone(), x.clone());
self.insert(y.clone(), x.clone());
}
}
}
}
Loading
Loading