Skip to content

Commit

Permalink
Merge pull request #107 from ZvonimirMatic/master
Browse files Browse the repository at this point in the history
Source generation optimization
  • Loading branch information
koenbeuk authored Jun 12, 2024
2 parents ffbcc8c + 6a31597 commit 811e23c
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 136 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;

namespace EntityFrameworkCore.Projectables.Generator;

public class MemberDeclarationSyntaxAndCompilationEqualityComparer : IEqualityComparer<(MemberDeclarationSyntax, Compilation)>
{
public bool Equals((MemberDeclarationSyntax, Compilation) x, (MemberDeclarationSyntax, Compilation) y)
{
return GetMemberDeclarationSyntaxAndCompilationName(x.Item1, x.Item2) == GetMemberDeclarationSyntaxAndCompilationName(y.Item1, y.Item2);
}

public int GetHashCode((MemberDeclarationSyntax, Compilation) obj)
{
return GetMemberDeclarationSyntaxAndCompilationName(obj.Item1, obj.Item2).GetHashCode();
}

public static string GetMemberDeclarationSyntaxAndCompilationName(MemberDeclarationSyntax memberDeclarationSyntax, Compilation compilation)
{
return $"{compilation.AssemblyName}:{MemberDeclarationSyntaxEqualityComparer.GetMemberDeclarationSyntaxName(memberDeclarationSyntax)}";
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
using System.Text;
using Microsoft.CodeAnalysis.CSharp.Syntax;

namespace EntityFrameworkCore.Projectables.Generator;

public class MemberDeclarationSyntaxEqualityComparer : IEqualityComparer<MemberDeclarationSyntax>
{
public bool Equals(MemberDeclarationSyntax x, MemberDeclarationSyntax y)
{
return GetMemberDeclarationSyntaxName(x) == GetMemberDeclarationSyntaxName(y);
}

public int GetHashCode(MemberDeclarationSyntax obj)
{
return GetMemberDeclarationSyntaxName(obj).GetHashCode();
}

public static string GetMemberDeclarationSyntaxName(MemberDeclarationSyntax memberDeclaration)
{
var sb = new StringBuilder();

// Get the member name
if (memberDeclaration is MethodDeclarationSyntax methodDeclaration)
{
sb.Append(methodDeclaration.Identifier.Text);
}
else if (memberDeclaration is PropertyDeclarationSyntax propertyDeclaration)
{
sb.Append(propertyDeclaration.Identifier.Text);
}
else if (memberDeclaration is FieldDeclarationSyntax fieldDeclaration)
{
sb.Append(string.Join(", ", fieldDeclaration.Declaration.Variables.Select(v => v.Identifier.Text)));
}

// Traverse up the tree to get containing type names
var parent = memberDeclaration.Parent;
while (parent != null)
{
switch (parent)
{
case NamespaceDeclarationSyntax namespaceDeclaration:
sb.Insert(0, namespaceDeclaration.Name + ".");
break;
case ClassDeclarationSyntax classDeclaration:
sb.Insert(0, classDeclaration.Identifier.Text + ".");
break;
case StructDeclarationSyntax structDeclaration:
sb.Insert(0, structDeclaration.Identifier.Text + ".");
break;
case InterfaceDeclarationSyntax interfaceDeclaration:
sb.Insert(0, interfaceDeclaration.Identifier.Text + ".");
break;
case EnumDeclarationSyntax enumDeclaration:
sb.Insert(0, enumDeclaration.Identifier.Text + ".");
break;
}
parent = parent.Parent;
}

return sb.ToString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,7 @@
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Text;
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Linq;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;

namespace EntityFrameworkCore.Projectables.Generator
Expand Down Expand Up @@ -41,167 +33,127 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
{
// Do a simple filter for members
IncrementalValuesProvider<MemberDeclarationSyntax> memberDeclarations = context.SyntaxProvider
.CreateSyntaxProvider(
predicate: static (s, _) => s is MemberDeclarationSyntax m && m.AttributeLists.Count > 0,
transform: static (c, _) => GetSemanticTargetForGeneration(c))
.Where(static m => m is not null)!; // filter out attributed enums that we don't care about
.ForAttributeWithMetadataName(
ProjectablesAttributeName,
predicate: static (s, _) => s is MemberDeclarationSyntax,
transform: static (c, _) => (MemberDeclarationSyntax)c.TargetNode)
.WithComparer(new MemberDeclarationSyntaxEqualityComparer());

// Combine the selected enums with the `Compilation`
IncrementalValueProvider<(Compilation, ImmutableArray<MemberDeclarationSyntax>)> compilationAndEnums
= context.CompilationProvider.Combine(memberDeclarations.Collect());
IncrementalValuesProvider<(MemberDeclarationSyntax, Compilation)> compilationAndMemberPairs = memberDeclarations
.Combine(context.CompilationProvider)
.WithComparer(new MemberDeclarationSyntaxAndCompilationEqualityComparer());

// Generate the source using the compilation and enums
context.RegisterImplementationSourceOutput(compilationAndEnums,
context.RegisterImplementationSourceOutput(compilationAndMemberPairs,
static (spc, source) => Execute(source.Item1, source.Item2, spc));
}

static MemberDeclarationSyntax? GetSemanticTargetForGeneration(GeneratorSyntaxContext context)
static void Execute(MemberDeclarationSyntax member, Compilation compilation, SourceProductionContext context)
{
// we know the node is a MemberDeclarationSyntax
var memberDeclarationSyntax = (MemberDeclarationSyntax)context.Node;
var projectable = ProjectableInterpreter.GetDescriptor(compilation, member, context);

// loop through all the attributes on the method
foreach (var attributeListSyntax in memberDeclarationSyntax.AttributeLists)
if (projectable is null)
{
foreach (var attributeSyntax in attributeListSyntax.Attributes)
{
if (context.SemanticModel.GetSymbolInfo(attributeSyntax).Symbol is not IMethodSymbol attributeSymbol)
{
// weird, we couldn't get the symbol, ignore it
continue;
}

var attributeContainingTypeSymbol = attributeSymbol.ContainingType;
var fullName = attributeContainingTypeSymbol.ToDisplayString();

// Is the attribute the [Projcetable] attribute?
if (fullName == ProjectablesAttributeName)
{
// return the enum
return memberDeclarationSyntax;
}
}
}

// we didn't find the attribute we were looking for
return null;
}

static void Execute(Compilation compilation, ImmutableArray<MemberDeclarationSyntax> members, SourceProductionContext context)
{
if (members.IsDefaultOrEmpty)
{
// nothing to do yet
return;
}

var projectables = members
.Select(x => ProjectableInterpreter.GetDescriptor(compilation, x, context))
.Where(x => x is not null)
.Select(x => x!);

var resultBuilder = new StringBuilder();

foreach (var projectable in projectables)
if (projectable.MemberName is null)
{
if (projectable.MemberName is null)
{
throw new InvalidOperationException("Expected a memberName here");
}
throw new InvalidOperationException("Expected a memberName here");
}

var generatedClassName = ProjectionExpressionClassNameGenerator.GenerateName(projectable.ClassNamespace, projectable.NestedInClassNames, projectable.MemberName);
var generatedFileName = projectable.ClassTypeParameterList is not null ? $"{generatedClassName}-{projectable.ClassTypeParameterList.ChildNodes().Count()}.g.cs" : $"{generatedClassName}.g.cs";
var generatedClassName = ProjectionExpressionClassNameGenerator.GenerateName(projectable.ClassNamespace, projectable.NestedInClassNames, projectable.MemberName);
var generatedFileName = projectable.ClassTypeParameterList is not null ? $"{generatedClassName}-{projectable.ClassTypeParameterList.ChildNodes().Count()}.g.cs" : $"{generatedClassName}.g.cs";

var classSyntax = ClassDeclaration(generatedClassName)
.WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword)))
.WithTypeParameterList(projectable.ClassTypeParameterList)
.WithConstraintClauses(projectable.ClassConstraintClauses ?? List<TypeParameterConstraintClauseSyntax>())
.AddAttributeLists(
AttributeList()
.AddAttributes(_editorBrowsableAttribute)
)
.AddMembers(
MethodDeclaration(
GenericName(
Identifier("global::System.Linq.Expressions.Expression"),
TypeArgumentList(
SingletonSeparatedList(
(TypeSyntax)GenericName(
Identifier("global::System.Func"),
GetLambdaTypeArgumentListSyntax(projectable)
)
var classSyntax = ClassDeclaration(generatedClassName)
.WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword)))
.WithTypeParameterList(projectable.ClassTypeParameterList)
.WithConstraintClauses(projectable.ClassConstraintClauses ?? List<TypeParameterConstraintClauseSyntax>())
.AddAttributeLists(
AttributeList()
.AddAttributes(_editorBrowsableAttribute)
)
.AddMembers(
MethodDeclaration(
GenericName(
Identifier("global::System.Linq.Expressions.Expression"),
TypeArgumentList(
SingletonSeparatedList(
(TypeSyntax)GenericName(
Identifier("global::System.Func"),
GetLambdaTypeArgumentListSyntax(projectable)
)
)
),
"Expression"
)
.WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword)))
.WithTypeParameterList(projectable.TypeParameterList)
.WithConstraintClauses(projectable.ConstraintClauses ?? List<TypeParameterConstraintClauseSyntax>())
.WithBody(
Block(
ReturnStatement(
ParenthesizedLambdaExpression(
projectable.ParametersList ?? ParameterList(),
null,
projectable.ExpressionBody
)
)
),
"Expression"
)
.WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword)))
.WithTypeParameterList(projectable.TypeParameterList)
.WithConstraintClauses(projectable.ConstraintClauses ?? List<TypeParameterConstraintClauseSyntax>())
.WithBody(
Block(
ReturnStatement(
ParenthesizedLambdaExpression(
projectable.ParametersList ?? ParameterList(),
null,
projectable.ExpressionBody
)
)
)
);
)
)
);

#nullable disable

var compilationUnit = CompilationUnit();
var compilationUnit = CompilationUnit();

foreach (var usingDirective in projectable.UsingDirectives)
{
compilationUnit = compilationUnit.AddUsings(usingDirective);
}
foreach (var usingDirective in projectable.UsingDirectives)
{
compilationUnit = compilationUnit.AddUsings(usingDirective);
}

if (projectable.ClassNamespace is not null)
{
compilationUnit = compilationUnit.AddUsings(
UsingDirective(
ParseName(projectable.ClassNamespace)
)
);
}
if (projectable.ClassNamespace is not null)
{
compilationUnit = compilationUnit.AddUsings(
UsingDirective(
ParseName(projectable.ClassNamespace)
)
);
}

compilationUnit = compilationUnit
.AddMembers(
NamespaceDeclaration(
ParseName("EntityFrameworkCore.Projectables.Generated")
).AddMembers(classSyntax)
compilationUnit = compilationUnit
.AddMembers(
NamespaceDeclaration(
ParseName("EntityFrameworkCore.Projectables.Generated")
).AddMembers(classSyntax)
)
.WithLeadingTrivia(
TriviaList(
Comment("// <auto-generated/>"),
Trivia(NullableDirectiveTrivia(Token(SyntaxKind.DisableKeyword), true))
)
.WithLeadingTrivia(
TriviaList(
Comment("// <auto-generated/>"),
Trivia(NullableDirectiveTrivia(Token(SyntaxKind.DisableKeyword), true))
)
);
);


context.AddSource(generatedFileName, SourceText.From(compilationUnit.NormalizeWhitespace().ToFullString(), Encoding.UTF8));
context.AddSource(generatedFileName, SourceText.From(compilationUnit.NormalizeWhitespace().ToFullString(), Encoding.UTF8));

static TypeArgumentListSyntax GetLambdaTypeArgumentListSyntax(ProjectableDescriptor projectable)
{
var lambdaTypeArguments = TypeArgumentList(
SeparatedList(
// TODO: Document where clause
projectable.ParametersList?.Parameters.Where(p => p.Type is not null).Select(p => p.Type!)
)
);

static TypeArgumentListSyntax GetLambdaTypeArgumentListSyntax(ProjectableDescriptor projectable)
if (projectable.ReturnTypeName is not null)
{
var lambdaTypeArguments = TypeArgumentList(
SeparatedList(
// TODO: Document where clause
projectable.ParametersList?.Parameters.Where(p => p.Type is not null).Select(p => p.Type!)
)
);

if (projectable.ReturnTypeName is not null)
{
lambdaTypeArguments = lambdaTypeArguments.AddArguments(ParseTypeName(projectable.ReturnTypeName));
}

return lambdaTypeArguments;
lambdaTypeArguments = lambdaTypeArguments.AddArguments(ParseTypeName(projectable.ReturnTypeName));
}

return lambdaTypeArguments;
}
}
}
Expand Down

0 comments on commit 811e23c

Please sign in to comment.