Skip to content

Commit

Permalink
Improve detection of IStructId namespace from compilation
Browse files Browse the repository at this point in the history
The existing approach of using analyzer config options breaks if there are transitive project references involved, since the options will contain potentially the wrong namespace (the one for the project being built rather than the one where IStructId actually exists).

This changes the approach to looking up the types by their metadata name plus a codegen attribute which we assume users won't be using for their own code (even if they do happen to use `IStructId` for some other purpose).
  • Loading branch information
kzu committed Dec 22, 2024
1 parent a727008 commit 513689a
Show file tree
Hide file tree
Showing 12 changed files with 89 additions and 62 deletions.
17 changes: 11 additions & 6 deletions src/StructId.Analyzer/AnalysisExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ public static CSharpParseOptions GetParseOptions(this Compilation compilation)
=> (CSharpParseOptions?)compilation.SyntaxTrees.FirstOrDefault()?.Options ??
CSharpParseOptions.Default.WithLanguageVersion(LanguageVersion.Latest);

public static bool IsGeneratedByStructId(this ISymbol symbol)
=> symbol.GetAttributes().Any(a
=> a.AttributeClass?.Name == "GeneratedCodeAttribute" &&
a.ConstructorArguments.Select(c => c.Value).OfType<string>().Any(v => v == nameof(StructId)));

/// <summary>
/// Checks whether the <paramref name="this"/> type inherits or implements the
/// <paramref name="baseTypeOrInterface"/> type, even if it's a generic type.
Expand Down Expand Up @@ -62,12 +67,6 @@ @this is INamedTypeSymbol namedActual &&
return Is(@this.BaseType, baseTypeOrInterface, looseGenerics);
}

public static string GetStructIdNamespace(this AnalyzerConfigOptions options)
=> options.TryGetValue("build_property.StructIdNamespace", out var ns) && !string.IsNullOrEmpty(ns) ? ns : "StructId";

public static IncrementalValueProvider<string> GetStructIdNamespace(this IncrementalValueProvider<AnalyzerConfigOptionsProvider> options)
=> options.Select((x, _) => x.GlobalOptions.TryGetValue("build_property.StructIdNamespace", out var ns) ? ns : "StructId");

public static bool ImplementsExplicitly(this INamedTypeSymbol namedTypeSymbol, INamedTypeSymbol interfaceTypeSymbol)
{
if (interfaceTypeSymbol.IsUnboundGenericType && interfaceTypeSymbol.TypeParameters.Length == 1)
Expand Down Expand Up @@ -156,13 +155,19 @@ public static string ToFileName(this ITypeSymbol type)

public static bool IsStructId(this ITypeSymbol type) => type.AllInterfaces.Any(x => x.Name == "IStructId");

public static bool IsValueTemplate(this INamedTypeSymbol symbol)
=> symbol.GetAttributes().Any(IsValueTemplate);

public static bool IsValueTemplate(this AttributeData attribute)
=> attribute.AttributeClass?.Name == "TValue" ||
attribute.AttributeClass?.Name == "TValueAttribute";

public static bool IsValueTemplate(this AttributeSyntax attribute)
=> attribute.Name.ToString() == "TValue" || attribute.Name.ToString() == "TValueAttribute";

public static bool IsStructIdTemplate(this INamedTypeSymbol symbol)
=> symbol.GetAttributes().Any(IsStructIdTemplate);

public static bool IsStructIdTemplate(this AttributeData attribute)
=> attribute.AttributeClass?.Name == "TStructId" ||
attribute.AttributeClass?.Name == "TStructIdAttribute";
Expand Down
5 changes: 1 addition & 4 deletions src/StructId.Analyzer/BaseGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,8 @@ protected record struct TemplateArgs(INamedTypeSymbol TSelf, INamedTypeSymbol TV

public virtual void Initialize(IncrementalGeneratorInitializationContext context)
{
var structIdNamespace = context.AnalyzerConfigOptionsProvider.GetStructIdNamespace();

var known = context.CompilationProvider
.Combine(structIdNamespace)
.Select((x, _) => new KnownTypes(x.Left, x.Right));
.Select((x, _) => new KnownTypes(x));

// Locate the required type
var types = context.CompilationProvider
Expand Down
40 changes: 32 additions & 8 deletions src/StructId.Analyzer/CodeTemplate.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ public static string Apply(string template, string structIdType, string valueTyp

public static string Apply(string template, string valueType, bool normalizeWhitespace = false)
{
var applied = ApplyImpl(Parse(template), valueType);
var applied = ApplyValueImpl(Parse(template), valueType);

return normalizeWhitespace ?
applied.NormalizeWhitespace().ToFullString().Trim() :
applied.ToFullString().Trim();
}

public static SyntaxNode ApplyValue(this SyntaxNode node, INamedTypeSymbol valueType) => ApplyImpl(node, valueType.ToFullName());
public static SyntaxNode ApplyValue(this SyntaxNode node, INamedTypeSymbol valueType) => ApplyValueImpl(node, valueType.ToFullName());

public static SyntaxNode Apply(this SyntaxNode node, INamedTypeSymbol structId)
{
Expand All @@ -59,7 +59,7 @@ public static SyntaxNode Apply(this SyntaxNode node, INamedTypeSymbol structId)
return ApplyImpl(root, structId.Name, tid, targetNamespace, corens);
}

static SyntaxNode ApplyImpl(this SyntaxNode node, string valueType)
static SyntaxNode ApplyValueImpl(this SyntaxNode node, string valueType)
{
var root = node.SyntaxTree.GetCompilationUnitRoot();
if (root == null)
Expand Down Expand Up @@ -194,7 +194,7 @@ bool IsFileLocal(TypeDeclarationSyntax node) =>
!node.AttributeLists.Any(list => list.Attributes.Any(a => a.IsValueTemplate()));
}

class TemplateRewriter(string tself, string tid) : CSharpSyntaxRewriter
class TemplateRewriter(string tself, string tvalue) : CSharpSyntaxRewriter
{
public override SyntaxNode? VisitRecordDeclaration(RecordDeclarationSyntax node)
{
Expand Down Expand Up @@ -282,8 +282,20 @@ class TemplateRewriter(string tself, string tid) : CSharpSyntaxRewriter
return IdentifierName(tself)
.WithLeadingTrivia(node.Identifier.LeadingTrivia)
.WithTrailingTrivia(node.Identifier.TrailingTrivia);
else if (node.Identifier.Text == "TId" || node.Identifier.Text == "TValue")
return IdentifierName(tid)

if (node.Identifier.Text.StartsWith("TSelf_"))
return IdentifierName(node.Identifier.Text.Replace("TSelf_", tvalue.Replace('.', '_') + "_"))
.WithLeadingTrivia(node.Identifier.LeadingTrivia)
.WithTrailingTrivia(node.Identifier.TrailingTrivia);

// TODO: remove TId as it's legacy
if (node.Identifier.Text == "TId" || node.Identifier.Text == "TValue")
return IdentifierName(tvalue)
.WithLeadingTrivia(node.Identifier.LeadingTrivia)
.WithTrailingTrivia(node.Identifier.TrailingTrivia);

if (node.Identifier.Text.StartsWith("TValue_"))
return IdentifierName(node.Identifier.Text.Replace("TValue_", tvalue.Replace('.', '_') + "_"))
.WithLeadingTrivia(node.Identifier.LeadingTrivia)
.WithTrailingTrivia(node.Identifier.TrailingTrivia);

Expand All @@ -297,8 +309,20 @@ public override SyntaxToken VisitToken(SyntaxToken token)
return Identifier(tself)
.WithLeadingTrivia(token.LeadingTrivia)
.WithTrailingTrivia(token.TrailingTrivia);
else if (token.IsKind(SyntaxKind.IdentifierToken) && (token.Text == "TId" || token.Text == "TValue"))
return Identifier(tid)

if (token.IsKind(SyntaxKind.IdentifierToken) && token.Text.StartsWith("TSelf_"))
return Identifier(token.Text.Replace("TSelf_", tvalue.Replace('.', '_') + "_"))
.WithLeadingTrivia(token.LeadingTrivia)
.WithTrailingTrivia(token.TrailingTrivia);

// TODO: remove TId as it's legacy
if (token.IsKind(SyntaxKind.IdentifierToken) && (token.Text == "TId" || token.Text == "TValue"))
return Identifier(tvalue)
.WithLeadingTrivia(token.LeadingTrivia)
.WithTrailingTrivia(token.TrailingTrivia);

if (token.IsKind(SyntaxKind.IdentifierToken) && token.Text.StartsWith("TValue_"))
return Identifier(token.Text.Replace("TValue_", tvalue.Replace('.', '_') + "_"))
.WithLeadingTrivia(token.LeadingTrivia)
.WithTrailingTrivia(token.TrailingTrivia);

Expand Down
22 changes: 13 additions & 9 deletions src/StructId.Analyzer/KnownTypes.cs
Original file line number Diff line number Diff line change
@@ -1,28 +1,32 @@
using Microsoft.CodeAnalysis;
using System.Linq;
using Microsoft.CodeAnalysis;

namespace StructId;

/// <summary>
/// Provides access to some common types and properties used in the compilation.
/// </summary>
/// <param name="Compilation">The compilation used to resolve the known types.</param>
/// <param name="StructIdNamespace">The namespace for StructId types.</param>
public record KnownTypes(Compilation Compilation, string StructIdNamespace)
public record KnownTypes(Compilation Compilation)
{
public string StructIdNamespace => IStructId?.ContainingNamespace.ToFullName() ?? "StructId";

/// <summary>
/// System.String
/// </summary>
public INamedTypeSymbol String { get; } = Compilation.GetTypeByMetadataName("System.String")!;

/// <summary>
/// StructId.IStructId
/// </summary>
public INamedTypeSymbol? IStructId { get; } = Compilation.GetTypeByMetadataName($"{StructIdNamespace}.IStructId");
public INamedTypeSymbol? IStructId { get; } = Compilation
.GetAllTypes(true)
.FirstOrDefault(x => x.MetadataName == "IStructId" && x.IsGeneratedByStructId());

/// <summary>
/// StructId.IStructId{T}
/// </summary>
public INamedTypeSymbol? IStructIdT { get; } = Compilation.GetTypeByMetadataName($"{StructIdNamespace}.IStructId`1");
/// <summary>
/// StructId.TStructIdAttribute
/// </summary>
public INamedTypeSymbol? TStructId { get; } = Compilation.GetTypeByMetadataName($"{StructIdNamespace}.TStructIdAttribute");
public INamedTypeSymbol? IStructIdT { get; } = Compilation
.GetAllTypes(true)
.FirstOrDefault(x => x.MetadataName == "IStructId`1" && x.IsGeneratedByStructId());
}
14 changes: 8 additions & 6 deletions src/StructId.Analyzer/NewtonsoftJsonGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,21 @@ public override void Initialize(IncrementalGeneratorInitializationContext contex
{
base.Initialize(context);

var source = context.CompilationProvider
.Select((x, _) => (new KnownTypes(x), x.GetTypeByMetadataName("Newtonsoft.Json.JsonConverter`1")));

context.RegisterSourceOutput(
context.CompilationProvider
.Select((x, _) => x.GetTypeByMetadataName("Newtonsoft.Json.JsonConverter`1"))
.Combine(context.AnalyzerConfigOptionsProvider.GetStructIdNamespace()),
source,
(context, source) =>
{
if (source.Left == null)
(var known, var converter) = source;
if (converter == null)
return;

context.AddSource("NewtonsoftJsonConverter.cs", SourceText.From(
ThisAssembly.Resources.Templates.NewtonsoftJsonConverter_1.Text
.Replace("namespace StructId;", $"namespace {source.Right};")
.Replace("using StructId;", $"using {source.Right};"),
.Replace("namespace StructId;", $"namespace {known.StructIdNamespace};")
.Replace("using StructId;", $"using {known.StructIdNamespace};"),
Encoding.UTF8));
});
}
Expand Down
6 changes: 3 additions & 3 deletions src/StructId.Analyzer/RecordAnalyzer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ public override void Initialize(AnalysisContext context)

static void Analyze(SyntaxNodeAnalysisContext context)
{
var ns = context.Options.AnalyzerConfigOptionsProvider.GlobalOptions.GetStructIdNamespace();
var known = new KnownTypes(context.Compilation);

if (context.Node is not TypeDeclarationSyntax typeDeclaration ||
context.Compilation.GetTypeByMetadataName($"{ns}.IStructId`1") is not { } structIdTypeOfT ||
context.Compilation.GetTypeByMetadataName($"{ns}.IStructId") is not { } structIdType)
known.IStructIdT is not { } structIdTypeOfT ||
known.IStructId is not { } structIdType)
return;

var symbol = context.SemanticModel.GetDeclaredSymbol(typeDeclaration);
Expand Down
2 changes: 0 additions & 2 deletions src/StructId.Analyzer/TemplateAnalyzer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ public override void Initialize(AnalysisContext context)

static void Analyze(SyntaxNodeAnalysisContext context)
{
var ns = context.Options.AnalyzerConfigOptionsProvider.GlobalOptions.GetStructIdNamespace();

if (context.Node is not TypeDeclarationSyntax typeDeclaration ||
!typeDeclaration.AttributeLists.Any(list => list.Attributes.Any(attr => attr.IsStructIdTemplate())))
return;
Expand Down
30 changes: 14 additions & 16 deletions src/StructId.Analyzer/TemplatedGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,8 @@ public bool AppliesTo(INamedTypeSymbol valueType)

public void Initialize(IncrementalGeneratorInitializationContext context)
{
var structIdNamespace = context.AnalyzerConfigOptionsProvider.GetStructIdNamespace();

var known = context.CompilationProvider
.Combine(structIdNamespace)
.Select((x, _) => new KnownTypes(x.Left, x.Right));
.Select((x, _) => new KnownTypes(x));

var templates = context.CompilationProvider
.SelectMany((x, _) => x.GetAllTypes(includeReferenced: true).OfType<INamedTypeSymbol>())
Expand All @@ -99,38 +96,39 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
.Combine(known)
.Select((x, cancellation) =>
{
var (structId, known) = x;
var (tself, known) = x;
// We infer the idType from the required primary constructor Value parameter type
var idType = (INamedTypeSymbol)structId.GetMembers().OfType<IPropertySymbol>().First(p => p.Name == "Value").Type;
var attribute = structId.GetAttributes().First(a => a.AttributeClass != null && a.AttributeClass.Is(known.TStructId));
var tvalue = (INamedTypeSymbol)tself.GetMembers().OfType<IPropertySymbol>().First(p => p.Name == "Value").Type;
var attribute = tself.GetAttributes().First(a => a.IsStructIdTemplate());

// The id type isn't declared in the same file, so we don't do anything fancy with it.
if (idType.DeclaringSyntaxReferences.Length == 0)
return new Template(structId, idType, attribute, known);
if (tvalue.DeclaringSyntaxReferences.Length == 0)
return new Template(tself, tvalue, attribute, known);

// Otherwise, the idType is a file-local type with a single interface
var type = idType.DeclaringSyntaxReferences[0].GetSyntax(cancellation) as TypeDeclarationSyntax;
var type = tvalue.DeclaringSyntaxReferences[0].GetSyntax(cancellation) as TypeDeclarationSyntax;
var iface = type?.BaseList?.Types.FirstOrDefault()?.Type;
if (type == null || iface == null)
return new Template(structId, idType, attribute, known) { OriginalTValue = idType };
return new Template(tself, tvalue, attribute, known) { OriginalTValue = tvalue };

if (x.Right.Compilation.GetSemanticModel(type.SyntaxTree).GetSymbolInfo(iface).Symbol is not INamedTypeSymbol ifaceType)
return new Template(structId, idType, attribute, known);
return new Template(tself, tvalue, attribute, known);

// if the interface is a generic type with a single type argument that is the same as the idType
// make it an unbound generic type. We'll bind it to the actual idType later at template render time.
if (ifaceType.IsGenericType && ifaceType.TypeArguments.Length == 1 && ifaceType.TypeArguments[0].Equals(idType, SymbolEqualityComparer.Default))
if (ifaceType.IsGenericType && ifaceType.TypeArguments.Length == 1 && ifaceType.TypeArguments[0].Equals(tvalue, SymbolEqualityComparer.Default))
ifaceType = ifaceType.ConstructUnboundGenericType();

return new Template(structId, ifaceType, attribute, known)
return new Template(tself, ifaceType, attribute, known)
{
OriginalTValue = idType
OriginalTValue = tvalue
};
})
.Collect();

var ids = context.CompilationProvider
.SelectMany((x, _) => x.Assembly.GetAllTypes().OfType<INamedTypeSymbol>())
.SelectMany((x, _) => x.Assembly.GetAllTypes().OfType<INamedTypeSymbol>()
.Where(t => !t.IsValueTemplate() && !t.IsStructIdTemplate()))
.Where(x => x.IsRecord && x.IsValueType && x.IsPartial())
.Combine(known)
.Where(x => x.Left.Is(x.Right.IStructId) || x.Left.Is(x.Right.IStructIdT))
Expand Down
5 changes: 1 addition & 4 deletions src/StructId.Analyzer/TemplatizedTValueExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,8 @@ static class TemplatizedTValueExtensions
/// </summary>
public static IncrementalValuesProvider<TemplatizedTValue> SelectTemplatizedValues(this IncrementalGeneratorInitializationContext context)
{
var structIdNamespace = context.AnalyzerConfigOptionsProvider.GetStructIdNamespace();

var known = context.CompilationProvider
.Combine(structIdNamespace)
.Select((x, _) => new KnownTypes(x.Left, x.Right));
.Select((x, _) => new KnownTypes(x));

var templates = context.CompilationProvider
.SelectMany((x, _) => x.GetAllTypes(includeReferenced: true).OfType<INamedTypeSymbol>())
Expand Down
4 changes: 0 additions & 4 deletions src/StructId.Package/StructId.props
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
<Project>

<ItemGroup>
<CompilerVisibleProperty Include="StructIdNamespace" />
</ItemGroup>

</Project>
3 changes: 3 additions & 0 deletions src/StructId/IStructId.cs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
// <auto-generated />

using System.CodeDom.Compiler;

namespace StructId;

/// <summary>
/// Interface for string-based identifiers.
/// </summary>
[GeneratedCode("StructId", default)]
public partial interface IStructId
{
/// <summary>
Expand Down
3 changes: 3 additions & 0 deletions src/StructId/IStructIdT.cs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
// <auto-generated />

using System.CodeDom.Compiler;

namespace StructId;

/// <summary>
/// Interface for struct-based identifiers.
/// </summary>
/// <typeparam name="TValue">The struct type for the inner <see cref="Value"/> of the identifier.</typeparam>
[GeneratedCode("StructId", default)]
public partial interface IStructId<TValue> where TValue : struct
{
/// <summary>
Expand Down

0 comments on commit 513689a

Please sign in to comment.