From bb88217a20ad3a3881af667daa82146346adb8c4 Mon Sep 17 00:00:00 2001 From: aromaa Date: Wed, 3 Jan 2024 20:59:48 +0200 Subject: [PATCH] More accurate capturing --- src/FunctionalInterfaces.Fody/ModuleWeaver.cs | 4 +- .../FunctionalInterfacesGenerator.cs | 52 +++++++++++++++++-- .../ActionTests.cs | 24 ++++----- 3 files changed, 63 insertions(+), 17 deletions(-) diff --git a/src/FunctionalInterfaces.Fody/ModuleWeaver.cs b/src/FunctionalInterfaces.Fody/ModuleWeaver.cs index c969fc8..4a899f1 100644 --- a/src/FunctionalInterfaces.Fody/ModuleWeaver.cs +++ b/src/FunctionalInterfaces.Fody/ModuleWeaver.cs @@ -24,9 +24,9 @@ private void ProcessType(TypeDefinition type) //TODO: Attribute? - string identifierName = method.Name + "_FunctionalInterface_"; + string identifierName = method.Name + "_FunctionalInterface"; - MethodDefinition? functionalInterface = type.Methods.FirstOrDefault(m => m.Name.StartsWith(identifierName)); + MethodDefinition? functionalInterface = type.Methods.FirstOrDefault(m => m.Name == identifierName); if (functionalInterface is not null) { method.Body = functionalInterface.Body; diff --git a/src/FunctionalInterfaces.SourceGenerator/FunctionalInterfacesGenerator.cs b/src/FunctionalInterfaces.SourceGenerator/FunctionalInterfacesGenerator.cs index eb7586d..647667b 100644 --- a/src/FunctionalInterfaces.SourceGenerator/FunctionalInterfacesGenerator.cs +++ b/src/FunctionalInterfaces.SourceGenerator/FunctionalInterfacesGenerator.cs @@ -226,7 +226,14 @@ static List GetHierarchy(SyntaxNode? node) } else if (symbol is IParameterSymbol param) { - writer.WriteLine($"public {param.Type} _{param.Name};"); + if (param.Name == "this") + { + writer.WriteLine($"public {param.Type} _{param.Name};"); + } + else + { + writer.WriteLine($"public {param.Type} {param.Name};"); + } } } } @@ -316,7 +323,9 @@ SyntaxNode TransformCall(SyntaxNode original, InvocationExpressionSyntax invoke, IMethodSymbol? methodSymbolInfo = semanticModel.GetDeclaredSymbol(declaringMethod); if (methodSymbolInfo is not null) { - SyntaxNode declaringMethodBody = declaringMethod.Body!.ReplaceNodes(declaringMethod.Body!.DescendantNodes(), (original, modified) => + SyntaxNode declaringMethodBody = declaringMethod.Body!; + + declaringMethodBody = declaringMethodBody.ReplaceNodes(declaringMethodBody.DescendantNodes(), (original, modified) => { if (modified is InvocationExpressionSyntax inv) { @@ -357,8 +366,45 @@ SyntaxNode TransformCall(SyntaxNode original, InvocationExpressionSyntax invoke, return modified; }); + foreach (SyntaxNode descendantNode in declaringMethodBody.DescendantNodes()) + { + if (descendantNode is InvocationExpressionSyntax inv && inv.ArgumentList.Arguments + .Any(a => a.Expression is InvocationExpressionSyntax { ArgumentList.Arguments.Count: 1 } exp + && exp.ArgumentList.Arguments[0].Expression is IdentifierNameSyntax { Identifier.Text: "__functionalInterface" })) + { + if (dataFlowAnalysis is not null && inv.Parent is ExpressionStatementSyntax expression) + { + List initVariables = new(); + foreach (ISymbol symbol in dataFlowAnalysis.DataFlowsIn) + { + foreach (SyntaxReference reference in symbol.DeclaringSyntaxReferences) + { + SyntaxNode declaration = reference.GetSyntax(); + if (declaration is VariableDeclaratorSyntax) + { + continue; + } + + initVariables.Add(SyntaxFactory.ExpressionStatement( + SyntaxFactory.AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + SyntaxFactory.MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + SyntaxFactory.IdentifierName("__functionalInterface"), + SyntaxFactory.IdentifierName(symbol.Name)), + SyntaxFactory.IdentifierName(symbol.Name)))); + } + } + + declaringMethodBody = declaringMethodBody.InsertNodesBefore(expression, initVariables); + + break; + } + } + } + writer.WriteLine(); - writer.WriteLine($"public {(methodSymbolInfo.IsStatic ? "static " : string.Empty)}{methodSymbolInfo.ReturnType} {methodSymbolInfo.Name}_FunctionalInterface_{lambda.Span.Start}({string.Join(", ", methodSymbolInfo.Parameters)})"); + writer.WriteLine($"private {(methodSymbolInfo.IsStatic ? "static " : string.Empty)}{methodSymbolInfo.ReturnType} {methodSymbolInfo.Name}_FunctionalInterface({string.Join(", ", methodSymbolInfo.Parameters)})"); writer.WriteLine("{"); writer.Indent++; diff --git a/src/FunctionalInterfaces.TestAssembly/ActionTests.cs b/src/FunctionalInterfaces.TestAssembly/ActionTests.cs index 5dbb97e..7c83520 100644 --- a/src/FunctionalInterfaces.TestAssembly/ActionTests.cs +++ b/src/FunctionalInterfaces.TestAssembly/ActionTests.cs @@ -28,20 +28,20 @@ public static void CallActionWithCapturedIntReferenceOutside() }); } - public static void CallActionWithCapturedIntTwoTimes() - { - int param = 50; + ////public static void CallActionWithCapturedIntTwoTimes() + ////{ + //// int param = 50; - ActionTests.Invoke(() => - { - Assert.Equal(50, param); - }); + //// ActionTests.Invoke(() => + //// { + //// Assert.Equal(50, param); + //// }); - ActionTests.Invoke(() => - { - Assert.Equal(50, param); - }); - } + //// ActionTests.Invoke(() => + //// { + //// Assert.Equal(50, param); + //// }); + ////} public static void CallVirtualActionWithCapturedInt() {