Skip to content

Commit

Permalink
- use infos already present GeneratorAttributeSyntaxContext to optimz…
Browse files Browse the repository at this point in the history
…e pipeline steps
  • Loading branch information
ax0l0tl committed Jul 30, 2024
1 parent 19965b2 commit c40fc8e
Show file tree
Hide file tree
Showing 8 changed files with 144 additions and 154 deletions.
17 changes: 2 additions & 15 deletions Source/FunicularSwitch.Generators/EnumTypeGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -130,24 +130,11 @@ static IEnumerable<EnumSymbolInfo> GetSymbolInfosForExtendEnumTypeAttribute(Attr
static (EnumCaseOrder caseOrder, ExtensionAccessibility visibility) GetAttributeNamedArguments(
AttributeData extendEnumTypesAttribute)
{
var caseOrder = GetEnumNamedArgument(extendEnumTypesAttribute, "CaseOrder", EnumCaseOrder.AsDeclared);
var visibility = GetEnumNamedArgument(extendEnumTypesAttribute, "Accessibility", ExtensionAccessibility.Public);
var caseOrder = extendEnumTypesAttribute.GetEnumNamedArgument("CaseOrder", EnumCaseOrder.AsDeclared);
var visibility = extendEnumTypesAttribute.GetEnumNamedArgument("Accessibility", ExtensionAccessibility.Public);
return (caseOrder, visibility);
}

static T GetEnumNamedArgument<T>(AttributeData attributeData, string name, T defaultValue) where T : struct
{
foreach (var kv in attributeData.NamedArguments)
{
if (kv.Key != name)
continue;

return (T)(object)((int)kv.Value.Value!);
}

return defaultValue;
}

static IEnumerable<EnumSymbolInfo> GetSymbolInfosForExtendEnumTypesAttribute(AttributeData extendEnumTypesAttribute)
{
var attributeSymbol = extendEnumTypesAttribute.AttributeClass!;
Expand Down
28 changes: 20 additions & 8 deletions Source/FunicularSwitch.Generators/GeneratorHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,29 @@ static class GeneratorHelper
return hasAttribute ? classDeclarationSyntax : null;
}

public static T GetNamedEnumAttributeArgument<T>(this AttributeSyntax attribute, string name, T defaultValue) where T : struct
public static T GetEnumNamedArgument<T>(this AttributeData attributeData, string name, T defaultValue) where T : struct
{
var memberAccess = attribute.ArgumentList?.Arguments
.Where(a => a.NameEquals?.Name.ToString() == name)
.Select(a => a.Expression)
.OfType<MemberAccessExpressionSyntax>()
.FirstOrDefault();
foreach (var kv in attributeData.NamedArguments)
{
if (kv.Key != name)
continue;

return (T)(object)((int)kv.Value.Value!);
}

return defaultValue;
}

public static T GetNamedArgument<T>(this AttributeData attributeData, string name, T defaultValue)
{
foreach (var kv in attributeData.NamedArguments)
{
if (kv.Key != name)
continue;

if (memberAccess == null) return defaultValue;
return (T)kv.Value.Value!;
}

return (T)Enum.Parse(typeof(T), memberAccess.Name.ToString());
return defaultValue;
}
}
10 changes: 6 additions & 4 deletions Source/FunicularSwitch.Generators/ResultType/Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ static class Generator

public static IEnumerable<(string filename, string source)> Emit(
ResultTypeSchema resultTypeSchema,
SymbolWrapper<INamedTypeSymbol> defaultErrorType,
MergeMethod? mergeErrorMethod,
ExceptionToErrorMethod? exceptionToErrorMethod,
Action<Diagnostic> reportDiagnostic,
Expand All @@ -24,16 +25,17 @@ static class Generator
reportDiagnostic(Diagnostics.ResultTypeInGlobalNamespace($"Result type {resultTypeName} is placed in global namespace, this is not supported. Please put {resultTypeName} into non empty namespace.", resultTypeSchema.ResultTypeLocation?.ToLocation() ?? Location.None));
yield break;
}

var isValueType = resultTypeSchema.ErrorType.Symbol.IsValueType;
var errorTypeNamespace = resultTypeSchema.ErrorType.Symbol.GetFullNamespace();

var errorTypeSymbol = resultTypeSchema.ErrorType ?? defaultErrorType;
var isValueType = errorTypeSymbol.Symbol.IsValueType;
var errorTypeNamespace = errorTypeSymbol.Symbol.GetFullNamespace();

string Replace(string code, IReadOnlyCollection<string> additionalNamespaces, string genericTypeParameterNameForHandleExceptions)
{
code = code
.Replace($"namespace {TemplateNamespace}", $"namespace {resultTypeNamespace}")
.Replace(TemplateResultTypeName, resultTypeName)
.Replace(TemplateErrorTypeName, resultTypeSchema.ErrorType.Symbol.Name);
.Replace(TemplateErrorTypeName, errorTypeSymbol.Symbol.Name);

if (resultTypeSchema.IsInternal)
code = code
Expand Down
3 changes: 2 additions & 1 deletion Source/FunicularSwitch.Generators/ResultType/Parser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ namespace FunicularSwitch.Generators.ResultType;

static class Parser
{
public static GenerationResult<ResultTypeSchema> GetResultTypeSchema(ClassDeclarationSyntax resultTypeClass, Compilation compilation, CancellationToken cancellationToken)
public static GenerationResult<ResultTypeSchema> GetResultTypeSchema(
ClassDeclarationSyntax resultTypeClass, Compilation compilation, CancellationToken cancellationToken)
{
var semanticModel = compilation.GetSemanticModel(resultTypeClass.SyntaxTree);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,23 @@ namespace FunicularSwitch.Generators.ResultType;

sealed class ResultTypeSchema(
ClassDeclarationSyntax resultType,
INamedTypeSymbol errorType)
INamedTypeSymbol? errorType)
{
public SymbolWrapper<INamedTypeSymbol> ErrorType { get; } = new (errorType);
public SymbolWrapper<INamedTypeSymbol>? ErrorType { get; } = errorType != null ? new (errorType) : null;
public LocationInfo? ResultTypeLocation { get; } = LocationInfo.CreateFrom(resultType.GetLocation());
public bool IsInternal { get; } = !resultType.Modifiers.HasModifier(SyntaxKind.PublicKeyword);
public QualifiedTypeName ResultTypeName { get; } = resultType.QualifiedName();
public string? ResultTypeNamespace { get; } = resultType.GetContainingNamespace();

bool Equals(ResultTypeSchema other) => ErrorType.Equals(other.ErrorType) && IsInternal == other.IsInternal && ResultTypeName == other.ResultTypeName && ResultTypeNamespace == other.ResultTypeNamespace;
bool Equals(ResultTypeSchema other) => Equals(ErrorType, other.ErrorType) && IsInternal == other.IsInternal && ResultTypeName == other.ResultTypeName && ResultTypeNamespace == other.ResultTypeNamespace;

public override bool Equals(object? obj) => ReferenceEquals(this, obj) || obj is ResultTypeSchema other && Equals(other);

public override int GetHashCode()
{
unchecked
{
var hashCode = ErrorType.GetHashCode();
var hashCode = ErrorType?.GetHashCode() ?? 0;
hashCode = (hashCode * 397) ^ IsInternal.GetHashCode();
hashCode = (hashCode * 397) ^ ResultTypeName.GetHashCode();
hashCode = (hashCode * 397) ^ (ResultTypeNamespace != null ? ResultTypeNamespace.GetHashCode() : 0);
Expand Down
42 changes: 29 additions & 13 deletions Source/FunicularSwitch.Generators/ResultTypeGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,23 @@ public void Initialize(IncrementalGeneratorInitializationContext context)

var resultTypeClasses =
context.SyntaxProvider
.CreateSyntaxProvider(
predicate: static (s, _) => s.IsTypeDeclarationWithAttributes(),
.ForAttributeWithMetadataName(
ResultTypeAttribute,
predicate: static (s, _) => true,
transform: static (ctx, cancellationToken) =>
{
//TODO: support record result types one day
if (GeneratorHelper.GetSemanticTargetForGeneration(ctx, ResultTypeAttribute) is not ClassDeclarationSyntax resultTypeClass)
if (ctx.TargetSymbol is not INamedTypeSymbol n || n.IsRecord)
return GenerationResult<ResultTypeSchema>.Empty;

var schema = Parser.GetResultTypeSchema(resultTypeClass
, ctx.SemanticModel.Compilation, cancellationToken);
return schema;
var resultClass = (ClassDeclarationSyntax)ctx.TargetNode;
var errorTypeSymbol = (INamedTypeSymbol?)(!ctx.Attributes[0].NamedArguments.IsEmpty
? ctx.Attributes[0].NamedArguments[0].Value.Value!
: !ctx.Attributes[0].ConstructorArguments.IsEmpty
? ctx.Attributes[0].ConstructorArguments[0].Value
: null);

return new ResultTypeSchema(resultClass, errorTypeSymbol);
});

var compilationAndClasses = context.CompilationProvider
Expand All @@ -43,8 +49,9 @@ public void Initialize(IncrementalGeneratorInitializationContext context)

return GenerationResult.Create(
(
mergeMetdhods: mergeMethods.Values.ToImmutableArray().AsEquatableArray(),
exceptionToErrorMethods: exceptionToErrorMethods.Values.ToImmutableArray().AsEquatableArray()
mergeMethods: mergeMethods.Values.ToImmutableArray().AsEquatableArray(),
exceptionToErrorMethods: exceptionToErrorMethods.Values.ToImmutableArray().AsEquatableArray(),
stringSymbol: SymbolWrapper.Create(compilation.GetTypeByMetadataName("System.String")!)
),
diagnostics.Select(d => new DiagnosticInfo(d)).ToImmutableArray(), true
);
Expand All @@ -57,7 +64,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
}

static void Execute(
GenerationResult<(EquatableArray<MergeMethod> mergeMethods, EquatableArray<ExceptionToErrorMethod> exceptionToErrorMethods)> resultTypeMethods,
GenerationResult<(EquatableArray<MergeMethod> mergeMethods, EquatableArray<ExceptionToErrorMethod> exceptionToErrorMethods, SymbolWrapper<INamedTypeSymbol> stringSymbol)> resultTypeMethods,
ImmutableArray<GenerationResult<ResultTypeSchema>> resultTypeClassesResult, SourceProductionContext context)
{
foreach (var diagnosticInfo in resultTypeClassesResult
Expand All @@ -73,10 +80,19 @@ static void Execute(
if (resultTypeSchemata.IsDefaultOrEmpty) return;

var generated = resultTypeSchemata
.SelectMany(r => Generator.Emit(r,
resultTypeMethods.Value.mergeMethods.FirstOrDefault(m => m.FullErrorTypeName == r.ErrorType.FullNameWithNamespace),
resultTypeMethods.Value.exceptionToErrorMethods.FirstOrDefault(e => e.ErrorTypeName == r.ErrorType.FullNameWithNamespace),
context.ReportDiagnostic, context.CancellationToken)).ToImmutableArray();
.SelectMany(r =>
{
var defaultErrorType = resultTypeMethods.Value.stringSymbol;
var errorTypeSymbol = r.ErrorType ?? defaultErrorType;

return Generator.Emit(r,
defaultErrorType,
resultTypeMethods.Value.mergeMethods.FirstOrDefault(m =>
m.FullErrorTypeName == errorTypeSymbol.FullNameWithNamespace),
resultTypeMethods.Value.exceptionToErrorMethods.FirstOrDefault(e =>
e.ErrorTypeName == errorTypeSymbol.FullNameWithNamespace),
context.ReportDiagnostic, context.CancellationToken);
}).ToImmutableArray();

foreach (var (filename, source) in generated) context.AddSource(filename, source);
}
Expand Down
Loading

0 comments on commit c40fc8e

Please sign in to comment.