diff --git a/ast/src/analyzed/display.rs b/ast/src/analyzed/display.rs index 0cf14a5e3..0ef32c5f0 100644 --- a/ast/src/analyzed/display.rs +++ b/ast/src/analyzed/display.rs @@ -154,6 +154,9 @@ impl Display for Analyzed { StatementIdentifier::ProverFunction(i) => { writeln_indented(f, format!("{};", &self.prover_functions[*i]))?; } + StatementIdentifier::TraitImplementation(i) => { + writeln_indented(f, format!("{}", self.trait_impls[*i]))?; + } } } diff --git a/ast/src/analyzed/mod.rs b/ast/src/analyzed/mod.rs index 46bd15d81..36371d4f0 100644 --- a/ast/src/analyzed/mod.rs +++ b/ast/src/analyzed/mod.rs @@ -21,7 +21,7 @@ pub use crate::parsed::BinaryOperator; pub use crate::parsed::UnaryOperator; use crate::parsed::{ self, ArrayExpression, EnumDeclaration, EnumVariant, NamedType, TraitDeclaration, - TypeDeclaration, + TraitImplementation, TypeDeclaration, }; #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] @@ -29,10 +29,12 @@ pub enum StatementIdentifier { /// Either an intermediate column or a definition. Definition(String), PublicDeclaration(String), - /// Index into the vector of proof items. + /// Index into the vector of proof items / identities. ProofItem(usize), /// Index into the vector of prover functions. ProverFunction(usize), + /// Index into the vector of trait implementations. + TraitImplementation(usize), } #[derive(Debug, Clone, Default, Serialize, Deserialize, JsonSchema)] @@ -43,6 +45,7 @@ pub struct Analyzed { pub intermediate_columns: HashMap>)>, pub identities: Vec>>>, pub prover_functions: Vec, + pub trait_impls: Vec>, /// The order in which definitions and identities /// appear in the source. pub source_order: Vec, diff --git a/backend-utils/src/lib.rs b/backend-utils/src/lib.rs index 2b095e193..d311b663b 100644 --- a/backend-utils/src/lib.rs +++ b/backend-utils/src/lib.rs @@ -211,7 +211,8 @@ fn split_by_namespace( }, } } - StatementIdentifier::ProverFunction(_) => None, + StatementIdentifier::ProverFunction(_) + | StatementIdentifier::TraitImplementation(_) => None, }) // collect into a map .fold(Default::default(), |mut acc, (namespace, statement)| { diff --git a/backend/src/estark/json_exporter/expression_counter.rs b/backend/src/estark/json_exporter/expression_counter.rs index 659b38054..41bea68fb 100644 --- a/backend/src/estark/json_exporter/expression_counter.rs +++ b/backend/src/estark/json_exporter/expression_counter.rs @@ -29,7 +29,8 @@ pub fn compute_intermediate_expression_ids(analyzed: &Analyzed) -> HashMap analyzed.public_declarations[name].expression_count() } StatementIdentifier::ProofItem(id) => analyzed.identities[*id].expression_count(), - StatementIdentifier::ProverFunction(_) => 0, + StatementIdentifier::ProverFunction(_) + | StatementIdentifier::TraitImplementation(_) => 0, } } ids diff --git a/backend/src/estark/json_exporter/mod.rs b/backend/src/estark/json_exporter/mod.rs index 6c806b604..0bd378a0f 100644 --- a/backend/src/estark/json_exporter/mod.rs +++ b/backend/src/estark/json_exporter/mod.rs @@ -141,7 +141,8 @@ pub fn export(analyzed: &Analyzed) -> PIL { } } } - StatementIdentifier::ProverFunction(_) => {} + StatementIdentifier::ProverFunction(_) + | StatementIdentifier::TraitImplementation(_) => {} } } PIL { diff --git a/pil-analyzer/src/condenser.rs b/pil-analyzer/src/condenser.rs index 3bbacb216..52a786cd3 100644 --- a/pil-analyzer/src/condenser.rs +++ b/pil-analyzer/src/condenser.rs @@ -27,7 +27,7 @@ use powdr_ast::{ visitor::{AllChildren, ExpressionVisitable}, ArrayLiteral, BinaryOperation, BlockExpression, FunctionCall, FunctionKind, LambdaExpression, LetStatementInsideBlock, Number, Pattern, SourceReference, - TypedExpression, UnaryOperation, + TraitImplementation, TypedExpression, UnaryOperation, }, }; use powdr_number::{BigUint, FieldElement}; @@ -48,6 +48,7 @@ pub fn condense( solved_impls: HashMap, Arc>>, public_declarations: HashMap, proof_items: &[Expression], + trait_impls: Vec>, source_order: Vec, auto_added_symbols: HashSet, ) -> Analyzed { @@ -194,6 +195,7 @@ pub fn condense( intermediate_columns, identities: condensed_identities, prover_functions, + trait_impls, source_order, auto_added_symbols, } diff --git a/pil-analyzer/src/pil_analyzer.rs b/pil-analyzer/src/pil_analyzer.rs index 442555e61..5cb0ed9f0 100644 --- a/pil-analyzer/src/pil_analyzer.rs +++ b/pil-analyzer/src/pil_analyzer.rs @@ -74,9 +74,8 @@ struct PILAnalyzer { symbol_counters: Option, /// Symbols from the core that were added automatically but will not be printed. auto_added_symbols: HashSet, - /// All trait implementations found, organized according to their associated trait name. - /// If a trait has no implementations, it is still present. - trait_impls: HashMap>>, + /// All trait implementations found, in source order. + trait_impls: Vec>, } /// Reads and parses the given path and all its imports. @@ -230,15 +229,12 @@ impl PILAnalyzer { .unwrap_or_else(|err| errors.push(err)); } - for v in self.trait_impls.values() { - for impl_ in v { - impl_ - .children() - .try_for_each(|e| { - side_effect_checker::check(&self.definitions, FunctionKind::Pure, e) - }) - .unwrap_or_else(|err| errors.push(err)); - } + for i in &self.trait_impls { + i.children() + .try_for_each(|e| { + side_effect_checker::check(&self.definitions, FunctionKind::Pure, e) + }) + .unwrap_or_else(|err| errors.push(err)); } // for all proof items, check that they call pure or constr functions @@ -261,30 +257,28 @@ impl PILAnalyzer { // by the statement processor already). // For Arrays, we also collect the inner expressions and expect them to be field elements. - for (name, trait_impls) in self.trait_impls.iter_mut() { + for trait_impl in self.trait_impls.iter_mut() { let (_, def) = self .definitions - .get(name) + .get(&trait_impl.name.to_string()) .expect("Trait definition not found"); - let Some(FunctionValueDefinition::TraitDeclaration(trait_decl)) = def else { unreachable!(); }; - for impl_ in trait_impls { - let specialized_types: Vec<_> = impl_ - .functions - .iter() - .map(|named_expr| impl_.type_of_function(trait_decl, &named_expr.name)) - .collect(); - - for (named_expr, specialized_type) in - impl_.functions.iter_mut().zip(specialized_types) - { - expressions.push(( - Arc::get_mut(&mut named_expr.body).unwrap(), - specialized_type.into(), - )); - } + + let specialized_types: Vec<_> = trait_impl + .functions + .iter() + .map(|named_expr| trait_impl.type_of_function(trait_decl, &named_expr.name)) + .collect(); + + for (named_expr, specialized_type) in + trait_impl.functions.iter_mut().zip(specialized_types) + { + expressions.push(( + Arc::get_mut(&mut named_expr.body).unwrap(), + specialized_type.into(), + )); } } @@ -357,7 +351,18 @@ impl PILAnalyzer { /// Creates and returns a map for every referenced trait function with concrete type to the /// corresponding trait implementation function. fn resolve_trait_impls(&mut self) -> Result> { - let mut trait_solver = TraitsResolver::new(&self.trait_impls); + let all_traits = self + .definitions + .iter() + .filter_map(|(name, (_, value))| { + if let Some(FunctionValueDefinition::TraitDeclaration(..)) = value { + Some(name.as_str()) + } else { + None + } + }) + .collect(); + let mut trait_solver = TraitsResolver::new(all_traits, &self.trait_impls); // TODO building this impl map should be different from checking that all trait references // have an implementation. @@ -378,10 +383,7 @@ impl PILAnalyzer { }) .flat_map(|d| d.children()); let proof_items = self.proof_items.iter(); - let trait_impls = self - .trait_impls - .values() - .flat_map(|impls| impls.iter().flat_map(|i| i.children())); + let trait_impls = self.trait_impls.iter().flat_map(|i| i.children()); let mut errors = vec![]; for expr in definitions .chain(proof_items) @@ -413,6 +415,7 @@ impl PILAnalyzer { solved_impls, self.public_declarations, &self.proof_items, + self.trait_impls, self.source_order, self.auto_added_symbols, )) @@ -472,12 +475,6 @@ impl PILAnalyzer { match item { PILItem::Definition(symbol, value) => { let name = symbol.absolute_name.clone(); - if matches!(value, Some(FunctionValueDefinition::TraitDeclaration(_))) { - // Ensure that `trait_impls` has an entry for every trait, - // even if it has no implementations. - // We use this to distinguish generic functions from trait functions. - self.trait_impls.entry(name.clone()).or_default(); - } let is_new = self .definitions .insert(name.clone(), (symbol, value)) @@ -498,11 +495,12 @@ impl PILAnalyzer { .push(StatementIdentifier::ProofItem(index)); self.proof_items.push(item) } - PILItem::TraitImplementation(trait_impl) => self - .trait_impls - .entry(trait_impl.name.to_string()) - .or_default() - .push(trait_impl), + PILItem::TraitImplementation(trait_impl) => { + let index = self.trait_impls.len(); + self.source_order + .push(StatementIdentifier::TraitImplementation(index)); + self.trait_impls.push(trait_impl) + } } } } diff --git a/pil-analyzer/src/traits_resolver.rs b/pil-analyzer/src/traits_resolver.rs index c37fa8fba..e8a64ec69 100644 --- a/pil-analyzer/src/traits_resolver.rs +++ b/pil-analyzer/src/traits_resolver.rs @@ -5,7 +5,10 @@ use powdr_ast::{ TraitImplementation, }, }; -use std::{collections::HashMap, sync::Arc}; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; use crate::type_unifier::Unifier; @@ -15,8 +18,10 @@ pub type SolvedTraitImpls = HashMap, Arc>> /// TraitsResolver helps to find the implementation for a given trait function /// and concrete type arguments. pub struct TraitsResolver<'a> { + /// All trait names, even if they have no implementation. + traits: HashSet<&'a str>, /// List of implementations for all traits. - trait_impls: &'a HashMap>>, + trait_impls: HashMap>>, /// Map from trait function names and type arguments to the corresponding trait implementations. solved_impls: SolvedTraitImpls, } @@ -24,9 +29,20 @@ pub struct TraitsResolver<'a> { impl<'a> TraitsResolver<'a> { /// Creates a new instance of the resolver. /// The trait impls need to have a key for every trait name, even if it is not implemented at all. - pub fn new(trait_impls: &'a HashMap>>) -> Self { + pub fn new( + traits: HashSet<&'a str>, + trait_impls: &'a [TraitImplementation], + ) -> Self { + let mut impls_by_trait: HashMap> = HashMap::new(); + for i in trait_impls { + impls_by_trait + .entry(i.name.to_string()) + .or_default() + .push(i); + } Self { - trait_impls, + traits, + trait_impls: impls_by_trait, solved_impls: HashMap::new(), } } @@ -52,8 +68,14 @@ impl<'a> TraitsResolver<'a> { let Some((trait_decl_name, trait_fn_name)) = reference.name.rsplit_once("::") else { return Ok(()); }; - let Some(trait_impls) = self.trait_impls.get(trait_decl_name) else { + if !self.traits.contains(trait_decl_name) { + // Not a trait function. return Ok(()); + } + let Some(trait_impls) = self.trait_impls.get(trait_decl_name) else { + return Err(format!( + "Could not find an implementation for the trait function {reference}" + )); }; match find_trait_implementation(trait_fn_name, type_args, trait_impls) { @@ -80,7 +102,7 @@ impl<'a> TraitsResolver<'a> { fn find_trait_implementation( function: &str, type_args: &[Type], - implementations: &[TraitImplementation], + implementations: &[&TraitImplementation], ) -> Option> { let tuple_args = Type::Tuple(TupleType { items: type_args.to_vec(), diff --git a/pil-analyzer/tests/parse_display.rs b/pil-analyzer/tests/parse_display.rs index 3b5fe6647..5f380805f 100644 --- a/pil-analyzer/tests/parse_display.rs +++ b/pil-analyzer/tests/parse_display.rs @@ -937,3 +937,38 @@ fn typed_literals() { let analyzed = analyze_string(input); assert_eq!(analyzed.to_string(), expected); } + +#[test] +fn traits_and_impls() { + let input = " + trait X { + f: -> T, + g: T -> T, + } + impl X { + f: || 1, + g: |x| x + 1, + } + impl X<()> { + f: || (), + g: |()| (), + } + let a: int = X::f(); + "; + let expected = r#" trait X { + f: -> T, + g: T -> T, + } + impl X { + f: || 1_int, + g: |x| x + 1_int, + } + impl X<()> { + f: || (), + g: |()| (), + } + let a: int = X::f::(); +"#; + let analyzed = analyze_string(input); + assert_eq!(analyzed.to_string(), expected); +}