Skip to content

Commit

Permalink
Properly re-print trait impls.
Browse files Browse the repository at this point in the history
  • Loading branch information
chriseth committed Oct 25, 2024
1 parent e780f03 commit 810b9ef
Show file tree
Hide file tree
Showing 9 changed files with 124 additions and 58 deletions.
3 changes: 3 additions & 0 deletions ast/src/analyzed/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ impl<T: Display> Display for Analyzed<T> {
StatementIdentifier::ProverFunction(i) => {
writeln_indented(f, format!("{};", &self.prover_functions[*i]))?;
}
StatementIdentifier::TraitImplementation(i) => {
writeln_indented(f, format!("{}", self.trait_impls[*i]))?;
}
}
}

Expand Down
7 changes: 5 additions & 2 deletions ast/src/analyzed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,20 @@ pub use crate::parsed::BinaryOperator;
pub use crate::parsed::UnaryOperator;
use crate::parsed::{
self, ArrayExpression, EnumDeclaration, EnumVariant, NamedType, TraitDeclaration,
TypeDeclaration,
TraitImplementation, TypeDeclaration,
};

#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
pub enum StatementIdentifier {
/// Either an intermediate column or a definition.
Definition(String),
PublicDeclaration(String),
/// Index into the vector of proof items.
/// Index into the vector of proof items / identities.
ProofItem(usize),
/// Index into the vector of prover functions.
ProverFunction(usize),
/// Index into the vector of trait implementations.
TraitImplementation(usize),
}

#[derive(Debug, Clone, Default, Serialize, Deserialize, JsonSchema)]
Expand All @@ -43,6 +45,7 @@ pub struct Analyzed<T> {
pub intermediate_columns: HashMap<String, (Symbol, Vec<AlgebraicExpression<T>>)>,
pub identities: Vec<Identity<SelectedExpressions<AlgebraicExpression<T>>>>,
pub prover_functions: Vec<Expression>,
pub trait_impls: Vec<TraitImplementation<Expression>>,
/// The order in which definitions and identities
/// appear in the source.
pub source_order: Vec<StatementIdentifier>,
Expand Down
3 changes: 2 additions & 1 deletion backend-utils/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,8 @@ fn split_by_namespace<F: FieldElement>(
},
}
}
StatementIdentifier::ProverFunction(_) => None,
StatementIdentifier::ProverFunction(_)
| StatementIdentifier::TraitImplementation(_) => None,
})
// collect into a map
.fold(Default::default(), |mut acc, (namespace, statement)| {
Expand Down
3 changes: 2 additions & 1 deletion backend/src/estark/json_exporter/expression_counter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ pub fn compute_intermediate_expression_ids<T>(analyzed: &Analyzed<T>) -> HashMap
analyzed.public_declarations[name].expression_count()
}
StatementIdentifier::ProofItem(id) => analyzed.identities[*id].expression_count(),
StatementIdentifier::ProverFunction(_) => 0,
StatementIdentifier::ProverFunction(_)
| StatementIdentifier::TraitImplementation(_) => 0,
}
}
ids
Expand Down
3 changes: 2 additions & 1 deletion backend/src/estark/json_exporter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ pub fn export<T: FieldElement>(analyzed: &Analyzed<T>) -> PIL {
}
}
}
StatementIdentifier::ProverFunction(_) => {}
StatementIdentifier::ProverFunction(_)
| StatementIdentifier::TraitImplementation(_) => {}
}
}
PIL {
Expand Down
4 changes: 3 additions & 1 deletion pil-analyzer/src/condenser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use powdr_ast::{
visitor::{AllChildren, ExpressionVisitable},
ArrayLiteral, BinaryOperation, BlockExpression, FunctionCall, FunctionKind,
LambdaExpression, LetStatementInsideBlock, Number, Pattern, SourceReference,
TypedExpression, UnaryOperation,
TraitImplementation, TypedExpression, UnaryOperation,
},
};
use powdr_number::{BigUint, FieldElement};
Expand All @@ -48,6 +48,7 @@ pub fn condense<T: FieldElement>(
solved_impls: HashMap<String, HashMap<Vec<Type>, Arc<Expression>>>,
public_declarations: HashMap<String, PublicDeclaration>,
proof_items: &[Expression],
trait_impls: Vec<TraitImplementation<Expression>>,
source_order: Vec<StatementIdentifier>,
auto_added_symbols: HashSet<String>,
) -> Analyzed<T> {
Expand Down Expand Up @@ -194,6 +195,7 @@ pub fn condense<T: FieldElement>(
intermediate_columns,
identities: condensed_identities,
prover_functions,
trait_impls,
source_order,
auto_added_symbols,
}
Expand Down
90 changes: 44 additions & 46 deletions pil-analyzer/src/pil_analyzer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,8 @@ struct PILAnalyzer {
symbol_counters: Option<Counters>,
/// Symbols from the core that were added automatically but will not be printed.
auto_added_symbols: HashSet<String>,
/// All trait implementations found, organized according to their associated trait name.
/// If a trait has no implementations, it is still present.
trait_impls: HashMap<String, Vec<TraitImplementation<Expression>>>,
/// All trait implementations found, in source order.
trait_impls: Vec<TraitImplementation<Expression>>,
}

/// Reads and parses the given path and all its imports.
Expand Down Expand Up @@ -230,15 +229,12 @@ impl PILAnalyzer {
.unwrap_or_else(|err| errors.push(err));
}

for v in self.trait_impls.values() {
for impl_ in v {
impl_
.children()
.try_for_each(|e| {
side_effect_checker::check(&self.definitions, FunctionKind::Pure, e)
})
.unwrap_or_else(|err| errors.push(err));
}
for i in &self.trait_impls {
i.children()
.try_for_each(|e| {
side_effect_checker::check(&self.definitions, FunctionKind::Pure, e)
})
.unwrap_or_else(|err| errors.push(err));
}

// for all proof items, check that they call pure or constr functions
Expand All @@ -261,30 +257,28 @@ impl PILAnalyzer {
// by the statement processor already).
// For Arrays, we also collect the inner expressions and expect them to be field elements.

for (name, trait_impls) in self.trait_impls.iter_mut() {
for trait_impl in self.trait_impls.iter_mut() {
let (_, def) = self
.definitions
.get(name)
.get(&trait_impl.name.to_string())
.expect("Trait definition not found");

let Some(FunctionValueDefinition::TraitDeclaration(trait_decl)) = def else {
unreachable!();
};
for impl_ in trait_impls {
let specialized_types: Vec<_> = impl_
.functions
.iter()
.map(|named_expr| impl_.type_of_function(trait_decl, &named_expr.name))
.collect();

for (named_expr, specialized_type) in
impl_.functions.iter_mut().zip(specialized_types)
{
expressions.push((
Arc::get_mut(&mut named_expr.body).unwrap(),
specialized_type.into(),
));
}

let specialized_types: Vec<_> = trait_impl
.functions
.iter()
.map(|named_expr| trait_impl.type_of_function(trait_decl, &named_expr.name))
.collect();

for (named_expr, specialized_type) in
trait_impl.functions.iter_mut().zip(specialized_types)
{
expressions.push((
Arc::get_mut(&mut named_expr.body).unwrap(),
specialized_type.into(),
));
}
}

Expand Down Expand Up @@ -357,7 +351,18 @@ impl PILAnalyzer {
/// Creates and returns a map for every referenced trait function with concrete type to the
/// corresponding trait implementation function.
fn resolve_trait_impls(&mut self) -> Result<SolvedTraitImpls, Vec<Error>> {
let mut trait_solver = TraitsResolver::new(&self.trait_impls);
let all_traits = self
.definitions
.iter()
.filter_map(|(name, (_, value))| {
if let Some(FunctionValueDefinition::TraitDeclaration(..)) = value {
Some(name.as_str())
} else {
None
}
})
.collect();
let mut trait_solver = TraitsResolver::new(all_traits, &self.trait_impls);

// TODO building this impl map should be different from checking that all trait references
// have an implementation.
Expand All @@ -378,10 +383,7 @@ impl PILAnalyzer {
})
.flat_map(|d| d.children());
let proof_items = self.proof_items.iter();
let trait_impls = self
.trait_impls
.values()
.flat_map(|impls| impls.iter().flat_map(|i| i.children()));
let trait_impls = self.trait_impls.iter().flat_map(|i| i.children());
let mut errors = vec![];
for expr in definitions
.chain(proof_items)
Expand Down Expand Up @@ -413,6 +415,7 @@ impl PILAnalyzer {
solved_impls,
self.public_declarations,
&self.proof_items,
self.trait_impls,
self.source_order,
self.auto_added_symbols,
))
Expand Down Expand Up @@ -472,12 +475,6 @@ impl PILAnalyzer {
match item {
PILItem::Definition(symbol, value) => {
let name = symbol.absolute_name.clone();
if matches!(value, Some(FunctionValueDefinition::TraitDeclaration(_))) {
// Ensure that `trait_impls` has an entry for every trait,
// even if it has no implementations.
// We use this to distinguish generic functions from trait functions.
self.trait_impls.entry(name.clone()).or_default();
}
let is_new = self
.definitions
.insert(name.clone(), (symbol, value))
Expand All @@ -498,11 +495,12 @@ impl PILAnalyzer {
.push(StatementIdentifier::ProofItem(index));
self.proof_items.push(item)
}
PILItem::TraitImplementation(trait_impl) => self
.trait_impls
.entry(trait_impl.name.to_string())
.or_default()
.push(trait_impl),
PILItem::TraitImplementation(trait_impl) => {
let index = self.trait_impls.len();
self.source_order
.push(StatementIdentifier::TraitImplementation(index));
self.trait_impls.push(trait_impl)
}
}
}
}
Expand Down
34 changes: 28 additions & 6 deletions pil-analyzer/src/traits_resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ use powdr_ast::{
TraitImplementation,
},
};
use std::{collections::HashMap, sync::Arc};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};

use crate::type_unifier::Unifier;

Expand All @@ -15,18 +18,31 @@ pub type SolvedTraitImpls = HashMap<String, HashMap<Vec<Type>, Arc<Expression>>>
/// TraitsResolver helps to find the implementation for a given trait function
/// and concrete type arguments.
pub struct TraitsResolver<'a> {
/// All trait names, even if they have no implementation.
traits: HashSet<&'a str>,
/// List of implementations for all traits.
trait_impls: &'a HashMap<String, Vec<TraitImplementation<Expression>>>,
trait_impls: HashMap<String, Vec<&'a TraitImplementation<Expression>>>,
/// Map from trait function names and type arguments to the corresponding trait implementations.
solved_impls: SolvedTraitImpls,
}

impl<'a> TraitsResolver<'a> {
/// Creates a new instance of the resolver.
/// The trait impls need to have a key for every trait name, even if it is not implemented at all.
pub fn new(trait_impls: &'a HashMap<String, Vec<TraitImplementation<Expression>>>) -> Self {
pub fn new(
traits: HashSet<&'a str>,
trait_impls: &'a [TraitImplementation<Expression>],
) -> Self {
let mut impls_by_trait: HashMap<String, Vec<_>> = HashMap::new();
for i in trait_impls {
impls_by_trait
.entry(i.name.to_string())
.or_default()
.push(i);
}
Self {
trait_impls,
traits,
trait_impls: impls_by_trait,
solved_impls: HashMap::new(),
}
}
Expand All @@ -52,8 +68,14 @@ impl<'a> TraitsResolver<'a> {
let Some((trait_decl_name, trait_fn_name)) = reference.name.rsplit_once("::") else {
return Ok(());
};
let Some(trait_impls) = self.trait_impls.get(trait_decl_name) else {
if !self.traits.contains(trait_decl_name) {
// Not a trait function.
return Ok(());
}
let Some(trait_impls) = self.trait_impls.get(trait_decl_name) else {
return Err(format!(
"Could not find an implementation for the trait function {reference}"
));
};

match find_trait_implementation(trait_fn_name, type_args, trait_impls) {
Expand All @@ -80,7 +102,7 @@ impl<'a> TraitsResolver<'a> {
fn find_trait_implementation(
function: &str,
type_args: &[Type],
implementations: &[TraitImplementation<Expression>],
implementations: &[&TraitImplementation<Expression>],
) -> Option<Arc<Expression>> {
let tuple_args = Type::Tuple(TupleType {
items: type_args.to_vec(),
Expand Down
35 changes: 35 additions & 0 deletions pil-analyzer/tests/parse_display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -937,3 +937,38 @@ fn typed_literals() {
let analyzed = analyze_string(input);
assert_eq!(analyzed.to_string(), expected);
}

#[test]
fn traits_and_impls() {
let input = "
trait X<T> {
f: -> T,
g: T -> T,
}
impl X<int> {
f: || 1,
g: |x| x + 1,
}
impl X<()> {
f: || (),
g: |()| (),
}
let a: int = X::f();
";
let expected = r#" trait X<T> {
f: -> T,
g: T -> T,
}
impl X<int> {
f: || 1_int,
g: |x| x + 1_int,
}
impl X<()> {
f: || (),
g: |()| (),
}
let a: int = X::f::<int>();
"#;
let analyzed = analyze_string(input);
assert_eq!(analyzed.to_string(), expected);
}

0 comments on commit 810b9ef

Please sign in to comment.