diff --git a/Source/FunicularSwitch.Generators.Common/EquatableArray.cs b/Source/FunicularSwitch.Generators.Common/EquatableArray.cs
new file mode 100644
index 0000000..a80b624
--- /dev/null
+++ b/Source/FunicularSwitch.Generators.Common/EquatableArray.cs
@@ -0,0 +1,208 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections;
+using System.Collections.Immutable;
+using System.Runtime.CompilerServices;
+// ReSharper disable NotDisposedResourceIsReturned
+
+// ReSharper disable once CheckNamespace
+namespace CommunityToolkit.Mvvm.SourceGenerators.Helpers;
+
+///
+/// Extensions for .
+///
+public static class EquatableArray
+{
+ ///
+ /// Creates an instance from a given .
+ ///
+ /// The type of items in the input array.
+ /// The input instance.
+ /// An instance from a given .
+ public static EquatableArray AsEquatableArray(this ImmutableArray array)
+ where T : IEquatable
+ {
+ return new(array);
+ }
+}
+
+///
+/// An immutable, equatable array. This is equivalent to but with value equality support.
+///
+/// The type of values in the array.
+public readonly struct EquatableArray : IEquatable>, IEnumerable
+ where T : IEquatable
+{
+ ///
+ /// The underlying array.
+ ///
+ private readonly T[]? _array;
+
+ public bool IsDefault => _array == null;
+
+ ///
+ /// Creates a new instance.
+ ///
+ /// The input to wrap.
+ public EquatableArray(ImmutableArray array)
+ {
+ this._array = Unsafe.As, T[]?>(ref array);
+ }
+
+ ///
+ /// Gets a reference to an item at a specified position within the array.
+ ///
+ /// The index of the item to retrieve a reference to.
+ /// A reference to an item at a specified position within the array.
+ public ref readonly T this[int index]
+ {
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ get => ref AsImmutableArray().ItemRef(index);
+ }
+
+ ///
+ /// Gets a value indicating whether the current array is empty.
+ ///
+ public bool IsEmpty
+ {
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ get => AsImmutableArray().IsEmpty;
+ }
+
+ public int Length
+ {
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ get => AsImmutableArray().Length;
+ }
+
+ ///
+ public bool Equals(EquatableArray array)
+ {
+ return AsSpan().SequenceEqual(array.AsSpan());
+ }
+
+ ///
+ public override bool Equals(object? obj)
+ {
+ return obj is EquatableArray array && Equals(this, array);
+ }
+
+ ///
+ public override int GetHashCode()
+ {
+ if (this._array is not { } array)
+ {
+ return 0;
+ }
+
+ HashCode hashCode = default;
+
+ foreach (T item in array)
+ {
+ hashCode.Add(item);
+ }
+
+ return hashCode.ToHashCode();
+ }
+
+ ///
+ /// Gets an instance from the current .
+ ///
+ /// The from the current .
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public ImmutableArray AsImmutableArray()
+ {
+ return Unsafe.As>(ref Unsafe.AsRef(in this._array));
+ }
+
+ ///
+ /// Creates an instance from a given .
+ ///
+ /// The input instance.
+ /// An instance from a given .
+ public static EquatableArray FromImmutableArray(ImmutableArray array)
+ {
+ return new(array);
+ }
+
+ ///
+ /// Returns a wrapping the current items.
+ ///
+ /// A wrapping the current items.
+ public ReadOnlySpan AsSpan()
+ {
+ return AsImmutableArray().AsSpan();
+ }
+
+ ///
+ /// Copies the contents of this instance to a mutable array.
+ ///
+ /// The newly instantiated array.
+ public T[] ToArray()
+ {
+ return AsImmutableArray().ToArray();
+ }
+
+ ///
+ /// Gets an value to traverse items in the current array.
+ ///
+ /// An value to traverse items in the current array.
+ public ImmutableArray.Enumerator GetEnumerator()
+ {
+ return AsImmutableArray().GetEnumerator();
+ }
+
+ ///
+ IEnumerator IEnumerable.GetEnumerator()
+ {
+ return ((IEnumerable)AsImmutableArray()).GetEnumerator();
+ }
+
+ ///
+ IEnumerator IEnumerable.GetEnumerator()
+ {
+ return ((IEnumerable)AsImmutableArray()).GetEnumerator();
+ }
+
+ ///
+ /// Implicitly converts an to .
+ ///
+ /// An instance from a given .
+ public static implicit operator EquatableArray(ImmutableArray array)
+ {
+ return FromImmutableArray(array);
+ }
+
+ ///
+ /// Implicitly converts an to .
+ ///
+ /// An instance from a given .
+ public static implicit operator ImmutableArray(EquatableArray array)
+ {
+ return array.AsImmutableArray();
+ }
+
+ ///
+ /// Checks whether two values are the same.
+ ///
+ /// The first value.
+ /// The second value.
+ /// Whether and are equal.
+ public static bool operator ==(EquatableArray left, EquatableArray right)
+ {
+ return left.Equals(right);
+ }
+
+ ///
+ /// Checks whether two values are not the same.
+ ///
+ /// The first value.
+ /// The second value.
+ /// Whether and are not equal.
+ public static bool operator !=(EquatableArray left, EquatableArray right)
+ {
+ return !left.Equals(right);
+ }
+}
\ No newline at end of file
diff --git a/Source/FunicularSwitch.Generators.Common/FunicularSwitch.Generators.Common.projitems b/Source/FunicularSwitch.Generators.Common/FunicularSwitch.Generators.Common.projitems
index e3ad6a0..a997895 100644
--- a/Source/FunicularSwitch.Generators.Common/FunicularSwitch.Generators.Common.projitems
+++ b/Source/FunicularSwitch.Generators.Common/FunicularSwitch.Generators.Common.projitems
@@ -9,7 +9,11 @@
FunicularSwitch.Generators.Common
+
+
+
+
\ No newline at end of file
diff --git a/Source/FunicularSwitch.Generators.Common/HashCode.cs b/Source/FunicularSwitch.Generators.Common/HashCode.cs
new file mode 100644
index 0000000..06dc383
--- /dev/null
+++ b/Source/FunicularSwitch.Generators.Common/HashCode.cs
@@ -0,0 +1,188 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.ComponentModel;
+using System.Runtime.CompilerServices;
+using System.Security.Cryptography;
+
+#pragma warning disable CS0809
+// ReSharper disable once CheckNamespace
+namespace System;
+
+///
+/// A polyfill type that mirrors some methods from on .NET 6.
+///
+internal struct HashCode
+{
+ private const uint Prime1 = 2654435761U;
+ private const uint Prime2 = 2246822519U;
+ private const uint Prime3 = 3266489917U;
+ private const uint Prime4 = 668265263U;
+ private const uint Prime5 = 374761393U;
+
+ private static readonly uint Seed = GenerateGlobalSeed();
+
+ private uint _v1, _v2, _v3, _v4;
+ private uint _queue1, _queue2, _queue3;
+ private uint _length;
+
+ ///
+ /// Initializes the default seed.
+ ///
+ /// A random seed.
+ private static uint GenerateGlobalSeed()
+ {
+ byte[] bytes = new byte[4];
+
+ using (RandomNumberGenerator generator = RandomNumberGenerator.Create())
+ {
+ generator.GetBytes(bytes);
+ }
+
+ return BitConverter.ToUInt32(bytes, 0);
+ }
+
+ ///
+ /// Adds a single value to the current hash.
+ ///
+ /// The type of the value to add into the hash code.
+ /// The value to add into the hash code.
+ public void Add(T value)
+ {
+ Add(value?.GetHashCode() ?? 0);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private static void Initialize(out uint v1, out uint v2, out uint v3, out uint v4)
+ {
+ v1 = Seed + Prime1 + Prime2;
+ v2 = Seed + Prime2;
+ v3 = Seed;
+ v4 = Seed - Prime1;
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private static uint Round(uint hash, uint input)
+ {
+ return RotateLeft(hash + input * Prime2, 13) * Prime1;
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private static uint QueueRound(uint hash, uint queuedValue)
+ {
+ return RotateLeft(hash + queuedValue * Prime3, 17) * Prime4;
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private static uint MixState(uint v1, uint v2, uint v3, uint v4)
+ {
+ return RotateLeft(v1, 1) + RotateLeft(v2, 7) + RotateLeft(v3, 12) + RotateLeft(v4, 18);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private static uint MixEmptyState()
+ {
+ return Seed + Prime5;
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private static uint MixFinal(uint hash)
+ {
+ hash ^= hash >> 15;
+ hash *= Prime2;
+ hash ^= hash >> 13;
+ hash *= Prime3;
+ hash ^= hash >> 16;
+
+ return hash;
+ }
+
+ private void Add(int value)
+ {
+ uint val = (uint)value;
+ uint previousLength = this._length++;
+ uint position = previousLength % 4;
+
+ if (position == 0)
+ {
+ this._queue1 = val;
+ }
+ else if (position == 1)
+ {
+ this._queue2 = val;
+ }
+ else if (position == 2)
+ {
+ this._queue3 = val;
+ }
+ else
+ {
+ if (previousLength == 3)
+ {
+ Initialize(out this._v1, out this._v2, out this._v3, out this._v4);
+ }
+
+ this._v1 = Round(this._v1, this._queue1);
+ this._v2 = Round(this._v2, this._queue2);
+ this._v3 = Round(this._v3, this._queue3);
+ this._v4 = Round(this._v4, val);
+ }
+ }
+
+ ///
+ /// Gets the resulting hashcode from the current instance.
+ ///
+ /// The resulting hashcode from the current instance.
+ public int ToHashCode()
+ {
+ uint length = this._length;
+ uint position = length % 4;
+ uint hash = length < 4 ? MixEmptyState() : MixState(this._v1, this._v2, this._v3, this._v4);
+
+ hash += length * 4;
+
+ if (position > 0)
+ {
+ hash = QueueRound(hash, this._queue1);
+
+ if (position > 1)
+ {
+ hash = QueueRound(hash, this._queue2);
+
+ if (position > 2)
+ {
+ hash = QueueRound(hash, this._queue3);
+ }
+ }
+ }
+
+ hash = MixFinal(hash);
+
+ return (int)hash;
+ }
+
+ ///
+ [Obsolete("HashCode is a mutable struct and should not be compared with other HashCodes. Use ToHashCode to retrieve the computed hash code.", error: true)]
+ [EditorBrowsable(EditorBrowsableState.Never)]
+ public override int GetHashCode() => throw new NotSupportedException();
+
+ ///
+ [Obsolete("HashCode is a mutable struct and should not be compared with other HashCodes.", error: true)]
+ [EditorBrowsable(EditorBrowsableState.Never)]
+ public override bool Equals(object? obj) => throw new NotSupportedException();
+
+ ///
+ /// Rotates the specified value left by the specified number of bits.
+ /// Similar in behavior to the x86 instruction ROL.
+ ///
+ /// The value to rotate.
+ /// The number of bits to rotate by.
+ /// Any value outside the range [0..31] is treated as congruent mod 32.
+ /// The rotated value.
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private static uint RotateLeft(uint value, int offset)
+ {
+ return (value << offset) | (value >> (32 - offset));
+ }
+}
\ No newline at end of file
diff --git a/Source/FunicularSwitch.Generators.Common/Result.cs b/Source/FunicularSwitch.Generators.Common/Result.cs
new file mode 100644
index 0000000..fa02090
--- /dev/null
+++ b/Source/FunicularSwitch.Generators.Common/Result.cs
@@ -0,0 +1,75 @@
+using System.Collections.Immutable;
+using CommunityToolkit.Mvvm.SourceGenerators.Helpers;
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.Text;
+
+namespace FunicularSwitch.Generators.Common;
+
+public readonly record struct GenerationResult(T? Value, EquatableArray Diagnostics, bool HasValue)
+{
+ public static readonly GenerationResult Empty = new(default, ImmutableArray.Empty, false);
+
+ public GenerationResult AddDiagnostics(DiagnosticInfo diagnosticInfo) =>
+ this with { Diagnostics = Diagnostics.AsImmutableArray().Add(diagnosticInfo) };
+
+ public GenerationResult SetValue(T value) =>
+ this with { Value = value, HasValue = true };
+
+ public static implicit operator GenerationResult(DiagnosticInfo diagnostic) => Empty.AddDiagnostics(diagnostic);
+
+ public static implicit operator GenerationResult(EquatableArray diagnostics) => new(default, diagnostics, false);
+
+ public static implicit operator GenerationResult(T value) => Empty.SetValue(value);
+
+ public GenerationResult Bind(Func> bind)
+ {
+ if (!HasValue)
+ return Diagnostics;
+
+ var newValue = bind(Value!);
+ return newValue with { Diagnostics = Diagnostics.AsImmutableArray().AddRange(newValue.Diagnostics) };
+ }
+
+ public GenerationResult Map(Func map)
+ {
+ var newValue = !HasValue
+ ? default
+ : map(Value!);
+ return new(newValue, Diagnostics, HasValue);
+ }
+}
+
+public sealed record DiagnosticInfo
+{
+ // Explicit constructor to convert Location into LocationInfo
+ public DiagnosticInfo(DiagnosticDescriptor descriptor, Location? location)
+ {
+ Descriptor = descriptor;
+ Location = location is not null ? LocationInfo.CreateFrom(location) : null;
+ }
+
+ public DiagnosticInfo(Diagnostic diagnostic)
+ {
+ Descriptor = diagnostic.Descriptor;
+ Location = LocationInfo.CreateFrom(diagnostic.Location);
+ }
+
+ public static implicit operator DiagnosticInfo(Diagnostic diagnostic) => new(diagnostic);
+
+ public DiagnosticDescriptor Descriptor { get; }
+ public LocationInfo? Location { get; }
+}
+
+public record LocationInfo(string FilePath, TextSpan TextSpan, LinePositionSpan LineSpan)
+{
+ public Location ToLocation()
+ => Location.Create(FilePath, TextSpan, LineSpan);
+
+ public static LocationInfo? CreateFrom(SyntaxNode node)
+ => CreateFrom(node.GetLocation());
+
+ public static LocationInfo? CreateFrom(Location location) =>
+ location.SourceTree is null
+ ? null
+ : new(location.SourceTree.FilePath, location.SourceSpan, location.GetLineSpan().Span);
+}
\ No newline at end of file
diff --git a/Source/FunicularSwitch.Generators.Common/RoslynExtensions.cs b/Source/FunicularSwitch.Generators.Common/RoslynExtensions.cs
index f04d22b..dff41bb 100644
--- a/Source/FunicularSwitch.Generators.Common/RoslynExtensions.cs
+++ b/Source/FunicularSwitch.Generators.Common/RoslynExtensions.cs
@@ -1,4 +1,5 @@
using System.Collections.Immutable;
+using CommunityToolkit.Mvvm.SourceGenerators.Helpers;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
@@ -75,15 +76,15 @@ public static bool Implements(this INamedTypeSymbol symbol, ITypeSymbol interfac
return parentNamespaces.ToSeparatedString(".");
}
- static readonly SymbolDisplayFormat s_FullTypeWithNamespaceDisplayFormat = new(typeQualificationStyle: SymbolDisplayTypeQualificationStyle.NameAndContainingTypesAndNamespaces);
- static readonly SymbolDisplayFormat s_FullTypeDisplayFormat = new(typeQualificationStyle: SymbolDisplayTypeQualificationStyle.NameAndContainingTypes);
+ static readonly SymbolDisplayFormat FullTypeWithNamespaceDisplayFormat = new(typeQualificationStyle: SymbolDisplayTypeQualificationStyle.NameAndContainingTypesAndNamespaces);
+ static readonly SymbolDisplayFormat FullTypeDisplayFormat = new(typeQualificationStyle: SymbolDisplayTypeQualificationStyle.NameAndContainingTypes);
- public static string FullTypeNameWithNamespace(this INamedTypeSymbol namedTypeSymbol) => namedTypeSymbol.ToDisplayString(s_FullTypeWithNamespaceDisplayFormat);
- public static string FullTypeName(this INamedTypeSymbol namedTypeSymbol) => namedTypeSymbol.ToDisplayString(s_FullTypeDisplayFormat);
+ public static string FullTypeNameWithNamespace(this INamedTypeSymbol namedTypeSymbol) => namedTypeSymbol.ToDisplayString(FullTypeWithNamespaceDisplayFormat);
+ public static string FullTypeName(this INamedTypeSymbol namedTypeSymbol) => namedTypeSymbol.ToDisplayString(FullTypeDisplayFormat);
public static string FullNamespace(this INamespaceSymbol namespaceSymbol) =>
- namespaceSymbol.ToDisplayString(s_FullTypeDisplayFormat);
+ namespaceSymbol.ToDisplayString(FullTypeDisplayFormat);
public static QualifiedTypeName QualifiedName(this BaseTypeDeclarationSyntax dec)
{
@@ -126,6 +127,12 @@ public static bool HasModifier(this SyntaxTokenList tokens, SyntaxKind syntaxKin
return tokens.Any(t => t.Text == token);
}
+ public static bool HasModifier(this IEnumerable tokens, SyntaxKind syntaxKind)
+ {
+ var token = SyntaxFactory.Token(syntaxKind).Text;
+ return tokens.Contains(token);
+ }
+
public static string GetFullTypeName(this Compilation compilation, SyntaxNode typeSyntax)
{
var semanticModel = compilation.GetSemanticModel(typeSyntax.SyntaxTree);
@@ -166,25 +173,30 @@ public static IEnumerable GetAllTypes(this INamespaceOrTypeSym
}
}
- public static MemberInfo ToMemberInfo(this BaseMethodDeclarationSyntax member, string name, Compilation compilation) =>
- new(name,
- member.Modifiers,
- member.ParameterList.Parameters
- .Select(p =>
- ToParameterInfo(p, compilation))
- .ToImmutableList());
+ public static MemberInfo ToMemberInfo(this BaseMethodDeclarationSyntax member, string name, Compilation compilation)
+ {
+ var modifiers = member.Modifiers;
+ return new(name,
+ ToEquatableModifiers(modifiers),
+ member.ParameterList.Parameters
+ .Select(p =>
+ ToParameterInfo(p, compilation))
+ .ToImmutableArray());
+ }
+
+ public static ImmutableArray ToEquatableModifiers(this SyntaxTokenList modifiers) => modifiers.Select(m => m.Text).ToImmutableArray();
public static ParameterInfo ToParameterInfo(this ParameterSyntax p, Compilation compilation) =>
new(
p.Identifier.Text,
- p.Modifiers,
+ p.Modifiers.ToEquatableModifiers(),
compilation.GetSemanticModel(p.SyntaxTree).GetTypeInfo(p.Type!).Type!,
- p.Default);
+ p.Default?.ToString());
}
public sealed class QualifiedTypeName : IEquatable
{
- public static QualifiedTypeName NoParents(string name) => new QualifiedTypeName(name, Enumerable.Empty());
+ public static QualifiedTypeName NoParents(string name) => new(name, []);
readonly string m_FullName;
public ImmutableArray NestingParents { get; }
@@ -217,23 +229,58 @@ public bool Equals(QualifiedTypeName? other)
public static bool operator !=(QualifiedTypeName left, QualifiedTypeName right) => !Equals(left, right);
}
-public sealed record MemberInfo(string Name, SyntaxTokenList Modifiers, IReadOnlyCollection Parameters);
+public sealed record MemberInfo(string Name, EquatableArray Modifiers, EquatableArray Parameters);
-public sealed record ParameterInfo(string Name, SyntaxTokenList Modifiers, ITypeSymbol Type, EqualsValueClauseSyntax? DefaultClause)
+public sealed record ParameterInfo
{
- public override string ToString()
+ readonly string _typeName;
+
+ public string Name { get; }
+ public EquatableArray Modifiers { get; }
+ public ITypeSymbol Type { get; }
+ public string? DefaultClause { get; }
+
+ public ParameterInfo(string name, EquatableArray modifiers, ITypeSymbol type, string? defaultClause)
+ {
+ Name = name;
+ Modifiers = modifiers;
+ Type = type;
+ DefaultClause = defaultClause;
+
+ _typeName = Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
+ }
+
+ public override string ToString()
{
- IEnumerable Parts()
- {
- if (Modifiers.Count > 0)
- yield return Modifiers.ToSeparatedString(" ");
- yield return Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
- yield return Name;
- if (DefaultClause != null)
- yield return DefaultClause.ToString();
- }
-
-
- return Parts().ToSeparatedString(" ");
- }
+ return Parts().ToSeparatedString(" ");
+
+ IEnumerable Parts()
+ {
+ if (Modifiers.Length > 0)
+ yield return Modifiers.ToSeparatedString(" ");
+ yield return _typeName;
+ yield return Name;
+ if (DefaultClause != null)
+ yield return DefaultClause;
+ }
+ }
+
+ public bool Equals(ParameterInfo? other)
+ {
+ if (ReferenceEquals(null, other)) return false;
+ if (ReferenceEquals(this, other)) return true;
+ return _typeName == other._typeName && Name == other.Name && Modifiers.Equals(other.Modifiers) && DefaultClause == other.DefaultClause;
+ }
+
+ public override int GetHashCode()
+ {
+ unchecked
+ {
+ var hashCode = _typeName.GetHashCode();
+ hashCode = (hashCode * 397) ^ Name.GetHashCode();
+ hashCode = (hashCode * 397) ^ Modifiers.GetHashCode();
+ hashCode = (hashCode * 397) ^ (DefaultClause != null ? DefaultClause.GetHashCode() : 0);
+ return hashCode;
+ }
+ }
}
\ No newline at end of file
diff --git a/Source/FunicularSwitch.Generators.Common/SourceProductionContextExtension.cs b/Source/FunicularSwitch.Generators.Common/SourceProductionContextExtension.cs
new file mode 100644
index 0000000..c4737e0
--- /dev/null
+++ b/Source/FunicularSwitch.Generators.Common/SourceProductionContextExtension.cs
@@ -0,0 +1,10 @@
+using Microsoft.CodeAnalysis;
+
+namespace FunicularSwitch.Generators.Common
+{
+ public static class SourceProductionContextExtension
+ {
+ public static void ReportDiagnostic(this SourceProductionContext context, DiagnosticInfo diagnostic) =>
+ context.ReportDiagnostic(Diagnostic.Create(diagnostic.Descriptor, diagnostic.Location?.ToLocation()));
+ }
+}
\ No newline at end of file
diff --git a/Source/FunicularSwitch.Generators/EnumType/EnumTypeSchema.cs b/Source/FunicularSwitch.Generators/EnumType/EnumTypeSchema.cs
index 1451b1c..20200c8 100644
--- a/Source/FunicularSwitch.Generators/EnumType/EnumTypeSchema.cs
+++ b/Source/FunicularSwitch.Generators/EnumType/EnumTypeSchema.cs
@@ -1,35 +1,9 @@
-using FunicularSwitch.Generators.Generation;
+using CommunityToolkit.Mvvm.SourceGenerators.Helpers;
+using FunicularSwitch.Generators.Generation;
namespace FunicularSwitch.Generators.EnumType;
-public sealed record EnumTypeSchema(string? Namespace, string TypeName, string FullTypeName, IReadOnlyCollection Cases, bool IsInternal)
-{
- public string? Namespace { get; } = Namespace;
- public string FullTypeName { get; } = FullTypeName;
- public string TypeName { get; } = TypeName;
- public IReadOnlyCollection Cases { get; } = Cases;
- public bool IsInternal { get; } = IsInternal;
-
- public bool Equals(EnumTypeSchema? other)
- {
- if (ReferenceEquals(null, other)) return false;
- if (ReferenceEquals(this, other)) return true;
- return Namespace == other.Namespace && FullTypeName == other.FullTypeName && TypeName == other.TypeName && IsInternal == other.IsInternal && Cases.SequenceEqual(other.Cases);
- }
-
- public override int GetHashCode()
- {
- unchecked
- {
- var hashCode = (Namespace != null ? Namespace.GetHashCode() : 0);
- hashCode = (hashCode * 397) ^ FullTypeName.GetHashCode();
- hashCode = (hashCode * 397) ^ TypeName.GetHashCode();
- hashCode = (hashCode * 397) ^ Cases.GetHashCodeByItems();
- hashCode = (hashCode * 397) ^ IsInternal.GetHashCode();
- return hashCode;
- }
- }
-}
+sealed record EnumTypeSchema(string? Namespace, string TypeName, string FullTypeName, EquatableArray Cases, bool IsInternal, AttributePrecedence Precedence);
public sealed record EnumCase
{
@@ -41,20 +15,4 @@ public EnumCase(string fullCaseName, string caseName)
FullCaseName = fullCaseName;
ParameterName = (caseName.Any(c => c != '_') ? caseName.TrimEnd('_') : caseName).ToParameterName();
}
-}
-
-internal static class EnumerableExtension
-{
- internal static int GetHashCodeByItems(this IEnumerable lst)
- {
- unchecked
- {
- int hash = 19;
- foreach (var item in lst)
- {
- hash = hash * 31 + (item != null ? item.GetHashCode() : 1);
- }
- return hash;
- }
- }
}
\ No newline at end of file
diff --git a/Source/FunicularSwitch.Generators/EnumType/Generator.cs b/Source/FunicularSwitch.Generators/EnumType/Generator.cs
index 1ba1d79..6d632bf 100644
--- a/Source/FunicularSwitch.Generators/EnumType/Generator.cs
+++ b/Source/FunicularSwitch.Generators/EnumType/Generator.cs
@@ -4,7 +4,7 @@
namespace FunicularSwitch.Generators.EnumType;
-public static class Generator
+static class Generator
{
const string VoidMatchMethodName = "Switch";
const string MatchMethodName = "Match";
@@ -15,6 +15,7 @@ public static (string filename, string source) Emit(EnumTypeSchema enumTypeSchem
var builder = new CSharpBuilder();
builder.WriteLine("#pragma warning disable 1591");
+ //builder.WriteLine($"//Generator runs: {RunCount.Increase(enumTypeSchema.FullTypeName)}");
void BlankLine()
{
builder.WriteLine("");
diff --git a/Source/FunicularSwitch.Generators/EnumType/Parser.cs b/Source/FunicularSwitch.Generators/EnumType/Parser.cs
index 51defe3..26bf823 100644
--- a/Source/FunicularSwitch.Generators/EnumType/Parser.cs
+++ b/Source/FunicularSwitch.Generators/EnumType/Parser.cs
@@ -8,102 +8,104 @@ namespace FunicularSwitch.Generators.EnumType;
static class Parser
{
- public static EnumSymbolInfo? GetEnumSymbolInfo(EnumDeclarationSyntax enumTypeClass, AttributeSyntax attribute, SemanticModel semanticModel)
- {
- var enumTypeSymbol = semanticModel.GetDeclaredSymbol(enumTypeClass);
- if (enumTypeSymbol == null)
- return null;
-
- var (enumCaseOrder, visibility) = GetAttributeParameters(attribute);
-
- return new(SymbolWrapper.Create(enumTypeSymbol), visibility, enumCaseOrder, AttributePrecedence.High);
- }
-
- static IEnumerable OrderEnumCases(IEnumerable enumCases, EnumCaseOrder enumCaseOrder) =>
- (enumCaseOrder == EnumCaseOrder.AsDeclared
- ? enumCases
- : enumCases.OrderBy(m => m.FullCaseName));
-
- public static IEnumerable GetAccessibleEnumTypeSymbols(INamespaceSymbol @namespace, bool includeInternalEnums)
- {
- static IEnumerable GetTypes(INamespaceOrTypeSymbol namespaceSymbol)
- {
- foreach (var namedTypeSymbol in namespaceSymbol.GetTypeMembers())
- {
- yield return namedTypeSymbol;
- foreach (var typeSymbol in GetTypes(namedTypeSymbol))
- {
- yield return typeSymbol;
- }
- }
-
- if (namespaceSymbol is INamespaceSymbol ns)
- foreach (var subNamespace in ns.GetNamespaceMembers())
- {
- foreach (var namedTypeSymbol in GetTypes(subNamespace))
- {
- yield return namedTypeSymbol;
- }
- }
- }
-
- var enumTypes = GetTypes(@namespace)
- .Where(t => t.EnumUnderlyingType != null
- && IsAccessible(t, includeInternalEnums)
- );
-
- return enumTypes;
- }
-
- static bool IsAccessible(INamedTypeSymbol t, bool includeInternalEnums)
- {
- var actualAccessibility = t.GetActualAccessibility();
-
- return actualAccessibility == Accessibility.Public ||
- includeInternalEnums && actualAccessibility == Accessibility.Internal;
- }
-
- public static EnumTypeSchema ToEnumTypeSchema(this EnumSymbolInfo symbolInfo)
- {
- var enumSymbol = symbolInfo.EnumTypeSymbol.Symbol;
-
- var fullNamespace = enumSymbol.GetFullNamespace();
- var fullTypeNameWithNamespace = enumSymbol.FullTypeNameWithNamespace();
-
- var derivedTypes = enumSymbol.GetMembers()
- .OfType()
- .Select(f => new EnumCase($"{fullTypeNameWithNamespace}.{f.Name}", f.Name));
-
- var acc = symbolInfo.EnumTypeSymbol.Symbol.GetActualAccessibility();
- var extensionAccessibility = acc is Accessibility.NotApplicable or Accessibility.Internal
- ? ExtensionAccessibility.Internal
- : symbolInfo.ExtensionAccessibility;
-
- return new(fullNamespace, enumSymbol.FullTypeName(), fullTypeNameWithNamespace,
- OrderEnumCases(derivedTypes, symbolInfo.CaseOrder)
- .ToList(),
- extensionAccessibility == ExtensionAccessibility.Internal
- );
- }
-
- public static (EnumCaseOrder caseOrder, ExtensionAccessibility visibility) GetAttributeParameters(AttributeSyntax attribute)
- {
- var caseOrder = attribute.GetNamedEnumAttributeArgument("CaseOrder", EnumCaseOrder.AsDeclared);
- var visibility = attribute.GetNamedEnumAttributeArgument("Visibility", ExtensionAccessibility.Public);
- return (caseOrder, visibility);
- }
+ public static EnumSymbolInfo? GetEnumSymbolInfo(EnumDeclarationSyntax enumTypeClass, AttributeSyntax attribute, SemanticModel semanticModel)
+ {
+ var enumTypeSymbol = semanticModel.GetDeclaredSymbol(enumTypeClass);
+ if (enumTypeSymbol == null)
+ return null;
+
+ var (enumCaseOrder, visibility) = GetAttributeParameters(attribute);
+
+ return new(SymbolWrapper.Create(enumTypeSymbol), visibility, enumCaseOrder, AttributePrecedence.High);
+ }
+
+ static IEnumerable OrderEnumCases(IEnumerable enumCases, EnumCaseOrder enumCaseOrder) =>
+ (enumCaseOrder == EnumCaseOrder.AsDeclared
+ ? enumCases
+ : enumCases.OrderBy(m => m.FullCaseName));
+
+ public static IEnumerable GetAccessibleEnumTypeSymbols(INamespaceSymbol @namespace, bool includeInternalEnums)
+ {
+ static IEnumerable GetTypes(INamespaceOrTypeSymbol namespaceSymbol)
+ {
+ foreach (var namedTypeSymbol in namespaceSymbol.GetTypeMembers())
+ {
+ yield return namedTypeSymbol;
+ foreach (var typeSymbol in GetTypes(namedTypeSymbol))
+ {
+ yield return typeSymbol;
+ }
+ }
+
+ if (namespaceSymbol is INamespaceSymbol ns)
+ foreach (var subNamespace in ns.GetNamespaceMembers())
+ {
+ foreach (var namedTypeSymbol in GetTypes(subNamespace))
+ {
+ yield return namedTypeSymbol;
+ }
+ }
+ }
+
+ var enumTypes = GetTypes(@namespace)
+ .Where(t => t.EnumUnderlyingType != null
+ && IsAccessible(t, includeInternalEnums)
+ );
+
+ return enumTypes;
+ }
+
+ static bool IsAccessible(INamedTypeSymbol t, bool includeInternalEnums)
+ {
+ var actualAccessibility = t.GetActualAccessibility();
+
+ return actualAccessibility == Accessibility.Public ||
+ includeInternalEnums && actualAccessibility == Accessibility.Internal;
+ }
+
+ public static EnumTypeSchema ToEnumTypeSchema(this EnumSymbolInfo symbolInfo)
+ {
+ var enumSymbol = symbolInfo.EnumTypeSymbol.Symbol;
+
+ var fullNamespace = enumSymbol.GetFullNamespace();
+ var fullTypeNameWithNamespace = enumSymbol.FullTypeNameWithNamespace();
+
+ var derivedTypes = enumSymbol.GetMembers()
+ .OfType()
+ .Select(f => new EnumCase($"{fullTypeNameWithNamespace}.{f.Name}", f.Name));
+
+ var acc = enumSymbol.GetActualAccessibility();
+ var extensionAccessibility = acc is Accessibility.NotApplicable or Accessibility.Internal
+ ? ExtensionAccessibility.Internal
+ : symbolInfo.ExtensionAccessibility;
+
+ return new(fullNamespace,
+ enumSymbol.FullTypeName(),
+ fullTypeNameWithNamespace,
+ OrderEnumCases(derivedTypes, symbolInfo.CaseOrder).ToImmutableArray(),
+ extensionAccessibility == ExtensionAccessibility.Internal,
+ symbolInfo.Precedence
+ );
+ }
+
+ public static (EnumCaseOrder caseOrder, ExtensionAccessibility visibility) GetAttributeParameters(AttributeSyntax attribute)
+ {
+ var caseOrder = attribute.GetNamedEnumAttributeArgument("CaseOrder", EnumCaseOrder.AsDeclared);
+ var visibility = attribute.GetNamedEnumAttributeArgument("Visibility", ExtensionAccessibility.Public);
+ return (caseOrder, visibility);
+ }
}
-record EnumSymbolInfo(
- SymbolWrapper EnumTypeSymbol,
- ExtensionAccessibility ExtensionAccessibility,
- EnumCaseOrder CaseOrder,
- AttributePrecedence Precedence
+sealed record EnumSymbolInfo(
+ SymbolWrapper EnumTypeSymbol,
+ ExtensionAccessibility ExtensionAccessibility,
+ EnumCaseOrder CaseOrder,
+ AttributePrecedence Precedence
);
enum AttributePrecedence
{
- Low,
- Medium,
- High
+ Low,
+ Medium,
+ High
}
\ No newline at end of file
diff --git a/Source/FunicularSwitch.Generators/EnumTypeGenerator.cs b/Source/FunicularSwitch.Generators/EnumTypeGenerator.cs
index 0c86a34..e51e824 100644
--- a/Source/FunicularSwitch.Generators/EnumTypeGenerator.cs
+++ b/Source/FunicularSwitch.Generators/EnumTypeGenerator.cs
@@ -1,4 +1,5 @@
using System.Collections.Immutable;
+using CommunityToolkit.Mvvm.SourceGenerators.Helpers;
using FunicularSwitch.Generators.Common;
using FunicularSwitch.Generators.EnumType;
using Microsoft.CodeAnalysis;
@@ -9,161 +10,177 @@ namespace FunicularSwitch.Generators;
[Generator]
public class EnumTypeGenerator : IIncrementalGenerator
{
- const string ExtendedEnumAttribute = "FunicularSwitch.Generators.ExtendedEnumAttribute";
- const string ExtendEnumsAttribute = "FunicularSwitch.Generators.ExtendEnumsAttribute";
- const string ExtendEnumAttribute = "FunicularSwitch.Generators.ExtendEnumAttribute";
-
- const string FunicularSwitchGeneratorsNamespace = "FunicularSwitch.Generators";
-
- public void Initialize(IncrementalGeneratorInitializationContext context)
- {
- context.RegisterPostInitializationOutput(ctx => ctx.AddSource(
- "Attributes.g.cs",
- Templates.EnumTypeTemplates.StaticCode));
-
- var enumTypeClasses =
- context.SyntaxProvider
- .CreateSyntaxProvider(
- predicate: static (s, _) => s is EnumDeclarationSyntax && s.IsTypeDeclarationWithAttributes()
- || s.IsAssemblyAttribute(),
- transform: static (ctx, _) => GetSemanticTargetForGeneration(ctx)
- )
- .SelectMany(static (target, _) => target!)
- .Where(static target => target != null);
-
- context.RegisterSourceOutput(
- enumTypeClasses.Collect(),
- static (spc, source) => Execute(source!, spc));
- }
-
- static void Execute(ImmutableArray enumSymbolInfos, SourceProductionContext context)
- {
- foreach (var enumSymbolInfo in enumSymbolInfos
- .GroupBy(s => s.EnumTypeSymbol)
- .Select(g => g.OrderByDescending(s => s.Precedence).First()))
- {
- var enumSymbol = enumSymbolInfo.EnumTypeSymbol.Symbol;
-
- var acc = enumSymbol.GetActualAccessibility();
- if (acc is Accessibility.Private or Accessibility.Protected)
- {
- context.ReportDiagnostic(Diagnostics.EnumTypeIsNotAccessible($"{enumSymbol.FullTypeNameWithNamespace()} needs at least internal accessibility",
- enumSymbol.Locations.FirstOrDefault() ?? Location.None));
- continue;
- }
-
- var isFlags = enumSymbol.GetAttributes().Any(a => a.AttributeClass?.FullTypeNameWithNamespace() == "System.FlagsAttribute");
- if (isFlags)
- continue; //TODO: report diagnostics in case of explicit EnumType attribute
+ const string ExtendedEnumAttribute = "FunicularSwitch.Generators.ExtendedEnumAttribute";
+ const string ExtendEnumsAttribute = "FunicularSwitch.Generators.ExtendEnumsAttribute";
+ const string ExtendEnumAttribute = "FunicularSwitch.Generators.ExtendEnumAttribute";
+
+ const string FunicularSwitchGeneratorsNamespace = "FunicularSwitch.Generators";
+
+ public void Initialize(IncrementalGeneratorInitializationContext context)
+ {
+ context.RegisterPostInitializationOutput(ctx => ctx.AddSource(
+ "Attributes.g.cs",
+ Templates.EnumTypeTemplates.StaticCode));
+
+ var enumTypeClasses =
+ context.SyntaxProvider
+ .CreateSyntaxProvider(
+ predicate: static (s, _) => s is EnumDeclarationSyntax && s.IsTypeDeclarationWithAttributes()
+ || s.IsAssemblyAttribute(),
+ transform: static (ctx, _) => GetSemanticTargetForGeneration(ctx)
+ .Select(ToEnumTypeSchema)
+ .ToImmutableArray()
+ .AsEquatableArray()
+ );
+
+ context.RegisterSourceOutput(
+ enumTypeClasses.Collect(),
+ static (spc, source) =>
+ Execute(source.SelectMany(s => s).ToImmutableArray(), spc));
+ }
+
+ static (EnumTypeSchema? enumTypeSchema, DiagnosticInfo? error) ToEnumTypeSchema(EnumSymbolInfo enumSymbolInfo)
+ {
+ var enumSymbol = enumSymbolInfo.EnumTypeSymbol.Symbol;
+
+ var acc = enumSymbol.GetActualAccessibility();
+ if (acc is Accessibility.Private or Accessibility.Protected)
+ {
+ var diagnostic = Diagnostics.EnumTypeIsNotAccessible($"{enumSymbol.FullTypeNameWithNamespace()} needs at least internal accessibility",
+ enumSymbol.Locations.FirstOrDefault() ?? Location.None);
+ return (null, new(diagnostic));
+ }
+
+ var isFlags = enumSymbol.GetAttributes().Any(a => a.AttributeClass?.FullTypeNameWithNamespace() == "System.FlagsAttribute");
+ if (isFlags)
+ return (null, null); //TODO: report diagnostics in case of explicit EnumType attribute
#pragma warning disable RS1024
- var hasDuplicates = enumSymbol.GetMembers()
- .OfType()
- .GroupBy(f => f.ConstantValue ?? 0)
+ var hasDuplicates = enumSymbol.GetMembers()
+ .OfType()
+ .GroupBy(f => f.ConstantValue ?? 0)
#pragma warning restore RS1024
- .Any(g => g.Count() > 1);
-
- if (hasDuplicates)
- continue; //TODO: report diagnostics in case of explicit EnumType attribute
-
- var enumTypeSchema = enumSymbolInfo.ToEnumTypeSchema();
-
- var (filename, source) = Generator.Emit(enumTypeSchema, context.ReportDiagnostic, context.CancellationToken);
- context.AddSource(filename, source);
-
- }
- }
-
- static IEnumerable GetSemanticTargetForGeneration(GeneratorSyntaxContext context)
- {
- switch (context.Node)
- {
- case EnumDeclarationSyntax enumDeclarationSyntax:
- {
- return GetSymbolInfoFromEnumDeclaration(context, enumDeclarationSyntax);
- }
- case AttributeSyntax extendEnumTypesAttribute:
- {
- var semanticModel = context.SemanticModel;
- var attributeFullName = extendEnumTypesAttribute.GetAttributeFullName(semanticModel);
-
- return attributeFullName switch
- {
- ExtendEnumsAttribute => GetSymbolInfosForExtendEnumTypesAttribute(extendEnumTypesAttribute, semanticModel),
- ExtendEnumAttribute => GetSymbolInfosForExtendEnumTypeAttribute(extendEnumTypesAttribute, semanticModel),
- _ => Enumerable.Empty()
- };
- }
- default:
- throw new ArgumentException($"Unexpected node of type {context.Node.GetType()}");
- }
- }
-
- static IEnumerable GetSymbolInfosForExtendEnumTypeAttribute(AttributeSyntax extendEnumTypesAttribute, SemanticModel semanticModel)
- {
- var typeofExpression = extendEnumTypesAttribute.ArgumentList?.Arguments
- .Select(a => a.Expression)
- .OfType()
- .FirstOrDefault();
-
- if (typeofExpression == null)
- return Enumerable.Empty();
-
- if (semanticModel.GetSymbolInfo(typeofExpression.Type).Symbol is not INamedTypeSymbol typeSymbol)
- return Enumerable.Empty();
-
- if (typeSymbol.EnumUnderlyingType == null)
- return Enumerable.Empty();
-
- var (caseOrder, visibility) = Parser.GetAttributeParameters(extendEnumTypesAttribute);
- return new[] { new EnumSymbolInfo(SymbolWrapper.Create(typeSymbol), visibility, caseOrder, AttributePrecedence.Medium) };
- }
-
- static IEnumerable GetSymbolInfosForExtendEnumTypesAttribute(AttributeSyntax extendEnumTypesAttribute, SemanticModel semanticModel)
- {
- var typeofExpression = extendEnumTypesAttribute.ArgumentList?.Arguments
- .Select(a => a.Expression)
- .OfType()
- .FirstOrDefault();
-
- var attributeSymbol = semanticModel.GetSymbolInfo(extendEnumTypesAttribute).Symbol!;
- var enumFromAssembly = typeofExpression != null
- ? semanticModel.GetSymbolInfo(typeofExpression.Type).Symbol!.ContainingAssembly
- : attributeSymbol.ContainingAssembly;
-
- var (caseOrder, visibility) = Parser.GetAttributeParameters(extendEnumTypesAttribute);
-
- return Parser.GetAccessibleEnumTypeSymbols(enumFromAssembly.GlobalNamespace,
- SymbolEqualityComparer.Default.Equals(attributeSymbol.ContainingAssembly, enumFromAssembly))
- .Where(e =>
- (e.Name != "ExtensionAccessibility" || e.GetFullNamespace() != FunicularSwitchGeneratorsNamespace) &&
- (e.Name != "EnumCaseOrder" || e.GetFullNamespace() != FunicularSwitchGeneratorsNamespace) &&
- (e.Name != "CaseOrder" || e.GetFullNamespace() != FunicularSwitchGeneratorsNamespace))
- .Select(e => new EnumSymbolInfo(SymbolWrapper.Create(e), visibility, caseOrder, AttributePrecedence.Low));
- }
-
- static IEnumerable GetSymbolInfoFromEnumDeclaration(GeneratorSyntaxContext context,
- EnumDeclarationSyntax enumDeclarationSyntax)
- {
- AttributeSyntax? enumTypeAttribute = null;
- foreach (var attributeListSyntax in enumDeclarationSyntax.AttributeLists)
- {
- foreach (var attributeSyntax in attributeListSyntax.Attributes)
- {
- var semanticModel = context.SemanticModel;
- var attributeFullName = attributeSyntax.GetAttributeFullName(semanticModel);
- if (attributeFullName != ExtendedEnumAttribute) continue;
- enumTypeAttribute = attributeSyntax;
- goto Return;
- }
- }
-
- Return:
- if (enumTypeAttribute == null)
- return Enumerable.Empty();
-
- var schema = Parser.GetEnumSymbolInfo(enumDeclarationSyntax, enumTypeAttribute, context.SemanticModel);
-
- return new[] { schema };
- }
+ .Any(g => g.Count() > 1);
+
+ if (hasDuplicates)
+ return (null, null); //TODO: report diagnostics in case of explicit EnumType attribute
+
+ var enumTypeSchema = enumSymbolInfo.ToEnumTypeSchema();
+ return (enumTypeSchema, null);
+ }
+
+ static void Execute(IReadOnlyCollection<(EnumTypeSchema? enumTypeSchema, DiagnosticInfo? error)> enumSymbolInfos, SourceProductionContext context)
+ {
+ foreach (var (_, diagnostic) in enumSymbolInfos)
+ {
+ if (diagnostic == null)
+ continue;
+ context.ReportDiagnostic(diagnostic);
+ }
+
+ foreach (var enumSymbolInfo in enumSymbolInfos
+ .Select(s => s.enumTypeSchema)
+ .Where(s => s != null)
+ .GroupBy(s => s!.FullTypeName)
+ .Select(g => g
+ .OrderByDescending(s => s!.Precedence)
+ .First()))
+ {
+ var (filename, source) = Generator.Emit(enumSymbolInfo!, context.ReportDiagnostic, context.CancellationToken);
+ context.AddSource(filename, source);
+ }
+ }
+
+ static IEnumerable GetSemanticTargetForGeneration(GeneratorSyntaxContext context)
+ {
+ switch (context.Node)
+ {
+ case EnumDeclarationSyntax enumDeclarationSyntax:
+ {
+ return GetSymbolInfoFromEnumDeclaration(context, enumDeclarationSyntax);
+ }
+ case AttributeSyntax extendEnumTypesAttribute:
+ {
+ var semanticModel = context.SemanticModel;
+ var attributeFullName = extendEnumTypesAttribute.GetAttributeFullName(semanticModel);
+
+ return attributeFullName switch
+ {
+ ExtendEnumsAttribute => GetSymbolInfosForExtendEnumTypesAttribute(extendEnumTypesAttribute, semanticModel),
+ ExtendEnumAttribute => GetSymbolInfosForExtendEnumTypeAttribute(extendEnumTypesAttribute, semanticModel),
+ _ => []
+ };
+ }
+ default:
+ throw new ArgumentException($"Unexpected node of type {context.Node.GetType()}");
+ }
+ }
+
+ static IEnumerable GetSymbolInfosForExtendEnumTypeAttribute(AttributeSyntax extendEnumTypesAttribute, SemanticModel semanticModel)
+ {
+ var typeofExpression = extendEnumTypesAttribute.ArgumentList?.Arguments
+ .Select(a => a.Expression)
+ .OfType()
+ .FirstOrDefault();
+
+ if (typeofExpression == null)
+ return [];
+
+ if (semanticModel.GetSymbolInfo(typeofExpression.Type).Symbol is not INamedTypeSymbol typeSymbol)
+ return [];
+
+ if (typeSymbol.EnumUnderlyingType == null)
+ return [];
+
+ var (caseOrder, visibility) = Parser.GetAttributeParameters(extendEnumTypesAttribute);
+ return new[] { new EnumSymbolInfo(SymbolWrapper.Create(typeSymbol), visibility, caseOrder, AttributePrecedence.Medium) };
+ }
+
+ static IEnumerable GetSymbolInfosForExtendEnumTypesAttribute(AttributeSyntax extendEnumTypesAttribute, SemanticModel semanticModel)
+ {
+ var typeofExpression = extendEnumTypesAttribute.ArgumentList?.Arguments
+ .Select(a => a.Expression)
+ .OfType()
+ .FirstOrDefault();
+
+ var attributeSymbol = semanticModel.GetSymbolInfo(extendEnumTypesAttribute).Symbol!;
+ var enumFromAssembly = typeofExpression != null
+ ? semanticModel.GetSymbolInfo(typeofExpression.Type).Symbol!.ContainingAssembly
+ : attributeSymbol.ContainingAssembly;
+
+ var (caseOrder, visibility) = Parser.GetAttributeParameters(extendEnumTypesAttribute);
+
+ return Parser.GetAccessibleEnumTypeSymbols(enumFromAssembly.GlobalNamespace,
+ SymbolEqualityComparer.Default.Equals(attributeSymbol.ContainingAssembly, enumFromAssembly))
+ .Where(e =>
+ (e.Name != "ExtensionAccessibility" || e.GetFullNamespace() != FunicularSwitchGeneratorsNamespace) &&
+ (e.Name != "EnumCaseOrder" || e.GetFullNamespace() != FunicularSwitchGeneratorsNamespace) &&
+ (e.Name != "CaseOrder" || e.GetFullNamespace() != FunicularSwitchGeneratorsNamespace))
+ .Select(e => new EnumSymbolInfo(SymbolWrapper.Create(e), visibility, caseOrder, AttributePrecedence.Low));
+ }
+
+ static IEnumerable GetSymbolInfoFromEnumDeclaration(GeneratorSyntaxContext context,
+ EnumDeclarationSyntax enumDeclarationSyntax)
+ {
+ AttributeSyntax? enumTypeAttribute = null;
+ foreach (var attributeListSyntax in enumDeclarationSyntax.AttributeLists)
+ {
+ foreach (var attributeSyntax in attributeListSyntax.Attributes)
+ {
+ var semanticModel = context.SemanticModel;
+ var attributeFullName = attributeSyntax.GetAttributeFullName(semanticModel);
+ if (attributeFullName != ExtendedEnumAttribute) continue;
+ enumTypeAttribute = attributeSyntax;
+ goto Return;
+ }
+ }
+
+ Return:
+ if (enumTypeAttribute == null)
+ return [];
+
+ var schema = Parser.GetEnumSymbolInfo(enumDeclarationSyntax, enumTypeAttribute, context.SemanticModel);
+
+ return schema == null ? Enumerable.Empty() : new[] { schema };
+ }
}
\ No newline at end of file
diff --git a/Source/FunicularSwitch.Generators/FunicularSwitch.Generators.csproj b/Source/FunicularSwitch.Generators/FunicularSwitch.Generators.csproj
index 79d2ed6..32c27dc 100644
--- a/Source/FunicularSwitch.Generators/FunicularSwitch.Generators.csproj
+++ b/Source/FunicularSwitch.Generators/FunicularSwitch.Generators.csproj
@@ -23,7 +23,7 @@
4
- 0.0
+ 1.0
$(MajorVersion).0.0
@@ -68,11 +68,6 @@
-
-
-
-
-
diff --git a/Source/FunicularSwitch.Generators/GeneratorHelper.cs b/Source/FunicularSwitch.Generators/GeneratorHelper.cs
index 5396132..9e9da54 100644
--- a/Source/FunicularSwitch.Generators/GeneratorHelper.cs
+++ b/Source/FunicularSwitch.Generators/GeneratorHelper.cs
@@ -6,66 +6,73 @@ namespace FunicularSwitch.Generators;
static class GeneratorHelper
{
- public static BaseTypeDeclarationSyntax? GetSemanticTargetForGeneration(GeneratorSyntaxContext context, string expectedAttributeName)
- {
- var classDeclarationSyntax = (BaseTypeDeclarationSyntax)context.Node;
- var hasAttribute = false;
- foreach (var attributeListSyntax in classDeclarationSyntax.AttributeLists)
- {
- foreach (var attributeSyntax in attributeListSyntax.Attributes)
- {
- var semanticModel = context.SemanticModel;
- var attributeFullName = attributeSyntax.GetAttributeFullName(semanticModel);
- if (attributeFullName != expectedAttributeName) continue;
- hasAttribute = true;
- goto Return;
- }
- }
+ public static BaseTypeDeclarationSyntax? GetSemanticTargetForGeneration(GeneratorSyntaxContext context, string expectedAttributeName)
+ {
+ var classDeclarationSyntax = (BaseTypeDeclarationSyntax)context.Node;
+ var hasAttribute = false;
+ foreach (var attributeListSyntax in classDeclarationSyntax.AttributeLists)
+ {
+ foreach (var attributeSyntax in attributeListSyntax.Attributes)
+ {
+ var semanticModel = context.SemanticModel;
+ var attributeFullName = attributeSyntax.GetAttributeFullName(semanticModel);
+ if (attributeFullName != expectedAttributeName) continue;
+ hasAttribute = true;
+ goto Return;
+ }
+ }
- Return:
- return hasAttribute ? classDeclarationSyntax : null;
- }
+ Return:
+ return hasAttribute ? classDeclarationSyntax : null;
+ }
- public static T GetNamedEnumAttributeArgument(this AttributeSyntax attribute, string name, T defaultValue) where T : struct
- {
- var memberAccess = attribute.ArgumentList?.Arguments
- .Where(a => a.NameEquals?.Name.ToString() == name)
- .Select(a => a.Expression)
- .OfType()
- .FirstOrDefault();
+ public static T GetNamedEnumAttributeArgument(this AttributeSyntax attribute, string name, T defaultValue) where T : struct
+ {
+ var memberAccess = attribute.ArgumentList?.Arguments
+ .Where(a => a.NameEquals?.Name.ToString() == name)
+ .Select(a => a.Expression)
+ .OfType()
+ .FirstOrDefault();
- if (memberAccess == null) return defaultValue;
+ if (memberAccess == null) return defaultValue;
- return (T)Enum.Parse(typeof(T), memberAccess.Name.ToString());
- }
+ return (T)Enum.Parse(typeof(T), memberAccess.Name.ToString());
+ }
}
public class SymbolWrapper
{
- public static SymbolWrapper Create(T symbol) where T : ISymbol => new(symbol);
+ internal static readonly SymbolDisplayFormat FullTypeWithNamespaceDisplayFormat = new(typeQualificationStyle: SymbolDisplayTypeQualificationStyle.NameAndContainingTypesAndNamespaces);
+
+ public static SymbolWrapper Create(T symbol) where T : ISymbol => new(symbol);
}
public class SymbolWrapper : IEquatable> where T : ISymbol
{
- public SymbolWrapper(T symbol) => Symbol = symbol;
+ public SymbolWrapper(T symbol)
+ {
+ Symbol = symbol;
+ FullNameWithNamespace = symbol.ToDisplayString(SymbolWrapper.FullTypeWithNamespaceDisplayFormat);
+ }
- public T Symbol { get; }
+ public string FullNameWithNamespace { get; }
+ public T Symbol { get; }
- public bool Equals(SymbolWrapper? other)
- {
- if (ReferenceEquals(null, other)) return false;
- if (ReferenceEquals(this, other)) return true;
- return SymbolEqualityComparer.Default.Equals(Symbol, other.Symbol);
- }
+ public bool Equals(SymbolWrapper? other)
+ {
+ if (ReferenceEquals(null, other)) return false;
+ if (ReferenceEquals(this, other)) return true;
+ return FullNameWithNamespace == other.FullNameWithNamespace;
+ }
- public override bool Equals(object? obj)
- {
- if (ReferenceEquals(null, obj)) return false;
- if (ReferenceEquals(this, obj)) return true;
- if (obj.GetType() != this.GetType()) return false;
- return Equals((SymbolWrapper)obj);
- }
+ public override bool Equals(object? obj)
+ {
+ if (ReferenceEquals(null, obj)) return false;
+ if (ReferenceEquals(this, obj)) return true;
+ if (obj.GetType() != GetType()) return false;
+ return Equals((SymbolWrapper)obj);
+ }
- public override int GetHashCode() => SymbolEqualityComparer.Default.GetHashCode(Symbol);
+ public override int GetHashCode() => FullNameWithNamespace.GetHashCode();
}
\ No newline at end of file
diff --git a/Source/FunicularSwitch.Generators/ResultTypeGenerator.cs b/Source/FunicularSwitch.Generators/ResultTypeGenerator.cs
index d53a17a..7b84af8 100644
--- a/Source/FunicularSwitch.Generators/ResultTypeGenerator.cs
+++ b/Source/FunicularSwitch.Generators/ResultTypeGenerator.cs
@@ -23,7 +23,6 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
.CreateSyntaxProvider(
predicate: static (s, _) => s.IsTypeDeclarationWithAttributes(),
transform: static (ctx, _) => GeneratorHelper.GetSemanticTargetForGeneration(ctx, ResultTypeAttribute)
-
)
.Where(static target => target != null)
.Select(static (target, _) => target!);
diff --git a/Source/FunicularSwitch.Generators/RunCount.cs b/Source/FunicularSwitch.Generators/RunCount.cs
new file mode 100644
index 0000000..bc77f23
--- /dev/null
+++ b/Source/FunicularSwitch.Generators/RunCount.cs
@@ -0,0 +1,9 @@
+using System.Collections.Concurrent;
+
+namespace FunicularSwitch.Generators;
+
+static class RunCount
+{
+ static readonly ConcurrentDictionary Counters = new();
+ public static int Increase(string key) => Counters.AddOrUpdate(key, _ => 1, (_, i) => i + 1);
+}
\ No newline at end of file
diff --git a/Source/FunicularSwitch.Generators/Templates/Resources.cs b/Source/FunicularSwitch.Generators/Templates/Resources.cs
new file mode 100644
index 0000000..1161640
--- /dev/null
+++ b/Source/FunicularSwitch.Generators/Templates/Resources.cs
@@ -0,0 +1,14 @@
+namespace FunicularSwitch.Generators.Templates;
+
+static class Resources
+{
+ static readonly string s_Namespace = $"{typeof(ResultTypeTemplates).Namespace}";
+
+ public static string ReadResource(string filename)
+ {
+ var resourcePath = $"{s_Namespace}.{filename}";
+ using var stream = typeof(ResultTypeTemplates).Assembly.GetManifestResourceStream(resourcePath);
+ using var reader = new StreamReader(stream!);
+ return reader.ReadToEnd();
+ }
+}
\ No newline at end of file
diff --git a/Source/FunicularSwitch.Generators/Templates/ResultTypeTemplates.cs b/Source/FunicularSwitch.Generators/Templates/ResultTypeTemplates.cs
index 89926e8..5ad2b74 100644
--- a/Source/FunicularSwitch.Generators/Templates/ResultTypeTemplates.cs
+++ b/Source/FunicularSwitch.Generators/Templates/ResultTypeTemplates.cs
@@ -5,17 +5,4 @@ static class ResultTypeTemplates
public static string ResultType => Resources.ReadResource("ResultType.cs");
public static string ResultTypeWithMerge => Resources.ReadResource("ResultTypeWithMerge.cs");
public static string StaticCode => Resources.ReadResource("ResultTypeAttributes.cs");
-}
-
-static class Resources
-{
- static readonly string s_Namespace = $"{typeof(ResultTypeTemplates).Namespace}";
-
- public static string ReadResource(string filename)
- {
- var resourcePath = $"{s_Namespace}.{filename}";
- using var stream = typeof(ResultTypeTemplates).Assembly.GetManifestResourceStream(resourcePath);
- using var reader = new StreamReader(stream!);
- return reader.ReadToEnd();
- }
}
\ No newline at end of file
diff --git a/Source/FunicularSwitch.Generators/UnionType/Generator.cs b/Source/FunicularSwitch.Generators/UnionType/Generator.cs
index 0b21a87..57a1021 100644
--- a/Source/FunicularSwitch.Generators/UnionType/Generator.cs
+++ b/Source/FunicularSwitch.Generators/UnionType/Generator.cs
@@ -17,6 +17,8 @@ public static (string filename, string source) Emit(UnionTypeSchema unionTypeSch
var builder = new CSharpBuilder();
builder.WriteLine("#pragma warning disable 1591");
+ //builder.WriteLine($"//Generator runs: {RunCount.Increase(unionTypeSchema.FullTypeName)}");
+
using (unionTypeSchema.Namespace != null ? builder.Namespace(unionTypeSchema.Namespace) : null)
{
WriteMatchExtension(unionTypeSchema, builder);
@@ -82,7 +84,7 @@ static void WritePartialWithStaticFactories(UnionTypeSchema unionTypeSchema, CSh
var info = unionTypeSchema.StaticFactoryInfo!;
var typeKind = unionTypeSchema.TypeKind switch { UnionTypeTypeKind.Class => "class", UnionTypeTypeKind.Interface => "interface", UnionTypeTypeKind.Record => "record", _ => throw new ArgumentException($"Unknown type kind: {unionTypeSchema.TypeKind}") };
- builder.WriteLine($"{(info.Modifiers.Select(m => m.Text).ToSeparatedString(" "))} {typeKind} {unionTypeSchema.TypeName}");
+ builder.WriteLine($"{(info.Modifiers.ToSeparatedString(" "))} {typeKind} {unionTypeSchema.TypeName}");
using (builder.Scope())
{
foreach (var derivedType in unionTypeSchema.Cases)
@@ -95,13 +97,13 @@ static void WritePartialWithStaticFactories(UnionTypeSchema unionTypeSchema, CSh
continue;
var constructors = derivedType.Constructors;
- if (constructors.Count == 0)
+ if (constructors.Length == 0)
constructors = new[]
{
new MemberInfo($"{derivedTypeName}",
- SyntaxTokenList.Create(SyntaxFactory.Token(SyntaxKind.PublicKeyword)),
- ImmutableList.Empty)
- };
+ ImmutableArray.Empty.Add("public"),
+ ImmutableArray.Empty)
+ }.ToImmutableArray();
foreach (var constructor in constructors)
{
diff --git a/Source/FunicularSwitch.Generators/UnionType/Parser.cs b/Source/FunicularSwitch.Generators/UnionType/Parser.cs
index d6a021e..096bb6e 100644
--- a/Source/FunicularSwitch.Generators/UnionType/Parser.cs
+++ b/Source/FunicularSwitch.Generators/UnionType/Parser.cs
@@ -9,70 +9,71 @@ namespace FunicularSwitch.Generators.UnionType;
static class Parser
{
- public static IEnumerable GetUnionTypes(Compilation compilation,
- ImmutableArray unionTypeClasses, Action reportDiagnostic,
- CancellationToken cancellationToken) =>
- unionTypeClasses
- .Select(unionTypeClass =>
- {
- var semanticModel = compilation.GetSemanticModel(unionTypeClass.SyntaxTree);
- var unionTypeSymbol = semanticModel.GetDeclaredSymbol(unionTypeClass);
-
- if (unionTypeSymbol == null) //TODO: report diagnostics
- return null!;
-
- var fullTypeName = unionTypeSymbol.FullTypeNameWithNamespace();
- var acc = unionTypeSymbol.DeclaredAccessibility;
- if (acc is Accessibility.Private or Accessibility.Protected)
- {
- reportDiagnostic(Diagnostics.UnionTypeIsNotAccessible(
- $"{fullTypeName} needs at least internal accessibility", unionTypeClass.GetLocation()));
- return null!;
- }
-
- var attribute = unionTypeClass.AttributeLists
- .Select(l => l.Attributes.First(a =>
- a.GetAttributeFullName(semanticModel) == UnionTypeGenerator.UnionTypeAttribute))
- .First();
-
- var (caseOrder, staticFactoryMethods) = TryGetCaseOrder(attribute, reportDiagnostic);
-
- var fullNamespace = unionTypeSymbol.GetFullNamespace();
-
- var derivedTypes = compilation.SyntaxTrees.SelectMany(t =>
- {
- var root = t.GetRoot(cancellationToken);
- var treeSemanticModel = t != unionTypeClass.SyntaxTree ? compilation.GetSemanticModel(t) : semanticModel;
-
- return FindConcreteDerivedTypesWalker.Get(root, unionTypeSymbol, treeSemanticModel);
- });
-
- var isPartial = unionTypeClass.Modifiers.HasModifier(SyntaxKind.PartialKeyword);
- var generateFactoryMethods = isPartial /*&& unionTypeClass is not InterfaceDeclarationSyntax*/ &&
- staticFactoryMethods;
- var cases =
- ToOrderedCases(caseOrder, derivedTypes, reportDiagnostic, compilation, generateFactoryMethods, unionTypeSymbol.Name)
- .ToImmutableArray();
-
- return new UnionTypeSchema(
- Namespace: fullNamespace,
- TypeName: unionTypeSymbol.Name,
- FullTypeName: fullTypeName,
- Cases: cases,
- IsInternal: acc is Accessibility.NotApplicable or Accessibility.Internal,
- IsPartial: isPartial,
- TypeKind: unionTypeClass switch
- {
- RecordDeclarationSyntax => UnionTypeTypeKind.Record,
- InterfaceDeclarationSyntax => UnionTypeTypeKind.Interface,
- _ => UnionTypeTypeKind.Class
- },
- StaticFactoryInfo: generateFactoryMethods
- ? BuildFactoryInfo(unionTypeClass, compilation)
- : null
- );
- })
- .Where(unionTypeClass => unionTypeClass is { Cases.Count: > 0 });
+ public static GenerationResult GetUnionTypeSchema(Compilation compilation,
+ CancellationToken cancellationToken, BaseTypeDeclarationSyntax unionTypeClass)
+ {
+ var semanticModel = compilation.GetSemanticModel(unionTypeClass.SyntaxTree);
+ var unionTypeSymbol = semanticModel.GetDeclaredSymbol(unionTypeClass);
+
+ if (unionTypeSymbol == null) //TODO: report diagnostics
+ return GenerationResult.Empty;
+
+ var fullTypeName = unionTypeSymbol.FullTypeNameWithNamespace();
+ var acc = unionTypeSymbol.DeclaredAccessibility;
+ if (acc is Accessibility.Private or Accessibility.Protected)
+ {
+ var diag = Diagnostics.UnionTypeIsNotAccessible($"{fullTypeName} needs at least internal accessibility", unionTypeClass.GetLocation());
+ return Error(diag);
+ }
+
+ var attribute = unionTypeClass.AttributeLists
+ .Select(l => l.Attributes.First(a =>
+ a.GetAttributeFullName(semanticModel) == UnionTypeGenerator.UnionTypeAttribute))
+ .First();
+
+ var caseOrderResult = TryGetCaseOrder(attribute);
+
+ return caseOrderResult.Bind(t =>
+ {
+ var fullNamespace = unionTypeSymbol.GetFullNamespace();
+
+ var derivedTypes = compilation.SyntaxTrees.SelectMany(syntaxTree =>
+ {
+ var root = syntaxTree.GetRoot(cancellationToken);
+ var treeSemanticModel = syntaxTree != unionTypeClass.SyntaxTree ? compilation.GetSemanticModel(syntaxTree) : semanticModel;
+
+ return FindConcreteDerivedTypesWalker.Get(root, unionTypeSymbol, treeSemanticModel);
+ });
+
+ var (caseOrder, staticFactoryMethods) = t;
+ var isPartial = unionTypeClass.Modifiers.HasModifier(SyntaxKind.PartialKeyword);
+ var generateFactoryMethods = isPartial /*&& unionTypeClass is not InterfaceDeclarationSyntax*/ &&
+ staticFactoryMethods;
+
+ return
+ ToOrderedCases(caseOrder, derivedTypes, compilation, generateFactoryMethods, unionTypeSymbol.Name)
+ .Map(cases =>
+ new UnionTypeSchema(
+ Namespace: fullNamespace,
+ TypeName: unionTypeSymbol.Name,
+ FullTypeName: fullTypeName,
+ Cases: cases,
+ IsInternal: acc is Accessibility.NotApplicable or Accessibility.Internal,
+ IsPartial: isPartial,
+ TypeKind: unionTypeClass switch
+ {
+ RecordDeclarationSyntax => UnionTypeTypeKind.Record,
+ InterfaceDeclarationSyntax => UnionTypeTypeKind.Interface,
+ _ => UnionTypeTypeKind.Class
+ },
+ StaticFactoryInfo: generateFactoryMethods
+ ? BuildFactoryInfo(unionTypeClass, compilation)
+ : null
+ ));
+ });
+
+ static GenerationResult Error(Diagnostic diagnostic) => GenerationResult.Empty.AddDiagnostics(diagnostic);
+ }
static (string parameterName, string methodName) DeriveParameterAndStaticMethodName(string typeName,
string baseTypeName)
@@ -104,7 +105,7 @@ static StaticFactoryMethodsInfo BuildFactoryInfo(BaseTypeDeclarationSyntax union
.OfType()
.Where(m => m.Modifiers.HasModifier(SyntaxKind.StaticKeyword))
.Select(m => m.ToMemberInfo(m.Name(), compilation))
- .ToImmutableList();
+ .ToImmutableArray();
var staticFields = unionTypeClass.ChildNodes()
.SelectMany(s => s switch
@@ -117,12 +118,12 @@ PropertyDeclarationSyntax p when p.Modifiers.HasModifier(SyntaxKind.StaticKeywor
},
_ => Array.Empty()
})
- .ToImmutableHashSet();
+ .ToImmutableArray();
- return new(staticMethods, staticFields, unionTypeClass.Modifiers);
+ return new(staticMethods, staticFields, unionTypeClass.Modifiers.ToEquatableModifiers());
}
- static (CaseOrder caseOder, bool staticFactoryMethods) TryGetCaseOrder(AttributeSyntax attribute, Action reportDiagnostics)
+ static GenerationResult<(CaseOrder caseOder, bool staticFactoryMethods)> TryGetCaseOrder(AttributeSyntax attribute)
{
var caseOrder = CaseOrder.Alphabetic;
var staticFactoryMethods = true;
@@ -130,6 +131,7 @@ PropertyDeclarationSyntax p when p.Modifiers.HasModifier(SyntaxKind.StaticKeywor
if ((attribute.ArgumentList?.Arguments.Count ?? 0) < 1)
return (caseOrder, staticFactoryMethods);
+ var errors = ImmutableArray.Empty;
foreach (var attributeArgumentSyntax in attribute.ArgumentList!.Arguments)
{
var propertyName = attributeArgumentSyntax.NameEquals?.Name.Identifier.Text;
@@ -138,19 +140,18 @@ PropertyDeclarationSyntax p when p.Modifiers.HasModifier(SyntaxKind.StaticKeywor
else if (propertyName == "StaticFactoryMethods" && attributeArgumentSyntax.Expression is LiteralExpressionSyntax lit)
staticFactoryMethods = bool.Parse(lit.Token.Text);
else
- {
- reportDiagnostics(Diagnostics.InvalidUnionTypeAttributeUsage($"Unsupported usage: {attribute}",
- attribute.GetLocation()));
- }
+ {
+ var diagnostic = Diagnostics.InvalidUnionTypeAttributeUsage($"Unsupported usage: {attribute}", attribute.GetLocation());
+ errors = errors.Add(diagnostic);
+ }
}
- return (caseOrder, staticFactoryMethods);
+ return new ((caseOrder, staticFactoryMethods), errors, true);
}
- static IEnumerable ToOrderedCases(CaseOrder caseOrder,
+ static GenerationResult> ToOrderedCases(CaseOrder caseOrder,
IEnumerable<(INamedTypeSymbol symbol, BaseTypeDeclarationSyntax node, int? caseIndex, int
- numberOfConctreteBaseTypes)> derivedTypes, Action reportDiagnostic, Compilation compilation,
- bool getConstructors, string baseTypeName)
+ numberOfConctreteBaseTypes)> derivedTypes, Compilation compilation, bool getConstructors, string baseTypeName)
{
var ordered = derivedTypes.OrderByDescending(d => d.numberOfConctreteBaseTypes);
ordered = caseOrder switch
@@ -164,6 +165,8 @@ static IEnumerable ToOrderedCases(CaseOrder caseOrder,
var result = ordered.ToImmutableArray();
+ var errors = ImmutableArray.Empty;
+
switch (caseOrder)
{
case CaseOrder.Alphabetic:
@@ -171,14 +174,15 @@ static IEnumerable ToOrderedCases(CaseOrder caseOrder,
foreach (var t in result.Where(r => r.caseIndex != null))
{
var message = $"Explicit case index on {t.node.Name()} is ignored, because CaseOrder on UnionTypeAttribute is {caseOrder}. Set it CaseOrder.Explicit for explicit ordering.";
- reportDiagnostic(Diagnostics.MisleadingCaseOrdering(message, t.node.GetLocation()));
+ var diagnostic = Diagnostics.MisleadingCaseOrdering(message, t.node.GetLocation());
+ errors = errors.Add(diagnostic);
}
break;
case CaseOrder.Explicit:
foreach (var t in result.Where(r => r.caseIndex == null))
{
var message = $"Missing case index on {t.node.Name()}. Please add UnionCaseAttribute for explicit case ordering.";
- reportDiagnostic(Diagnostics.CaseIndexNotSet(message, t.node.GetLocation()));
+ errors = errors.Add(Diagnostics.CaseIndexNotSet(message, t.node.GetLocation()));
}
foreach (var group in result.Where(r => r.caseIndex != null)
@@ -186,14 +190,14 @@ static IEnumerable ToOrderedCases(CaseOrder caseOrder,
.Where(g => g.Count() > 1))
{
var message = $"Cases {group.Select(g => g.node.Name()).ToSeparatedString()} define the same case index. Order is not guaranteed.";
- reportDiagnostic(Diagnostics.AmbiguousCaseIndex(message, group.First().node.GetLocation()));
+ errors = errors.Add(Diagnostics.AmbiguousCaseIndex(message, group.First().node.GetLocation()));
}
break;
default:
throw new ArgumentOutOfRangeException(nameof(caseOrder), caseOrder, null);
}
- return result.Select(d =>
+ var derived = result.Select(d =>
{
var qualifiedTypeName = d.node.QualifiedName();
var fullNamespace = d.symbol.GetFullNamespace();
@@ -207,8 +211,8 @@ static IEnumerable ToOrderedCases(CaseOrder caseOrder,
if (d.node is TypeDeclarationSyntax { ParameterList: not null } typeDeclaration)
constructors = constructors.Concat(new[]
{
- new MemberInfo(d.node.Name(), d.node.Modifiers, typeDeclaration.ParameterList.Parameters
- .Select(p => p.ToParameterInfo(compilation)).ToImmutableList())
+ new MemberInfo(d.node.Name(), d.node.Modifiers.ToEquatableModifiers(), typeDeclaration.ParameterList.Parameters
+ .Select(p => p.ToParameterInfo(compilation)).ToImmutableArray())
});
}
@@ -217,10 +221,12 @@ static IEnumerable ToOrderedCases(CaseOrder caseOrder,
return new DerivedType(
fullTypeName: $"{(fullNamespace != null ? $"{fullNamespace}." : "")}{qualifiedTypeName}",
- constructors: constructors?.ToImmutableList(),
+ constructors: constructors?.ToImmutableArray(),
parameterName: parameterName,
staticFactoryMethodName: staticMethodName);
- });
+ }).ToImmutableArray();
+
+ return new(derived, errors, true);
}
}
diff --git a/Source/FunicularSwitch.Generators/UnionType/UnionTypeSchema.cs b/Source/FunicularSwitch.Generators/UnionType/UnionTypeSchema.cs
index 90e5c56..c5deda9 100644
--- a/Source/FunicularSwitch.Generators/UnionType/UnionTypeSchema.cs
+++ b/Source/FunicularSwitch.Generators/UnionType/UnionTypeSchema.cs
@@ -1,14 +1,12 @@
-using System.Collections.Immutable;
+using CommunityToolkit.Mvvm.SourceGenerators.Helpers;
using FunicularSwitch.Generators.Common;
-using FunicularSwitch.Generators.Generation;
-using Microsoft.CodeAnalysis;
namespace FunicularSwitch.Generators.UnionType;
public sealed record UnionTypeSchema(string? Namespace,
string TypeName,
string FullTypeName,
- IReadOnlyCollection Cases,
+ EquatableArray Cases,
bool IsInternal,
bool IsPartial,
UnionTypeTypeKind TypeKind,
@@ -22,19 +20,19 @@ public enum UnionTypeTypeKind
}
public record StaticFactoryMethodsInfo(
- IReadOnlyCollection ExistingStaticMethods,
- IReadOnlyCollection ExistingStaticFields,
- SyntaxTokenList Modifiers
+ EquatableArray ExistingStaticMethods,
+ EquatableArray ExistingStaticFields,
+ EquatableArray Modifiers
);
public sealed record DerivedType
{
public string FullTypeName { get; }
- public IReadOnlyCollection Constructors { get; }
+ public EquatableArray Constructors { get; }
public string ParameterName { get; }
public string StaticFactoryMethodName { get; }
- public DerivedType(string fullTypeName, string parameterName, string staticFactoryMethodName, IReadOnlyCollection? constructors = null)
+ public DerivedType(string fullTypeName, string parameterName, string staticFactoryMethodName, EquatableArray? constructors = null)
{
FullTypeName = fullTypeName;
ParameterName = parameterName;
diff --git a/Source/FunicularSwitch.Generators/UnionTypeGenerator.cs b/Source/FunicularSwitch.Generators/UnionTypeGenerator.cs
index c47e27e..a2844d3 100644
--- a/Source/FunicularSwitch.Generators/UnionTypeGenerator.cs
+++ b/Source/FunicularSwitch.Generators/UnionTypeGenerator.cs
@@ -1,4 +1,4 @@
-using System.Collections.Immutable;
+using CommunityToolkit.Mvvm.SourceGenerators.Helpers;
using FunicularSwitch.Generators.Common;
using FunicularSwitch.Generators.UnionType;
using Microsoft.CodeAnalysis;
@@ -20,31 +20,28 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
var unionTypeClasses =
context.SyntaxProvider
- .CreateSyntaxProvider(
- predicate: static (s, _) => s.IsTypeDeclarationWithAttributes(),
- transform: static (ctx, _) => GeneratorHelper.GetSemanticTargetForGeneration(ctx, UnionTypeAttribute)
- )
- .Where(static target => target != null)
- .Select(static (target, _) => target!);
+ .ForAttributeWithMetadataName(
+ UnionTypeAttribute,
+ predicate: static (_, _) => true,
+ transform: static (ctx, cancellationToken) =>
+ Parser.GetUnionTypeSchema(ctx.SemanticModel.Compilation, cancellationToken, (BaseTypeDeclarationSyntax)ctx.TargetNode))
+ .Select(static (target, _) => target);
- var compilationAndClasses = context.CompilationProvider.Combine(unionTypeClasses.Collect());
+ var compilationAndClasses = unionTypeClasses;
context.RegisterSourceOutput(
compilationAndClasses,
- static (spc, source) => Execute(source.Left, source.Right, spc));
+ static (spc, source) => Execute(source, spc));
}
- static void Execute(Compilation compilation, ImmutableArray unionTypeClasses, SourceProductionContext context)
+ static void Execute(GenerationResult target, SourceProductionContext context)
{
- if (unionTypeClasses.IsDefaultOrEmpty) return;
+ var (unionTypeSchema, errors, hasValue) = target;
+ foreach (var error in errors) context.ReportDiagnostic(error);
+
+ if (!hasValue || unionTypeSchema!.Cases.IsEmpty) return;
- var resultTypeSchemata =
- Parser.GetUnionTypes(compilation, unionTypeClasses, context.ReportDiagnostic, context.CancellationToken)
- .ToImmutableArray();
-
- var generation =
- resultTypeSchemata.Select(r => Generator.Emit(r, context.ReportDiagnostic, context.CancellationToken));
-
- foreach (var (filename, source) in generation) context.AddSource(filename, source);
+ var (filename, source) = Generator.Emit(unionTypeSchema!, context.ReportDiagnostic, context.CancellationToken);
+ context.AddSource(filename, source);
}
}
\ No newline at end of file
diff --git a/Source/Tests/FunicularSwitch.Generators.Consumer.StandardMinLangVersion/UnionTypeSpecs.cs b/Source/Tests/FunicularSwitch.Generators.Consumer.StandardMinLangVersion/UnionTypeSpecs.cs
index 2d3d7f9..fb96866 100644
--- a/Source/Tests/FunicularSwitch.Generators.Consumer.StandardMinLangVersion/UnionTypeSpecs.cs
+++ b/Source/Tests/FunicularSwitch.Generators.Consumer.StandardMinLangVersion/UnionTypeSpecs.cs
@@ -37,7 +37,6 @@ public class Case1_ : MyUnion {}
public class MyUnionCase2 : MyUnion {}
}
-
[ResultType(ErrorType = typeof(string))]
public abstract partial class Result
{
diff --git a/Source/Tests/FunicularSwitch.Generators.Consumer/EnumSpecs.cs b/Source/Tests/FunicularSwitch.Generators.Consumer/EnumSpecs.cs
index d5b4805..09dca4d 100644
--- a/Source/Tests/FunicularSwitch.Generators.Consumer/EnumSpecs.cs
+++ b/Source/Tests/FunicularSwitch.Generators.Consumer/EnumSpecs.cs
@@ -4,7 +4,7 @@
using FluentAssertions.Data;
using Microsoft.VisualStudio.TestTools.UnitTesting;
-[assembly: ExtendEnums(typeof(FluentAssertions.AtLeast), CaseOrder = EnumCaseOrder.Alphabetic, Accessibility = ExtensionAccessibility.Internal)]
+[assembly: ExtendEnums(typeof(AtLeast), CaseOrder = EnumCaseOrder.Alphabetic, Accessibility = ExtensionAccessibility.Internal)]
[assembly: ExtendEnums]
[assembly: ExtendEnum(typeof(DateTimeKind), CaseOrder = EnumCaseOrder.Alphabetic)]
@@ -14,7 +14,8 @@ namespace FunicularSwitch.Generators.Consumer;
public class EnumSpecs
{
[ExtendedEnum(CaseOrder = EnumCaseOrder.Alphabetic)] //direct EnumType attribute should have higher precedence compared to ExtendEnumTypes attribute,
- //so case oder should be Alphabetic for Match methods of PlatformIdentifier
+
+ //so case oder should be Alphabetic for Match methods of PlatformIdentifier
public enum PlatformIdentifier
{
LinuxDevice,
diff --git a/Source/Tests/FunicularSwitch.Generators.Consumer/Generated/FunicularSwitch.Generators/FunicularSwitch.Generators.EnumTypeGenerator/FluentAssertions.Common.CSharpAccessModifierMatchExtension.g.cs b/Source/Tests/FunicularSwitch.Generators.Consumer/Generated/FunicularSwitch.Generators/FunicularSwitch.Generators.EnumTypeGenerator/FluentAssertions.Common.CSharpAccessModifierMatchExtension.g.cs
deleted file mode 100644
index 0094bc9..0000000
--- a/Source/Tests/FunicularSwitch.Generators.Consumer/Generated/FunicularSwitch.Generators/FunicularSwitch.Generators.EnumTypeGenerator/FluentAssertions.Common.CSharpAccessModifierMatchExtension.g.cs
+++ /dev/null
@@ -1,108 +0,0 @@
-#pragma warning disable 1591
-using System;
-using System.Threading.Tasks;
-
-namespace FluentAssertions.Common
-{
- public static partial class CSharpAccessModifierMatchExtension
- {
- public static T Match(this FluentAssertions.Common.CSharpAccessModifier cSharpAccessModifier, Func @internal, Func invalidForCSharp, Func @private, Func privateProtected, Func @protected, Func protectedInternal, Func @public) =>
- cSharpAccessModifier switch
- {
- FluentAssertions.Common.CSharpAccessModifier.Internal => @internal(),
- FluentAssertions.Common.CSharpAccessModifier.InvalidForCSharp => invalidForCSharp(),
- FluentAssertions.Common.CSharpAccessModifier.Private => @private(),
- FluentAssertions.Common.CSharpAccessModifier.PrivateProtected => privateProtected(),
- FluentAssertions.Common.CSharpAccessModifier.Protected => @protected(),
- FluentAssertions.Common.CSharpAccessModifier.ProtectedInternal => protectedInternal(),
- FluentAssertions.Common.CSharpAccessModifier.Public => @public(),
- _ => throw new ArgumentException($"Unknown enum value from FluentAssertions.Common.CSharpAccessModifier: {cSharpAccessModifier.GetType().Name}")
- };
-
- public static Task Match(this FluentAssertions.Common.CSharpAccessModifier cSharpAccessModifier, Func> @internal, Func> invalidForCSharp, Func> @private, Func> privateProtected, Func> @protected, Func> protectedInternal, Func> @public) =>
- cSharpAccessModifier switch
- {
- FluentAssertions.Common.CSharpAccessModifier.Internal => @internal(),
- FluentAssertions.Common.CSharpAccessModifier.InvalidForCSharp => invalidForCSharp(),
- FluentAssertions.Common.CSharpAccessModifier.Private => @private(),
- FluentAssertions.Common.CSharpAccessModifier.PrivateProtected => privateProtected(),
- FluentAssertions.Common.CSharpAccessModifier.Protected => @protected(),
- FluentAssertions.Common.CSharpAccessModifier.ProtectedInternal => protectedInternal(),
- FluentAssertions.Common.CSharpAccessModifier.Public => @public(),
- _ => throw new ArgumentException($"Unknown enum value from FluentAssertions.Common.CSharpAccessModifier: {cSharpAccessModifier.GetType().Name}")
- };
-
- public static async Task Match(this Task cSharpAccessModifier, Func @internal, Func invalidForCSharp, Func @private, Func privateProtected, Func @protected, Func protectedInternal, Func @public) =>
- (await cSharpAccessModifier.ConfigureAwait(false)).Match(@internal, invalidForCSharp, @private, privateProtected, @protected, protectedInternal, @public);
-
- public static async Task Match(this Task cSharpAccessModifier, Func> @internal, Func> invalidForCSharp, Func> @private, Func> privateProtected, Func> @protected, Func> protectedInternal, Func> @public) =>
- await (await cSharpAccessModifier.ConfigureAwait(false)).Match(@internal, invalidForCSharp, @private, privateProtected, @protected, protectedInternal, @public).ConfigureAwait(false);
-
- public static void Switch(this FluentAssertions.Common.CSharpAccessModifier cSharpAccessModifier, Action @internal, Action invalidForCSharp, Action @private, Action privateProtected, Action @protected, Action protectedInternal, Action @public)
- {
- switch (cSharpAccessModifier)
- {
- case FluentAssertions.Common.CSharpAccessModifier.Internal:
- @internal();
- break;
- case FluentAssertions.Common.CSharpAccessModifier.InvalidForCSharp:
- invalidForCSharp();
- break;
- case FluentAssertions.Common.CSharpAccessModifier.Private:
- @private();
- break;
- case FluentAssertions.Common.CSharpAccessModifier.PrivateProtected:
- privateProtected();
- break;
- case FluentAssertions.Common.CSharpAccessModifier.Protected:
- @protected();
- break;
- case FluentAssertions.Common.CSharpAccessModifier.ProtectedInternal:
- protectedInternal();
- break;
- case FluentAssertions.Common.CSharpAccessModifier.Public:
- @public();
- break;
- default:
- throw new ArgumentException($"Unknown enum value from FluentAssertions.Common.CSharpAccessModifier: {cSharpAccessModifier.GetType().Name}");
- }
- }
-
- public static async Task Switch(this FluentAssertions.Common.CSharpAccessModifier cSharpAccessModifier, Func @internal, Func invalidForCSharp, Func @private, Func privateProtected, Func @protected, Func protectedInternal, Func @public)
- {
- switch (cSharpAccessModifier)
- {
- case FluentAssertions.Common.CSharpAccessModifier.Internal:
- await @internal().ConfigureAwait(false);
- break;
- case FluentAssertions.Common.CSharpAccessModifier.InvalidForCSharp:
- await invalidForCSharp().ConfigureAwait(false);
- break;
- case FluentAssertions.Common.CSharpAccessModifier.Private:
- await @private().ConfigureAwait(false);
- break;
- case FluentAssertions.Common.CSharpAccessModifier.PrivateProtected:
- await privateProtected().ConfigureAwait(false);
- break;
- case FluentAssertions.Common.CSharpAccessModifier.Protected:
- await @protected().ConfigureAwait(false);
- break;
- case FluentAssertions.Common.CSharpAccessModifier.ProtectedInternal:
- await protectedInternal().ConfigureAwait(false);
- break;
- case FluentAssertions.Common.CSharpAccessModifier.Public:
- await @public().ConfigureAwait(false);
- break;
- default:
- throw new ArgumentException($"Unknown enum value from FluentAssertions.Common.CSharpAccessModifier: {cSharpAccessModifier.GetType().Name}");
- }
- }
-
- public static async Task Switch(this Task cSharpAccessModifier, Action @internal, Action invalidForCSharp, Action @private, Action privateProtected, Action @protected, Action protectedInternal, Action @public) =>
- (await cSharpAccessModifier.ConfigureAwait(false)).Switch(@internal, invalidForCSharp, @private, privateProtected, @protected, protectedInternal, @public);
-
- public static async Task Switch(this Task cSharpAccessModifier, Func @internal, Func invalidForCSharp, Func @private, Func privateProtected, Func @protected, Func protectedInternal, Func @public) =>
- await (await cSharpAccessModifier.ConfigureAwait(false)).Switch(@internal, invalidForCSharp, @private, privateProtected, @protected, protectedInternal, @public).ConfigureAwait(false);
- }
-}
-#pragma warning restore 1591
diff --git a/Source/Tests/FunicularSwitch.Generators.Consumer/Generated/FunicularSwitch.Generators/FunicularSwitch.Generators.EnumTypeGenerator/FluentAssertions.Common.ValueFormatterDetectionModeMatchExtension.g.cs b/Source/Tests/FunicularSwitch.Generators.Consumer/Generated/FunicularSwitch.Generators/FunicularSwitch.Generators.EnumTypeGenerator/FluentAssertions.Common.ValueFormatterDetectionModeMatchExtension.g.cs
deleted file mode 100644
index 2898d48..0000000
--- a/Source/Tests/FunicularSwitch.Generators.Consumer/Generated/FunicularSwitch.Generators/FunicularSwitch.Generators.EnumTypeGenerator/FluentAssertions.Common.ValueFormatterDetectionModeMatchExtension.g.cs
+++ /dev/null
@@ -1,76 +0,0 @@
-#pragma warning disable 1591
-using System;
-using System.Threading.Tasks;
-
-namespace FluentAssertions.Common
-{
- public static partial class ValueFormatterDetectionModeMatchExtension
- {
- public static T Match(this FluentAssertions.Common.ValueFormatterDetectionMode valueFormatterDetectionMode, Func disabled, Func scan, Func specific) =>
- valueFormatterDetectionMode switch
- {
- FluentAssertions.Common.ValueFormatterDetectionMode.Disabled => disabled(),
- FluentAssertions.Common.ValueFormatterDetectionMode.Scan => scan(),
- FluentAssertions.Common.ValueFormatterDetectionMode.Specific => specific(),
- _ => throw new ArgumentException($"Unknown enum value from FluentAssertions.Common.ValueFormatterDetectionMode: {valueFormatterDetectionMode.GetType().Name}")
- };
-
- public static Task Match(this FluentAssertions.Common.ValueFormatterDetectionMode valueFormatterDetectionMode, Func> disabled, Func> scan, Func> specific) =>
- valueFormatterDetectionMode switch
- {
- FluentAssertions.Common.ValueFormatterDetectionMode.Disabled => disabled(),
- FluentAssertions.Common.ValueFormatterDetectionMode.Scan => scan(),
- FluentAssertions.Common.ValueFormatterDetectionMode.Specific => specific(),
- _ => throw new ArgumentException($"Unknown enum value from FluentAssertions.Common.ValueFormatterDetectionMode: {valueFormatterDetectionMode.GetType().Name}")
- };
-
- public static async Task Match(this Task valueFormatterDetectionMode, Func disabled, Func scan, Func specific) =>
- (await valueFormatterDetectionMode.ConfigureAwait(false)).Match(disabled, scan, specific);
-
- public static async Task Match(this Task valueFormatterDetectionMode, Func> disabled, Func> scan, Func> specific) =>
- await (await valueFormatterDetectionMode.ConfigureAwait(false)).Match(disabled, scan, specific).ConfigureAwait(false);
-
- public static void Switch(this FluentAssertions.Common.ValueFormatterDetectionMode valueFormatterDetectionMode, Action disabled, Action scan, Action specific)
- {
- switch (valueFormatterDetectionMode)
- {
- case FluentAssertions.Common.ValueFormatterDetectionMode.Disabled:
- disabled();
- break;
- case FluentAssertions.Common.ValueFormatterDetectionMode.Scan:
- scan();
- break;
- case FluentAssertions.Common.ValueFormatterDetectionMode.Specific:
- specific();
- break;
- default:
- throw new ArgumentException($"Unknown enum value from FluentAssertions.Common.ValueFormatterDetectionMode: {valueFormatterDetectionMode.GetType().Name}");
- }
- }
-
- public static async Task Switch(this FluentAssertions.Common.ValueFormatterDetectionMode valueFormatterDetectionMode, Func disabled, Func scan, Func specific)
- {
- switch (valueFormatterDetectionMode)
- {
- case FluentAssertions.Common.ValueFormatterDetectionMode.Disabled:
- await disabled().ConfigureAwait(false);
- break;
- case FluentAssertions.Common.ValueFormatterDetectionMode.Scan:
- await scan().ConfigureAwait(false);
- break;
- case FluentAssertions.Common.ValueFormatterDetectionMode.Specific:
- await specific().ConfigureAwait(false);
- break;
- default:
- throw new ArgumentException($"Unknown enum value from FluentAssertions.Common.ValueFormatterDetectionMode: {valueFormatterDetectionMode.GetType().Name}");
- }
- }
-
- public static async Task Switch(this Task valueFormatterDetectionMode, Action disabled, Action scan, Action specific) =>
- (await valueFormatterDetectionMode.ConfigureAwait(false)).Switch(disabled, scan, specific);
-
- public static async Task Switch(this Task valueFormatterDetectionMode, Func disabled, Func scan, Func specific) =>
- await (await valueFormatterDetectionMode.ConfigureAwait(false)).Switch(disabled, scan, specific).ConfigureAwait(false);
- }
-}
-#pragma warning restore 1591
diff --git a/Source/Tests/FunicularSwitch.Generators.Consumer/Generated/FunicularSwitch.Generators/FunicularSwitch.Generators.EnumTypeGenerator/FluentAssertions.Data.RowMatchModeMatchExtension.g.cs b/Source/Tests/FunicularSwitch.Generators.Consumer/Generated/FunicularSwitch.Generators/FunicularSwitch.Generators.EnumTypeGenerator/FluentAssertions.Data.RowMatchModeMatchExtension.g.cs
deleted file mode 100644
index cfe0134..0000000
--- a/Source/Tests/FunicularSwitch.Generators.Consumer/Generated/FunicularSwitch.Generators/FunicularSwitch.Generators.EnumTypeGenerator/FluentAssertions.Data.RowMatchModeMatchExtension.g.cs
+++ /dev/null
@@ -1,68 +0,0 @@
-#pragma warning disable 1591
-using System;
-using System.Threading.Tasks;
-
-namespace FluentAssertions.Data
-{
- public static partial class RowMatchModeMatchExtension
- {
- public static T Match(this FluentAssertions.Data.RowMatchMode rowMatchMode, Func index, Func primaryKey) =>
- rowMatchMode switch
- {
- FluentAssertions.Data.RowMatchMode.Index => index(),
- FluentAssertions.Data.RowMatchMode.PrimaryKey => primaryKey(),
- _ => throw new ArgumentException($"Unknown enum value from FluentAssertions.Data.RowMatchMode: {rowMatchMode.GetType().Name}")
- };
-
- public static Task Match(this FluentAssertions.Data.RowMatchMode rowMatchMode, Func> index, Func> primaryKey) =>
- rowMatchMode switch
- {
- FluentAssertions.Data.RowMatchMode.Index => index(),
- FluentAssertions.Data.RowMatchMode.PrimaryKey => primaryKey(),
- _ => throw new ArgumentException($"Unknown enum value from FluentAssertions.Data.RowMatchMode: {rowMatchMode.GetType().Name}")
- };
-
- public static async Task Match(this Task rowMatchMode, Func index, Func primaryKey) =>
- (await rowMatchMode.ConfigureAwait(false)).Match(index, primaryKey);
-
- public static async Task Match(this Task rowMatchMode, Func> index, Func> primaryKey) =>
- await (await rowMatchMode.ConfigureAwait(false)).Match(index, primaryKey).ConfigureAwait(false);
-
- public static void Switch(this FluentAssertions.Data.RowMatchMode rowMatchMode, Action index, Action primaryKey)
- {
- switch (rowMatchMode)
- {
- case FluentAssertions.Data.RowMatchMode.Index:
- index();
- break;
- case FluentAssertions.Data.RowMatchMode.PrimaryKey:
- primaryKey();
- break;
- default:
- throw new ArgumentException($"Unknown enum value from FluentAssertions.Data.RowMatchMode: {rowMatchMode.GetType().Name}");
- }
- }
-
- public static async Task Switch(this FluentAssertions.Data.RowMatchMode rowMatchMode, Func index, Func primaryKey)
- {
- switch (rowMatchMode)
- {
- case FluentAssertions.Data.RowMatchMode.Index:
- await index().ConfigureAwait(false);
- break;
- case FluentAssertions.Data.RowMatchMode.PrimaryKey:
- await primaryKey().ConfigureAwait(false);
- break;
- default:
- throw new ArgumentException($"Unknown enum value from FluentAssertions.Data.RowMatchMode: {rowMatchMode.GetType().Name}");
- }
- }
-
- public static async Task Switch(this Task rowMatchMode, Action index, Action primaryKey) =>
- (await rowMatchMode.ConfigureAwait(false)).Switch(index, primaryKey);
-
- public static async Task Switch(this Task rowMatchMode, Func index, Func primaryKey) =>
- await (await rowMatchMode.ConfigureAwait(false)).Switch(index, primaryKey).ConfigureAwait(false);
- }
-}
-#pragma warning restore 1591
diff --git a/Source/Tests/FunicularSwitch.Generators.Consumer/Generated/FunicularSwitch.Generators/FunicularSwitch.Generators.EnumTypeGenerator/FluentAssertions.Equivalency.CyclicReferenceHandlingMatchExtension.g.cs b/Source/Tests/FunicularSwitch.Generators.Consumer/Generated/FunicularSwitch.Generators/FunicularSwitch.Generators.EnumTypeGenerator/FluentAssertions.Equivalency.CyclicReferenceHandlingMatchExtension.g.cs
deleted file mode 100644
index 67d32d7..0000000
--- a/Source/Tests/FunicularSwitch.Generators.Consumer/Generated/FunicularSwitch.Generators/FunicularSwitch.Generators.EnumTypeGenerator/FluentAssertions.Equivalency.CyclicReferenceHandlingMatchExtension.g.cs
+++ /dev/null
@@ -1,68 +0,0 @@
-#pragma warning disable 1591
-using System;
-using System.Threading.Tasks;
-
-namespace FluentAssertions.Equivalency
-{
- public static partial class CyclicReferenceHandlingMatchExtension
- {
- public static T Match