Skip to content

Commit

Permalink
[Rgen] Implement the methods to decide if we need to use a stret call.
Browse files Browse the repository at this point in the history
  • Loading branch information
mandel-macaque committed Jan 24, 2025
1 parent 40f10cf commit a70cdd3
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/ObjCRuntime/Stret.cs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ public static bool ArmNeedStret (Type returnType, Generator generator)
}
#endif // BGENERATOR

#if BGENERATOR
#if BGENERATOR || RGEN
public static bool X86NeedStret (Type returnType, Generator generator)
{
Type t = returnType;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public static Dictionary<string, List<AttributeData>> GetAttributeData (this ISy
var boundAttributes = symbol.GetAttributes ();
if (boundAttributes.Length == 0) {
// return an empty dictionary if there are no attributes
return new ();
return new();
}

var attributes = new Dictionary<string, List<AttributeData>> ();
Expand Down Expand Up @@ -102,7 +102,7 @@ public static bool HasAttribute (this ISymbol symbol, string attribute)
if (attrName is null)
return null;
if (!attributes.TryGetValue (attrName, out var exportAttrDataList) ||
exportAttrDataList.Count != 1)
exportAttrDataList.Count != 1)
return null;

var exportAttrData = exportAttrDataList [0];
Expand All @@ -115,7 +115,8 @@ public static bool HasAttribute (this ISymbol symbol, string attribute)
return null;
}

internal static T? GetAttribute<T> (this ISymbol symbol, string attributeName, TryParse<T> tryParse) where T : struct
internal static T? GetAttribute<T> (this ISymbol symbol, string attributeName, TryParse<T> tryParse)
where T : struct
=> GetAttribute (symbol, () => attributeName, tryParse);

/// <summary>
Expand All @@ -128,7 +129,7 @@ public static LayoutKind GetStructLayout (this ITypeSymbol symbol)
// Check for StructLayout attribute with LayoutKind.Sequential
var layoutAttribute = symbol.GetAttributes ()
.FirstOrDefault (attr =>
attr.AttributeClass?.ToString () == typeof (StructLayoutAttribute).FullName);
attr.AttributeClass?.ToString () == typeof(StructLayoutAttribute).FullName);

if (layoutAttribute is not null) {
return (LayoutKind) layoutAttribute.ConstructorArguments [0].Value!;
Expand Down Expand Up @@ -226,10 +227,11 @@ public static int GetFieldOffset (this IFieldSymbol symbol)
{
var offsetAttribute = symbol.GetAttributes ()
.FirstOrDefault (attr =>
attr.AttributeClass?.ToString () == typeof (FieldOffsetAttribute).FullName);
attr.AttributeClass?.ToString () == typeof(FieldOffsetAttribute).FullName);

return offsetAttribute is not null
? (int) offsetAttribute.ConstructorArguments [0].Value! : 0;
? (int) offsetAttribute.ConstructorArguments [0].Value!
: 0;
}

/// <summary>
Expand All @@ -241,7 +243,7 @@ public static (UnmanagedType Type, int SizeConst)? GetMarshalAs (this ISymbol sy
{
var marshalAsAttribute = symbol.GetAttributes ()
.FirstOrDefault (attr =>
attr.AttributeClass?.ToString () == typeof (MarshalAsAttribute).FullName);
attr.AttributeClass?.ToString () == typeof(MarshalAsAttribute).FullName);
if (marshalAsAttribute is null)
return null;
var type = (UnmanagedType) marshalAsAttribute.ConstructorArguments [0].Value!;
Expand Down Expand Up @@ -298,6 +300,9 @@ internal static bool TryGetBuiltInTypeSize (this ITypeSymbol symbol, bool is64bi
return result;
}

static bool TryGetBuiltInTypeSize (this ITypeSymbol type)
=> TryGetBuiltInTypeSize (type, true /* doesn't matter */, out _);

static int AlignAndAdd (int size, int add, ref int maxElementSize)
{
maxElementSize = Math.Max (maxElementSize, add);
Expand All @@ -307,7 +312,8 @@ static int AlignAndAdd (int size, int add, ref int maxElementSize)
}


static void GetValueTypeSize (this ITypeSymbol originalSymbol, ITypeSymbol type, List<ITypeSymbol> fieldSymbols, bool is64Bits, ref int size,
static void GetValueTypeSize (this ITypeSymbol originalSymbol, ITypeSymbol type, List<ITypeSymbol> fieldSymbols,
bool is64Bits, ref int size,
ref int maxElementSize)
{
// FIXME:
Expand All @@ -328,13 +334,15 @@ static void GetValueTypeSize (this ITypeSymbol originalSymbol, ITypeSymbol type,
GetValueTypeSize (originalSymbol, field.Type, fieldSymbols, is64Bits, ref size, ref maxElementSize);
continue;
}

var (marshalAsType, sizeConst) = marshalAs.Value;
var multiplier = 1;
switch (marshalAsType) {
case UnmanagedType.ByValArray:
var types = new List<ITypeSymbol> ();
var arrayTypeSymbol = (field as IArrayTypeSymbol)!;
GetValueTypeSize (originalSymbol, arrayTypeSymbol.ElementType, types, is64Bits, ref typeSize, ref maxElementSize);
GetValueTypeSize (originalSymbol, arrayTypeSymbol.ElementType, types, is64Bits, ref typeSize,
ref maxElementSize);
multiplier = sizeConst;
break;
case UnmanagedType.U1:
Expand All @@ -356,13 +364,14 @@ static void GetValueTypeSize (this ITypeSymbol originalSymbol, ITypeSymbol type,
typeSize = 8;
break;
default:
throw new Exception ($"Unhandled MarshalAs attribute: {marshalAs.Value} on field {field.ToDisplayString ()}");
throw new Exception (
$"Unhandled MarshalAs attribute: {marshalAs.Value} on field {field.ToDisplayString ()}");
}

fieldSymbols.Add (field.Type);
size = AlignAndAdd (size, typeSize, ref maxElementSize);
size += (multiplier - 1) * size;
}

}

/// <summary>
Expand Down Expand Up @@ -436,5 +445,4 @@ public static void GetInheritance (
parents = parentsBuilder.ToImmutable ();
interfaces = [.. interfacesSet];
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

using System;
using System.Collections.Generic;
using Microsoft.CodeAnalysis;
using Microsoft.Macios.Generator.Attributes;
using Microsoft.Macios.Generator.Availability;
Expand Down Expand Up @@ -149,4 +150,86 @@ public static BindingTypeData<T> GetBindingData<T> (this ISymbol symbol) where T
/// returned.</remarks>
public static FieldData<T>? GetFieldData<T> (this ISymbol symbol) where T : Enum
=> GetAttribute<FieldData<T>> (symbol, AttributesNames.GetFieldAttributeName<T>, FieldData<T>.TryParse);

public static bool X86NeedStret (ITypeSymbol returnType)
{
if (!returnType.IsValueType || returnType.SpecialType == SpecialType.System_Enum ||
returnType.TryGetBuiltInTypeSize ())
return false;

var fieldTypes = new List<ITypeSymbol> ();
var size = GetValueTypeSize (returnType, fieldTypes, false);

if (size > 8)
return true;

return fieldTypes.Count == 3;
}

public static bool X86_64NeedStret (ITypeSymbol returnType)
{
if (!returnType.IsValueType || returnType.SpecialType == SpecialType.System_Enum ||
returnType.TryGetBuiltInTypeSize ())
return false;

var fieldTypes = new List<ITypeSymbol> ();
return GetValueTypeSize (returnType, fieldTypes, true) > 16;
}

public static bool ArmNeedStret (ITypeSymbol returnType, Compilation compilation)
{
var currentPlatform = compilation.GetCurrentPlatform ();
bool has32bitArm = currentPlatform != PlatformName.TvOS && currentPlatform != PlatformName.MacOSX;
if (!has32bitArm)
return false;

ITypeSymbol t = returnType;

if (!t.IsValueType || t.SpecialType == SpecialType.System_Enum || t.TryGetBuiltInTypeSize())
return false;

var fieldTypes = new List<ITypeSymbol> ();
var size = t.GetValueTypeSize (fieldTypes, false);

bool isiOS = currentPlatform == PlatformName.iOS;

if (isiOS && size <= 4 && fieldTypes.Count == 1) {

#pragma warning disable format
return fieldTypes [0] switch {
{ Name: "nint" } => false,
{ Name: "nuint" } => false,
{ SpecialType: SpecialType.System_Char } => false,
{ SpecialType: SpecialType.System_Byte } => false,
{ SpecialType: SpecialType.System_SByte } => false,
{ SpecialType: SpecialType.System_UInt16 } => false,
{ SpecialType: SpecialType.System_Int16 } => false,
{ SpecialType: SpecialType.System_UInt32 } => false,
{ SpecialType: SpecialType.System_Int32 } => false,
{ SpecialType: SpecialType.System_IntPtr } => false,
{ SpecialType: SpecialType.System_UIntPtr } => false,
_ => true
};
#pragma warning restore format
}

return true;
}

/// <summary>
/// Return if a given ITypeSymbol requires to use the objc_MsgSend_stret variants.
/// </summary>
/// <param name="returnType">The type we are testing.</param>
/// <param name="compilation">The current compilation, used to determine the target platform.</param>
/// <returns>If the type represented by the symtol needs a stret call variant.</returns>
public static bool NeedsStret (this ITypeSymbol returnType, Compilation compilation)
{
if (X86NeedStret (returnType))
return true;

if (X86_64NeedStret (returnType))
return true;

return ArmNeedStret (returnType, compilation);
}
}

0 comments on commit a70cdd3

Please sign in to comment.