Skip to content

Commit

Permalink
More accurate capturing
Browse files Browse the repository at this point in the history
  • Loading branch information
aromaa committed Jan 3, 2024
1 parent 48139db commit bb88217
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 17 deletions.
4 changes: 2 additions & 2 deletions src/FunctionalInterfaces.Fody/ModuleWeaver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ private void ProcessType(TypeDefinition type)

//TODO: Attribute?

string identifierName = method.Name + "_FunctionalInterface_";
string identifierName = method.Name + "_FunctionalInterface";

MethodDefinition? functionalInterface = type.Methods.FirstOrDefault(m => m.Name.StartsWith(identifierName));
MethodDefinition? functionalInterface = type.Methods.FirstOrDefault(m => m.Name == identifierName);
if (functionalInterface is not null)
{
method.Body = functionalInterface.Body;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,14 @@ static List<ClassDeclarationSyntax> GetHierarchy(SyntaxNode? node)
}
else if (symbol is IParameterSymbol param)
{
writer.WriteLine($"public {param.Type} _{param.Name};");
if (param.Name == "this")
{
writer.WriteLine($"public {param.Type} _{param.Name};");
}
else
{
writer.WriteLine($"public {param.Type} {param.Name};");
}
}
}
}
Expand Down Expand Up @@ -316,7 +323,9 @@ SyntaxNode TransformCall(SyntaxNode original, InvocationExpressionSyntax invoke,
IMethodSymbol? methodSymbolInfo = semanticModel.GetDeclaredSymbol(declaringMethod);
if (methodSymbolInfo is not null)
{
SyntaxNode declaringMethodBody = declaringMethod.Body!.ReplaceNodes(declaringMethod.Body!.DescendantNodes(), (original, modified) =>
SyntaxNode declaringMethodBody = declaringMethod.Body!;

declaringMethodBody = declaringMethodBody.ReplaceNodes(declaringMethodBody.DescendantNodes(), (original, modified) =>
{
if (modified is InvocationExpressionSyntax inv)
{
Expand Down Expand Up @@ -357,8 +366,45 @@ SyntaxNode TransformCall(SyntaxNode original, InvocationExpressionSyntax invoke,
return modified;
});

foreach (SyntaxNode descendantNode in declaringMethodBody.DescendantNodes())
{
if (descendantNode is InvocationExpressionSyntax inv && inv.ArgumentList.Arguments
.Any(a => a.Expression is InvocationExpressionSyntax { ArgumentList.Arguments.Count: 1 } exp
&& exp.ArgumentList.Arguments[0].Expression is IdentifierNameSyntax { Identifier.Text: "__functionalInterface" }))
{
if (dataFlowAnalysis is not null && inv.Parent is ExpressionStatementSyntax expression)
{
List<ExpressionStatementSyntax> initVariables = new();
foreach (ISymbol symbol in dataFlowAnalysis.DataFlowsIn)
{
foreach (SyntaxReference reference in symbol.DeclaringSyntaxReferences)
{
SyntaxNode declaration = reference.GetSyntax();
if (declaration is VariableDeclaratorSyntax)
{
continue;
}

initVariables.Add(SyntaxFactory.ExpressionStatement(
SyntaxFactory.AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
SyntaxFactory.IdentifierName("__functionalInterface"),
SyntaxFactory.IdentifierName(symbol.Name)),
SyntaxFactory.IdentifierName(symbol.Name))));
}
}

declaringMethodBody = declaringMethodBody.InsertNodesBefore(expression, initVariables);

break;
}
}
}

writer.WriteLine();
writer.WriteLine($"public {(methodSymbolInfo.IsStatic ? "static " : string.Empty)}{methodSymbolInfo.ReturnType} {methodSymbolInfo.Name}_FunctionalInterface_{lambda.Span.Start}({string.Join(", ", methodSymbolInfo.Parameters)})");
writer.WriteLine($"private {(methodSymbolInfo.IsStatic ? "static " : string.Empty)}{methodSymbolInfo.ReturnType} {methodSymbolInfo.Name}_FunctionalInterface({string.Join(", ", methodSymbolInfo.Parameters)})");
writer.WriteLine("{");
writer.Indent++;

Expand Down
24 changes: 12 additions & 12 deletions src/FunctionalInterfaces.TestAssembly/ActionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,20 @@ public static void CallActionWithCapturedIntReferenceOutside()
});
}

public static void CallActionWithCapturedIntTwoTimes()
{
int param = 50;
////public static void CallActionWithCapturedIntTwoTimes()
////{
//// int param = 50;

ActionTests.Invoke(() =>
{
Assert.Equal(50, param);
});
//// ActionTests.Invoke(() =>
//// {
//// Assert.Equal(50, param);
//// });

ActionTests.Invoke(() =>
{
Assert.Equal(50, param);
});
}
//// ActionTests.Invoke(() =>
//// {
//// Assert.Equal(50, param);
//// });
////}

public static void CallVirtualActionWithCapturedInt()
{
Expand Down

0 comments on commit bb88217

Please sign in to comment.