Skip to content

Commit

Permalink
feat: add NotNullIfNotNull attribute on generated methods, support it…
Browse files Browse the repository at this point in the history
… on user implemented methods
  • Loading branch information
latonz committed Feb 12, 2025
1 parent ca3080e commit 43ee375
Show file tree
Hide file tree
Showing 64 changed files with 521 additions and 84 deletions.
26 changes: 22 additions & 4 deletions src/Riok.Mapperly/Configuration/AttributeDataAccessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,41 @@ public bool HasAttribute<TAttribute>(ISymbol symbol)
public IEnumerable<TAttribute> Access<TAttribute>(ISymbol symbol)
where TAttribute : Attribute => Access<TAttribute, TAttribute>(symbol);

public IEnumerable<TAttribute> TryAccess<TAttribute>(IEnumerable<AttributeData> data)
where TAttribute : Attribute => TryAccess<TAttribute, TAttribute>(data);

public IEnumerable<TData> Access<TAttribute, TData>(ISymbol symbol)
where TAttribute : Attribute
where TData : notnull
{
var attrDatas = symbolAccessor.GetAttributes<TAttribute>(symbol);
return Access<TAttribute, TData>(attrDatas);
}

public IEnumerable<TData> TryAccess<TAttribute, TData>(IEnumerable<AttributeData> attributes)
where TAttribute : Attribute
where TData : notnull
{
var attrDatas = symbolAccessor.TryGetAttributes<TAttribute>(attributes);
return Access<TAttribute, TData>(attrDatas);
}

/// <summary>
/// Reads the attribute data and sets it on a newly created instance of <see cref="TData"/>.
/// If <see cref="TAttribute"/> has n type parameters,
/// <see cref="TData"/> needs to have an accessible ctor with the parameters 0 to n-1 to be of type <see cref="ITypeSymbol"/>.
/// <see cref="TData"/> needs to have exactly the same constructors as <see cref="TAttribute"/> with additional type arguments.
/// </summary>
/// <param name="symbol">The symbol on which the attributes should be read.</param>
/// <param name="attributes">The attributes data.</param>
/// <typeparam name="TAttribute">The type of the attribute.</typeparam>
/// <typeparam name="TData">The type of the data class. If no type parameters are involved, this is usually the same as <see cref="TAttribute"/>.</typeparam>
/// <returns>The attribute data.</returns>
/// <exception cref="InvalidOperationException">If a property or ctor argument of <see cref="TData"/> could not be read on the attribute.</exception>
public IEnumerable<TData> Access<TAttribute, TData>(ISymbol symbol)
public IEnumerable<TData> Access<TAttribute, TData>(IEnumerable<AttributeData> attributes)
where TAttribute : Attribute
where TData : notnull
{
var attrDatas = symbolAccessor.GetAttributes<TAttribute>(symbol);
foreach (var attrData in attrDatas)
foreach (var attrData in symbolAccessor.GetAttributes<TAttribute>(attributes))
{
yield return Access<TAttribute, TData>(attrData, symbolAccessor);
}
Expand Down
2 changes: 1 addition & 1 deletion src/Riok.Mapperly/Descriptors/DescriptorBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public DescriptorBuilder(
MapperConfiguration defaultMapperConfiguration
)
{
_mapperDescriptor = new MapperDescriptor(mapperDeclaration, _methodNameBuilder);
_mapperDescriptor = new MapperDescriptor(mapperDeclaration, _methodNameBuilder, compilationContext.Compilation.LanguageVersion);
_symbolAccessor = symbolAccessor;
_types = compilationContext.Types;
_mappingBodyBuilder = new MappingBodyBuilder(_mappings);
Expand Down
6 changes: 5 additions & 1 deletion src/Riok.Mapperly/Descriptors/MapperDescriptor.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System.Text;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Riok.Mapperly.Descriptors.Mappings;
using Riok.Mapperly.Descriptors.UnsafeAccess;
Expand All @@ -18,10 +19,11 @@ public class MapperDescriptor

public bool Static { get; set; }

public MapperDescriptor(MapperDeclaration declaration, UniqueNameBuilder nameBuilder)
public MapperDescriptor(MapperDeclaration declaration, UniqueNameBuilder nameBuilder, LanguageVersion languageVersion)
{
_declaration = declaration;
NameBuilder = nameBuilder;
LanguageVersion = languageVersion;
Name = BuildName(declaration.Symbol);
UnsafeAccessorName = nameBuilder.New(AccessorClassName);

Expand All @@ -31,6 +33,8 @@ public MapperDescriptor(MapperDeclaration declaration, UniqueNameBuilder nameBui
}
}

public LanguageVersion LanguageVersion { get; }

public string Name { get; }

public string? Namespace { get; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,12 @@ public static class ConvertStaticMethodMappingBuilder
allTargetMethods,
GetTargetStaticMethodNames(ctx),
ctx.Source,
ctx.Target,
nonNullableTarget,
targetIsNullable
);

if (mapping is not null)
{
return mapping;
}

var allSourceMethods = ctx.SymbolAccessor.GetAllMethods(ctx.Source);

Expand All @@ -44,7 +41,6 @@ public static class ConvertStaticMethodMappingBuilder
allSourceMethods.ToList(),
GetSourceStaticMethodNames(ctx),
ctx.Source,
ctx.Target,
nonNullableTarget,
targetIsNullable
);
Expand Down Expand Up @@ -81,7 +77,6 @@ private static bool IsDateTimeToTimeOnlyConversion(MappingBuilderContext ctx)
List<IMethodSymbol> allMethods,
IEnumerable<string> methodNames,
ITypeSymbol sourceType,
ITypeSymbol targetType,
ITypeSymbol nonNullableTargetType,
bool targetIsNullable
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ public static class NullableMappingBuilder
return null;

var delegateMapping = ctx.BuildMapping(mappingKey, MappingBuildingOptions.KeepUserSymbol);

return delegateMapping == null ? null : BuildNullDelegateMapping(ctx, delegateMapping);
}

Expand Down
13 changes: 12 additions & 1 deletion src/Riok.Mapperly/Descriptors/MappingCollection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,11 @@ public void AddNamedUserMapping(string? name, TUserMapping mapping)
public MappingCollectionAddResult TryAddAsDefault(T mapping, TypeMappingConfiguration config)
{
var mappingKey = new TypeMappingKey(mapping, config);
return _defaultMappings.TryAdd(mappingKey, mapping)
var result = _defaultMappings.TryAdd(mappingKey, mapping)
? MappingCollectionAddResult.Added
: MappingCollectionAddResult.NotAddedDuplicated;
AddAdditionalMappings(mapping, config);
return result;
}

public MappingCollectionAddResult AddUserMapping(TUserMapping mapping, bool? isDefault, string? name)
Expand Down Expand Up @@ -291,7 +293,16 @@ private MappingCollectionAddResult AddDefaultUserMapping(T mapping)

_duplicatedNonDefaultUserMappings.Remove(mappingKey);
_defaultMappings[mappingKey] = mapping;
AddAdditionalMappings(mapping, TypeMappingConfiguration.Default);
return MappingCollectionAddResult.Added;
}

private void AddAdditionalMappings(T mapping, TypeMappingConfiguration config)
{
foreach (var additionalKey in mapping.BuildAdditionalMappingKeys(config))
{
_defaultMappings.TryAdd(additionalKey, mapping);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Riok.Mapperly.Emit.Syntax;
using Riok.Mapperly.Helpers;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
using static Riok.Mapperly.Emit.Syntax.SyntaxFactoryHelper;

Expand All @@ -11,11 +12,11 @@ namespace Riok.Mapperly.Descriptors.Mappings;
/// by implementing a type switch over known types and performs the provided mapping for each type.
/// </summary>
public class DerivedTypeSwitchMapping(ITypeSymbol sourceType, ITypeSymbol targetType, IReadOnlyCollection<INewInstanceMapping> typeMappings)
: NewInstanceMapping(sourceType, targetType)
: NewInstanceMethodMapping(sourceType, targetType)
{
private const string GetTypeMethodName = nameof(GetType);

public override ExpressionSyntax Build(TypeMappingBuildContext ctx)
public override IEnumerable<StatementSyntax> BuildBody(TypeMappingBuildContext ctx)
{
// _ => throw new ArgumentException(msg, nameof(ctx.Source)),
var sourceTypeExpr = ctx.SyntaxFactory.Invocation(MemberAccess(ctx.Source, GetTypeMethodName));
Expand All @@ -32,14 +33,15 @@ public override ExpressionSyntax Build(TypeMappingBuildContext ctx)
// source switch { A x => MapToADto(x), B x => MapToBDto(x) }
var (typeArmContext, typeArmVariableName) = ctx.WithNewSource();
var arms = typeMappings.Select(x => BuildSwitchArm(typeArmVariableName, x.SourceType, x.Build(typeArmContext))).Append(fallbackArm);
return ctx.SyntaxFactory.Switch(ctx.Source, arms);
var switchExpression = ctx.SyntaxFactory.Switch(ctx.Source, arms);
return [ctx.SyntaxFactory.Return(switchExpression)];
}

private SwitchExpressionArmSyntax BuildSwitchArm(string typeArmVariableName, ITypeSymbol type, ExpressionSyntax mapping)
{
// A x => MapToADto(x),
var declaration = DeclarationPattern(
FullyQualifiedIdentifier(type).AddTrailingSpace(),
FullyQualifiedIdentifier(type.NonNullable()).AddTrailingSpace(),
SingleVariableDesignation(Identifier(typeArmVariableName))
);
return SwitchArm(declaration, mapping);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,7 @@ protected ExistingTargetMapping(ITypeSymbol sourceType, ITypeSymbol targetType)

public virtual bool IsSynthetic => false;

public IEnumerable<TypeMappingKey> BuildAdditionalMappingKeys(TypeMappingConfiguration config) => [];

public abstract IEnumerable<StatementSyntax> Build(TypeMappingBuildContext ctx, ExpressionSyntax target);
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,6 @@ public class ObjectMemberExistingTargetMapping(ITypeSymbol sourceType, ITypeSymb
public ITypeSymbol TargetType { get; } = targetType;

public bool IsSynthetic => false;

public IEnumerable<TypeMappingKey> BuildAdditionalMappingKeys(TypeMappingConfiguration config) => [];
}
2 changes: 2 additions & 0 deletions src/Riok.Mapperly/Descriptors/Mappings/ITypeMapping.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,6 @@ public interface ITypeMapping : IMapping
/// Gets a value indicating whether this mapping produces any code or can be omitted completely (eg. direct assignments or delegate mappings).
/// </summary>
bool IsSynthetic { get; }

IEnumerable<TypeMappingKey> BuildAdditionalMappingKeys(TypeMappingConfiguration config);
}
11 changes: 8 additions & 3 deletions src/Riok.Mapperly/Descriptors/Mappings/MethodMapping.cs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ ITypeSymbol targetType

public bool IsSynthetic => false;

public virtual IEnumerable<TypeMappingKey> BuildAdditionalMappingKeys(TypeMappingConfiguration config) => [];

public virtual ExpressionSyntax Build(TypeMappingBuildContext ctx) =>
ctx.SyntaxFactory.Invocation(
MethodName,
Expand All @@ -91,7 +93,7 @@ public virtual MethodDeclarationSyntax BuildMethod(SourceEmitterContext ctx)
SourceParameter.Name,
ReferenceHandlerParameter?.Name,
ctx.NameBuilder.NewScope(),
ctx.SyntaxFactory.AddIndentation()
ctx.SyntaxFactory
);

var parameters = BuildParameterList();
Expand All @@ -101,8 +103,8 @@ public virtual MethodDeclarationSyntax BuildMethod(SourceEmitterContext ctx)
return MethodDeclaration(returnType.AddTrailingSpace(), Identifier(MethodName))
.WithModifiers(TokenList(BuildModifiers(ctx.IsStatic)))
.WithParameterList(parameters)
.WithAttributeLists(ctx.SyntaxFactory.GeneratedCodeAttributeList())
.WithBody(ctx.SyntaxFactory.Block(BuildBody(typeMappingBuildContext)));
.WithAttributeLists(BuildAttributes(typeMappingBuildContext))
.WithBody(ctx.SyntaxFactory.Block(BuildBody(typeMappingBuildContext.AddIndentation())));
}

public abstract IEnumerable<StatementSyntax> BuildBody(TypeMappingBuildContext ctx);
Expand All @@ -121,6 +123,9 @@ internal virtual void EnableReferenceHandling(INamedTypeSymbol iReferenceHandler
);
}

protected internal virtual SyntaxList<AttributeListSyntax> BuildAttributes(TypeMappingBuildContext ctx) =>
[ctx.SyntaxFactory.GeneratedCodeAttribute()];

protected virtual ParameterListSyntax BuildParameterList() =>
ParameterList(IsExtensionMethod, [SourceParameter, ReferenceHandlerParameter, .. AdditionalSourceParameters]);

Expand Down
2 changes: 2 additions & 0 deletions src/Riok.Mapperly/Descriptors/Mappings/NewInstanceMapping.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,7 @@ protected NewInstanceMapping(ITypeSymbol sourceType, ITypeSymbol targetType)
/// <inheritdoc cref="INewInstanceMapping.IsSynthetic"/>
public virtual bool IsSynthetic => false;

public virtual IEnumerable<TypeMappingKey> BuildAdditionalMappingKeys(TypeMappingConfiguration config) => [];

public abstract ExpressionSyntax Build(TypeMappingBuildContext ctx);
}
2 changes: 2 additions & 0 deletions src/Riok.Mapperly/Descriptors/Mappings/NoOpMapping.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,7 @@ public class NoOpMapping(ITypeSymbol sourceType, ITypeSymbol targetType) : IExis
public ITypeSymbol TargetType => targetType;
public bool IsSynthetic => true;

public IEnumerable<TypeMappingKey> BuildAdditionalMappingKeys(TypeMappingConfiguration config) => [];

public IEnumerable<StatementSyntax> Build(TypeMappingBuildContext ctx, ExpressionSyntax target) => [];
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,37 @@ public class NullDelegateMethodMapping(
NullFallbackValue nullFallbackValue
) : NewInstanceMethodMapping(nullableSourceType, nullableTargetType)
{
public override IEnumerable<TypeMappingKey> BuildAdditionalMappingKeys(TypeMappingConfiguration config)
{
// if the fallback value is not nullable,
// this mapping never returns null.
// add the following mapping keys:
// null => null (added by default)
// null => non-null
// non-null => non-null
if (!nullFallbackValue.IsNullable(TargetType))
{
return
[
new TypeMappingKey(SourceType, TargetType.NonNullable(), config),
new TypeMappingKey(SourceType.NonNullable(), TargetType.NonNullable(), config),
];
}

// this mapping never returns null for non-null input values
// and is guarded with [return: NotNullIfNotNull]
// therefore this mapping can also be used as mapping for non-null values.
return [new TypeMappingKey(delegateMapping, config)];
}

protected internal override SyntaxList<AttributeListSyntax> BuildAttributes(TypeMappingBuildContext ctx)
{
if (!TargetType.IsNullable() || !nullFallbackValue.IsNullable(TargetType))
return base.BuildAttributes(ctx);

return [.. base.BuildAttributes(ctx), ctx.SyntaxFactory.ReturnNotNullIfNotNullAttribute(ctx.Source)];
}

public override IEnumerable<StatementSyntax> BuildBody(TypeMappingBuildContext ctx)
{
var body = delegateMapping.BuildBody(ctx);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
using Microsoft.CodeAnalysis;
using Riok.Mapperly.Helpers;

namespace Riok.Mapperly.Descriptors.Mappings;

internal static class NullFallbackValueExtensions
{
public static bool IsNullable(this NullFallbackValue fallbackValue, ITypeSymbol targetType) =>
fallbackValue == NullFallbackValue.Default && targetType.IsNullable();
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Riok.Mapperly.Abstractions.ReferenceHandling;
using Riok.Mapperly.Helpers;
using Riok.Mapperly.Symbols;

namespace Riok.Mapperly.Descriptors.Mappings.UserMappings;
Expand All @@ -21,6 +22,8 @@ bool enableReferenceHandling

public new IMethodSymbol Method { get; } = method;

private MethodMapping? DelegateMethodMapping => _delegateMapping as MethodMapping;

public bool? Default { get; } = isDefault;

public bool IsExternal => false;
Expand All @@ -33,6 +36,18 @@ bool enableReferenceHandling

public void SetDelegateMapping(INewInstanceMapping mapping) => _delegateMapping = mapping;

public override IEnumerable<TypeMappingKey> BuildAdditionalMappingKeys(TypeMappingConfiguration config)
{
// null is never returned if the source value is not null
if (TargetType.IsNullable())
{
yield return new TypeMappingKey(SourceType.NonNullable(), TargetType.NonNullable(), config);
}
}

protected internal override SyntaxList<AttributeListSyntax> BuildAttributes(TypeMappingBuildContext ctx) =>
DelegateMethodMapping?.BuildAttributes(ctx) ?? base.BuildAttributes(ctx);

public override ExpressionSyntax Build(TypeMappingBuildContext ctx)
{
return InternalReferenceHandlingEnabled ? _delegateMapping?.Build(ctx) ?? base.Build(ctx) : base.Build(ctx);
Expand Down Expand Up @@ -67,10 +82,7 @@ public override IEnumerable<StatementSyntax> BuildBody(TypeMappingBuildContext c
return [ctx.SyntaxFactory.Return(_delegateMapping.Build(ctx))];
}

if (_delegateMapping is MethodMapping delegateMethodMapping)
return delegateMethodMapping.BuildBody(ctx);

return [ctx.SyntaxFactory.Return(_delegateMapping.Build(ctx))];
return DelegateMethodMapping?.BuildBody(ctx) ?? [ctx.SyntaxFactory.Return(_delegateMapping.Build(ctx))];
}

internal override void EnableReferenceHandling(INamedTypeSymbol iReferenceHandlerType)
Expand Down
Loading

0 comments on commit 43ee375

Please sign in to comment.