Skip to content

Commit

Permalink
first try
Browse files Browse the repository at this point in the history
  • Loading branch information
yihozhang committed Oct 23, 2024
1 parent b0db068 commit 876c63f
Show file tree
Hide file tree
Showing 22 changed files with 353 additions and 206 deletions.
8 changes: 4 additions & 4 deletions src/actions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ impl<'a> ActionCompiler<'a> {

fn do_prim(&mut self, prim: &SpecializedPrimitive) {
self.instructions.push(Instruction::CallPrimitive(
prim.primitive.clone(),
prim.clone(),
prim.input.len(),
));
}
Expand Down Expand Up @@ -126,7 +126,7 @@ enum Instruction {
CallFunction(Symbol, bool),
/// Pop primitive arguments off the stack, calls the primitive,
/// and push the result onto the stack.
CallPrimitive(Primitive, usize),
CallPrimitive(SpecializedPrimitive, usize),
/// Pop function arguments off the stack and either deletes or subsumes the corresponding row
/// in the function.
Change(Change, Symbol),
Expand Down Expand Up @@ -321,11 +321,11 @@ impl EGraph {
Instruction::CallPrimitive(p, arity) => {
let new_len = stack.len() - arity;
let values = &stack[new_len..];
if let Some(value) = p.apply(values, Some(self)) {
if let Some(value) = p.primitive.apply(values, (&p.input, &p.output), Some(self)) {
stack.truncate(new_len);
stack.push(value);
} else {
return Err(Error::PrimitiveError(p.clone(), values.to_vec()));
return Err(Error::PrimitiveError(p.primitive.clone(), values.to_vec()));
}
}
Instruction::Set(f) => {
Expand Down
2 changes: 1 addition & 1 deletion src/ast/parse.lalrpop
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ F64: OrderedFloat<f64> = {
"inf" => OrderedFloat::<f64>(f64::INFINITY),
"-inf" => OrderedFloat::<f64>(f64::NEG_INFINITY),
}
Ident: Symbol = <s:r"(([[:alpha:]][\w-]*)|([-+*/?!=<>&|^/%_]))+"> => s.parse().unwrap();
Ident: Symbol = <s:r"(([[:alpha:]][\w-]*)|([-+*/?!=<>&|^/%_#]))+"> => s.parse().unwrap();
SymString: Symbol = <String> => Symbol::from(<>);

String: String = <r#"("[^"]*")+"#> => {
Expand Down
147 changes: 99 additions & 48 deletions src/constraint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,40 @@ pub enum ImpossibleConstraint {
actual_output: ArcSort,
actual_input: Vec<ArcSort>,
},
CompileTimeConstantExpected {
span: Span,
sort: ArcSort,
},
UnboundedFunction {
head: Symbol,
span: Span,
},
}

#[derive(Debug)]
pub enum Constraint<Var, Value> {
pub enum Constraint<'a, Var, Value> {
Eq(Var, Var),
Assign(Var, Value),
And(Vec<Constraint<Var, Value>>),
And(Vec<Constraint<'a, Var, Value>>),
// Exactly one of the constraints holds
// and all others are false
Xor(Vec<Constraint<Var, Value>>),
Xor(Vec<Constraint<'a, Var, Value>>),
LazyConstraint(Var, Box<dyn Fn(&Value) -> Self + 'a>),
Impossible(ImpossibleConstraint),
}

impl<'a, Var: Debug, Value: Debug> Debug for Constraint<'a, Var, Value> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Constraint::Eq(x, y) => write!(f, "{:?} = {:?}", x, y),
Constraint::Assign(x, v) => write!(f, "{:?} = {:?}", x, v),
Constraint::And(cs) => write!(f, "And({:?})", cs),
Constraint::Xor(cs) => write!(f, "Xor({:?})", cs),
Constraint::LazyConstraint(x, _) => write!(f, "LazyConstraint({:?}, trigger=...)", x),
Constraint::Impossible(c) => write!(f, "Impossible({:?})", c),
}
}
}

pub enum ConstraintError<Var, Value> {
InconsistentConstraint(Var, Value, Value),
UnconstrainedVar(Var),
Expand Down Expand Up @@ -84,11 +105,18 @@ impl ConstraintError<AtomTerm, ArcSort> {
actual_output.clone(),
actual_input.clone(),
),
ConstraintError::ImpossibleCaseIdentified(
ImpossibleConstraint::CompileTimeConstantExpected { span, sort },
) => TypeError::CompileTimeConstantExpected(sort.clone(), span.clone()),
ConstraintError::ImpossibleCaseIdentified(ImpossibleConstraint::UnboundedFunction {
head,
span,
}) => TypeError::UnboundFunction(*head, span.clone()),
}
}
}

impl<Var, Value> Constraint<Var, Value>
impl<'a, Var, Value> Constraint<'a, Var, Value>
where
Var: Eq + PartialEq + Hash + Clone + Debug,
Value: Clone + Debug,
Expand All @@ -97,11 +125,12 @@ where
/// If there's a conflict, returns the conflicting variable, the assigned conflicting types.
/// Otherwise, return whether the assignment is updated.
fn update<K: Eq>(
&self,
&mut self,
assignment: &mut Assignment<Var, Value>,
key: impl Fn(&Value) -> K + Copy,
) -> Result<bool, ConstraintError<Var, Value>> {
match self {
let mut new_self = None;
let result = match self {
Constraint::Eq(x, y) => match (assignment.0.get(x), assignment.0.get(y)) {
(Some(value), None) => {
assignment.insert(y.clone(), value.clone());
Expand Down Expand Up @@ -198,17 +227,33 @@ where
}
Ok(updated)
}
Constraint::LazyConstraint(var, trigger) => {
if assignment.0.contains_key(var) {
//
let value = assignment.0.get(var).unwrap();
let mut constraint = trigger(value);
constraint.update(assignment, key)?;
new_self = Some(constraint);
Ok(true)
} else {
Ok(false)
}
}
};
if let Some(new_self) = new_self {
*self = new_self;
}
return result;
}
}

#[derive(Debug)]
pub struct Problem<Var, Value> {
pub constraints: Vec<Constraint<Var, Value>>,
pub struct Problem<'a, Var, Value> {
pub constraints: Vec<Constraint<'a, Var, Value>>,
pub range: HashSet<Var>,
}

impl Default for Problem<AtomTerm, ArcSort> {
impl<'a> Default for Problem<'a, AtomTerm, ArcSort> {
fn default() -> Self {
Self {
constraints: vec![],
Expand Down Expand Up @@ -422,20 +467,20 @@ impl Assignment<AtomTerm, ArcSort> {
}
}

impl<Var, Value> Problem<Var, Value>
impl<'a, Var, Value> Problem<'a, Var, Value>
where
Var: Eq + PartialEq + Hash + Clone + Debug,
Value: Clone + Debug,
{
pub(crate) fn solve<K: Eq + Debug>(
&self,
mut self,
key: impl Fn(&Value) -> K + Copy,
) -> Result<Assignment<Var, Value>, ConstraintError<Var, Value>> {
let mut assignment = Assignment(HashMap::default());
let mut changed = true;
while changed {
changed = false;
for constraint in self.constraints.iter() {
for constraint in self.constraints.iter_mut() {
changed |= constraint.update(&mut assignment, key)?;
}
}
Expand All @@ -453,11 +498,11 @@ where
}
}

impl Problem<AtomTerm, ArcSort> {
impl<'a> Problem<'a, AtomTerm, ArcSort> {
pub(crate) fn add_query(
&mut self,
query: &Query<SymbolOrEq, Symbol>,
typeinfo: &TypeInfo,
typeinfo: &'a TypeInfo,
) -> Result<(), TypeError> {
self.constraints.extend(query.get_constraints(typeinfo)?);
self.range.extend(query.atom_terms());
Expand All @@ -467,7 +512,7 @@ impl Problem<AtomTerm, ArcSort> {
pub fn add_actions(
&mut self,
actions: &GenericCoreActions<Symbol, Symbol>,
typeinfo: &TypeInfo,
typeinfo: &'a TypeInfo,
symbol_gen: &mut SymbolGen,
) -> Result<(), TypeError> {
for action in actions.0.iter() {
Expand All @@ -491,7 +536,7 @@ impl Problem<AtomTerm, ArcSort> {
pub(crate) fn add_rule(
&mut self,
rule: &CoreRule,
typeinfo: &TypeInfo,
typeinfo: &'a TypeInfo,
symbol_gen: &mut SymbolGen,
) -> Result<(), TypeError> {
let CoreRule {
Expand All @@ -517,18 +562,18 @@ impl Problem<AtomTerm, ArcSort> {
}

impl CoreAction {
pub(crate) fn get_constraints(
pub(crate) fn get_constraints<'a>(
&self,
typeinfo: &TypeInfo,
typeinfo: &'a TypeInfo,
symbol_gen: &mut SymbolGen,
) -> Result<Vec<Constraint<AtomTerm, ArcSort>>, TypeError> {
) -> Result<Vec<Constraint<'a, AtomTerm, ArcSort>>, TypeError> {
match self {
CoreAction::Let(span, symbol, f, args) => {
let mut args = args.clone();
args.push(AtomTerm::Var(span.clone(), *symbol));

Ok(get_literal_and_global_constraints(&args, typeinfo)
.chain(get_atom_application_constraints(f, &args, span, typeinfo)?)
.chain(get_atom_application_constraints(f, &args, span, typeinfo))
.collect())
}
CoreAction::Set(span, head, args, rhs) => {
Expand All @@ -538,7 +583,7 @@ impl CoreAction {
Ok(get_literal_and_global_constraints(&args, typeinfo)
.chain(get_atom_application_constraints(
head, &args, span, typeinfo,
)?)
))
.collect())
}
CoreAction::Change(span, _change, head, args) => {
Expand All @@ -550,7 +595,7 @@ impl CoreAction {
Ok(get_literal_and_global_constraints(&args, typeinfo)
.chain(get_atom_application_constraints(
head, &args, span, typeinfo,
)?)
))
.collect())
}
CoreAction::Union(_ann, lhs, rhs) => Ok(get_literal_and_global_constraints(
Expand Down Expand Up @@ -584,10 +629,10 @@ impl CoreAction {
}

impl Atom<SymbolOrEq> {
pub fn get_constraints(
pub fn get_constraints<'a>(
&self,
type_info: &TypeInfo,
) -> Result<Vec<Constraint<AtomTerm, ArcSort>>, TypeError> {
type_info: &'a TypeInfo,
) -> Result<Vec<Constraint<'a, AtomTerm, ArcSort>>, TypeError> {
let literal_constraints = get_literal_and_global_constraints(&self.args, type_info);
match &self.head {
SymbolOrEq::Eq => {
Expand All @@ -603,25 +648,25 @@ impl Atom<SymbolOrEq> {
SymbolOrEq::Symbol(head) => Ok(literal_constraints
.chain(get_atom_application_constraints(
head, &self.args, &self.span, type_info,
)?)
))
.collect()),
}
}
}

fn get_atom_application_constraints(
pub(crate) fn get_atom_application_constraints<'a>(
head: &Symbol,
args: &[AtomTerm],
span: &Span,
type_info: &TypeInfo,
) -> Result<Vec<Constraint<AtomTerm, ArcSort>>, TypeError> {
type_info: &'a TypeInfo,
) -> Vec<Constraint<'a, AtomTerm, ArcSort>> {
// An atom can have potentially different semantics due to polymorphism
// e.g. (set-empty) can mean any empty set with some element type.
// To handle this, we collect each possible instantiations of an atom
// (where each instantiation is a vec of constraints, thus vec of vec)
// into `xor_constraints`.
// `Constraint::Xor` means one and only one of the instantiation can hold.
let mut xor_constraints: Vec<Vec<Constraint<AtomTerm, ArcSort>>> = vec![];
let mut xor_constraints: Vec<Vec<Constraint<'a, AtomTerm, ArcSort>>> = vec![];

// function atom constraints
if let Some(typ) = type_info.func_types.get(head) {
Expand Down Expand Up @@ -664,28 +709,34 @@ fn get_atom_application_constraints(
// do literal and global variable constraints first
// as they are the most "informative"
match xor_constraints.len() {
0 => Err(TypeError::UnboundFunction(*head, span.clone())),
1 => Ok(xor_constraints.pop().unwrap()),
_ => Ok(vec![Constraint::Xor(
0 => vec![Constraint::Impossible(
ImpossibleConstraint::UnboundedFunction {
head: *head,
span: span.clone(),
},
)],
// 0 => Err(TypeError::UnboundFunction(*head, span.clone())),
1 => xor_constraints.pop().unwrap(),
_ => vec![Constraint::Xor(
xor_constraints.into_iter().map(Constraint::And).collect(),
)]),
)],
}
}

fn get_literal_and_global_constraints<'a>(
fn get_literal_and_global_constraints<'a, 'b>(
args: &'a [AtomTerm],
type_info: &'a TypeInfo,
) -> impl Iterator<Item = Constraint<AtomTerm, ArcSort>> + 'a {
) -> impl Iterator<Item = Constraint<'b, AtomTerm, ArcSort>> + 'a {
args.iter().filter_map(|arg| {
match arg {
AtomTerm::Var(_, _) => None,
// Literal to type constraint
AtomTerm::Literal(_, lit) => {
let typ = crate::sort::literal_sort(lit);
let typ = crate::sort::literal_sort(&lit);
Some(Constraint::Assign(arg.clone(), typ))
}
AtomTerm::Global(_, v) => {
if let Some(typ) = type_info.lookup_global(v) {
if let Some(typ) = type_info.lookup_global(&v) {
Some(Constraint::Assign(arg.clone(), typ.clone()))
} else {
panic!("All global variables should be bound before type checking")
Expand All @@ -696,11 +747,11 @@ fn get_literal_and_global_constraints<'a>(
}

pub trait TypeConstraint {
fn get(
fn get<'a>(
&self,
arguments: &[AtomTerm],
typeinfo: &TypeInfo,
) -> Vec<Constraint<AtomTerm, ArcSort>>;
typeinfo: &'a TypeInfo,
) -> Vec<Constraint<'a, AtomTerm, ArcSort>>;
}

/// Construct a set of `Assign` constraints that fully constrain the type of arguments
Expand All @@ -721,11 +772,11 @@ impl SimpleTypeConstraint {
}

impl TypeConstraint for SimpleTypeConstraint {
fn get(
fn get<'a>(
&self,
arguments: &[AtomTerm],
_typeinfo: &TypeInfo,
) -> Vec<Constraint<AtomTerm, ArcSort>> {
_typeinfo: &'a TypeInfo,
) -> Vec<Constraint<'a, AtomTerm, ArcSort>> {
if arguments.len() != self.sorts.len() {
vec![Constraint::Impossible(
ImpossibleConstraint::ArityMismatch {
Expand Down Expand Up @@ -796,11 +847,11 @@ impl AllEqualTypeConstraint {
}

impl TypeConstraint for AllEqualTypeConstraint {
fn get(
fn get<'a>(
&self,
mut arguments: &[AtomTerm],
_typeinfo: &TypeInfo,
) -> Vec<Constraint<AtomTerm, ArcSort>> {
_typeinfo: &'a TypeInfo,
) -> Vec<Constraint<'a, AtomTerm, ArcSort>> {
if arguments.is_empty() {
panic!("all arguments should have length > 0")
}
Expand Down
Loading

0 comments on commit 876c63f

Please sign in to comment.