Skip to content
This repository has been archived by the owner on Oct 1, 2024. It is now read-only.

Commit

Permalink
Move partial unmap handler to the native signal handler (#3437)
Browse files Browse the repository at this point in the history
* Initial commit with a lot of testing stuff.

* Partial Unmap Cleanup Part 1

* Fix some minor issues, hopefully windows tests.

* Disable partial unmap tests on macos for now

Weird issue.

* Goodbye magic number

* Add COMPlus_EnableAlternateStackCheck for tests

`COMPlus_EnableAlternateStackCheck` is needed for NullReferenceException handling to work on linux after registering the signal handler, due to how dotnet registers its own signal handler.

* Address some feedback

* Force retry when memory is mapped in memory tracking

This case existed before, but returning `false` no longer retries, so it would crash immediately after unprotecting the memory... Now, we return `true` to deliberately retry.

This case existed before (was just broken by this change) and I don't really want to look into fixing the issue right now. Technically, this means that on guest code partial unmaps will retry _due to this_ rather than hitting the handler. I don't expect this to cause any issues.

This should fix random crashes in Xenoblade Chronicles 2.

* Use IsRangeMapped

* Suppress MockMemoryManager.UnmapEvent warning

This event is not signalled by the mock memory manager.

* Remove 4kb mapping
  • Loading branch information
riperiperi authored Jul 29, 2022
1 parent 952d013 commit 14ce9e1
Show file tree
Hide file tree
Showing 24 changed files with 1,356 additions and 392 deletions.
21 changes: 19 additions & 2 deletions ARMeilleure/Signal/NativeSignalHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,29 @@ private static Operand EmitGenericRegionCheck(EmitterContext context, IntPtr sig
// Only call tracking if in range.
context.BranchIfFalse(nextLabel, inRange, BasicBlockFrequency.Cold);

context.Copy(inRegionLocal, Const(1));
Operand offset = context.BitwiseAnd(context.Subtract(faultAddress, rangeAddress), Const(~PageMask));

// Call the tracking action, with the pointer's relative offset to the base address.
Operand trackingActionPtr = context.Load(OperandType.I64, Const((ulong)signalStructPtr + rangeBaseOffset + 20));
context.Call(trackingActionPtr, OperandType.I32, offset, Const(PageSize), isWrite, Const(0));

context.Copy(inRegionLocal, Const(0));

Operand skipActionLabel = Label();

// Tracking action should be non-null to call it, otherwise assume false return.
context.BranchIfFalse(skipActionLabel, trackingActionPtr);
Operand result = context.Call(trackingActionPtr, OperandType.I32, offset, Const(PageSize), isWrite, Const(0));
context.Copy(inRegionLocal, result);

context.MarkLabel(skipActionLabel);

// If the tracking action returns false or does not exist, it might be an invalid access due to a partial overlap on Windows.
if (OperatingSystem.IsWindows())
{
context.BranchIfTrue(endLabel, inRegionLocal);

context.Copy(inRegionLocal, WindowsPartialUnmapHandler.EmitRetryFromAccessViolation(context));
}

context.Branch(endLabel);

Expand Down
84 changes: 84 additions & 0 deletions ARMeilleure/Signal/TestMethods.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
using ARMeilleure.IntermediateRepresentation;
using ARMeilleure.Translation;
using System;

using static ARMeilleure.IntermediateRepresentation.Operand.Factory;

namespace ARMeilleure.Signal
{
public struct NativeWriteLoopState
{
public int Running;
public int Error;
}

public static class TestMethods
{
public delegate bool DebugPartialUnmap();
public delegate int DebugThreadLocalMapGetOrReserve(int threadId, int initialState);
public delegate void DebugNativeWriteLoop(IntPtr nativeWriteLoopPtr, IntPtr writePtr);

public static DebugPartialUnmap GenerateDebugPartialUnmap()
{
EmitterContext context = new EmitterContext();

var result = WindowsPartialUnmapHandler.EmitRetryFromAccessViolation(context);

context.Return(result);

// Compile and return the function.

ControlFlowGraph cfg = context.GetControlFlowGraph();

OperandType[] argTypes = new OperandType[] { OperandType.I64 };

return Compiler.Compile(cfg, argTypes, OperandType.I32, CompilerOptions.HighCq).Map<DebugPartialUnmap>();
}

public static DebugThreadLocalMapGetOrReserve GenerateDebugThreadLocalMapGetOrReserve(IntPtr structPtr)
{
EmitterContext context = new EmitterContext();

var result = WindowsPartialUnmapHandler.EmitThreadLocalMapIntGetOrReserve(context, structPtr, context.LoadArgument(OperandType.I32, 0), context.LoadArgument(OperandType.I32, 1));

context.Return(result);

// Compile and return the function.

ControlFlowGraph cfg = context.GetControlFlowGraph();

OperandType[] argTypes = new OperandType[] { OperandType.I64 };

return Compiler.Compile(cfg, argTypes, OperandType.I32, CompilerOptions.HighCq).Map<DebugThreadLocalMapGetOrReserve>();
}

public static DebugNativeWriteLoop GenerateDebugNativeWriteLoop()
{
EmitterContext context = new EmitterContext();

// Loop a write to the target address until "running" is false.

Operand structPtr = context.Copy(context.LoadArgument(OperandType.I64, 0));
Operand writePtr = context.Copy(context.LoadArgument(OperandType.I64, 1));

Operand loopLabel = Label();
context.MarkLabel(loopLabel);

context.Store(writePtr, Const(12345));

Operand running = context.Load(OperandType.I32, structPtr);

context.BranchIfTrue(loopLabel, running);

context.Return();

// Compile and return the function.

ControlFlowGraph cfg = context.GetControlFlowGraph();

OperandType[] argTypes = new OperandType[] { OperandType.I64 };

return Compiler.Compile(cfg, argTypes, OperandType.None, CompilerOptions.HighCq).Map<DebugNativeWriteLoop>();
}
}
}
186 changes: 186 additions & 0 deletions ARMeilleure/Signal/WindowsPartialUnmapHandler.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
using ARMeilleure.IntermediateRepresentation;
using ARMeilleure.Translation;
using Ryujinx.Common.Memory.PartialUnmaps;
using System;

using static ARMeilleure.IntermediateRepresentation.Operand.Factory;

namespace ARMeilleure.Signal
{
/// <summary>
/// Methods to handle signals caused by partial unmaps. See the structs for C# implementations of the methods.
/// </summary>
internal static class WindowsPartialUnmapHandler
{
public static Operand EmitRetryFromAccessViolation(EmitterContext context)
{
IntPtr partialRemapStatePtr = PartialUnmapState.GlobalState;
IntPtr localCountsPtr = IntPtr.Add(partialRemapStatePtr, PartialUnmapState.LocalCountsOffset);

// Get the lock first.
EmitNativeReaderLockAcquire(context, IntPtr.Add(partialRemapStatePtr, PartialUnmapState.PartialUnmapLockOffset));

IntPtr getCurrentThreadId = WindowsSignalHandlerRegistration.GetCurrentThreadIdFunc();
Operand threadId = context.Call(Const((ulong)getCurrentThreadId), OperandType.I32);
Operand threadIndex = EmitThreadLocalMapIntGetOrReserve(context, localCountsPtr, threadId, Const(0));

Operand endLabel = Label();
Operand retry = context.AllocateLocal(OperandType.I32);
Operand threadIndexValidLabel = Label();

context.BranchIfFalse(threadIndexValidLabel, context.ICompareEqual(threadIndex, Const(-1)));

context.Copy(retry, Const(1)); // Always retry when thread local cannot be allocated.

context.Branch(endLabel);

context.MarkLabel(threadIndexValidLabel);

Operand threadLocalPartialUnmapsPtr = EmitThreadLocalMapIntGetValuePtr(context, localCountsPtr, threadIndex);
Operand threadLocalPartialUnmaps = context.Load(OperandType.I32, threadLocalPartialUnmapsPtr);
Operand partialUnmapsCount = context.Load(OperandType.I32, Const((ulong)IntPtr.Add(partialRemapStatePtr, PartialUnmapState.PartialUnmapsCountOffset)));

context.Copy(retry, context.ICompareNotEqual(threadLocalPartialUnmaps, partialUnmapsCount));

Operand noRetryLabel = Label();

context.BranchIfFalse(noRetryLabel, retry);

// if (retry) {

context.Store(threadLocalPartialUnmapsPtr, partialUnmapsCount);

context.Branch(endLabel);

context.MarkLabel(noRetryLabel);

// }

context.MarkLabel(endLabel);

// Finally, release the lock and return the retry value.
EmitNativeReaderLockRelease(context, IntPtr.Add(partialRemapStatePtr, PartialUnmapState.PartialUnmapLockOffset));

return retry;
}

public static Operand EmitThreadLocalMapIntGetOrReserve(EmitterContext context, IntPtr threadLocalMapPtr, Operand threadId, Operand initialState)
{
Operand idsPtr = Const((ulong)IntPtr.Add(threadLocalMapPtr, ThreadLocalMap<int>.ThreadIdsOffset));

Operand i = context.AllocateLocal(OperandType.I32);

context.Copy(i, Const(0));

// (Loop 1) Check all slots for a matching Thread ID (while also trying to allocate)

Operand endLabel = Label();

Operand loopLabel = Label();
context.MarkLabel(loopLabel);

Operand offset = context.Multiply(i, Const(sizeof(int)));
Operand idPtr = context.Add(idsPtr, context.SignExtend32(OperandType.I64, offset));

// Check that this slot has the thread ID.
Operand existingId = context.CompareAndSwap(idPtr, threadId, threadId);

// If it was already the thread ID, then we just need to return i.
context.BranchIfTrue(endLabel, context.ICompareEqual(existingId, threadId));

context.Copy(i, context.Add(i, Const(1)));

context.BranchIfTrue(loopLabel, context.ICompareLess(i, Const(ThreadLocalMap<int>.MapSize)));

// (Loop 2) Try take a slot that is 0 with our Thread ID.

context.Copy(i, Const(0)); // Reset i.

Operand loop2Label = Label();
context.MarkLabel(loop2Label);

Operand offset2 = context.Multiply(i, Const(sizeof(int)));
Operand idPtr2 = context.Add(idsPtr, context.SignExtend32(OperandType.I64, offset2));

// Try and swap in the thread id on top of 0.
Operand existingId2 = context.CompareAndSwap(idPtr2, Const(0), threadId);

Operand idNot0Label = Label();

// If it was 0, then we need to initialize the struct entry and return i.
context.BranchIfFalse(idNot0Label, context.ICompareEqual(existingId2, Const(0)));

Operand structsPtr = Const((ulong)IntPtr.Add(threadLocalMapPtr, ThreadLocalMap<int>.StructsOffset));
Operand structPtr = context.Add(structsPtr, context.SignExtend32(OperandType.I64, offset2));
context.Store(structPtr, initialState);

context.Branch(endLabel);

context.MarkLabel(idNot0Label);

context.Copy(i, context.Add(i, Const(1)));

context.BranchIfTrue(loop2Label, context.ICompareLess(i, Const(ThreadLocalMap<int>.MapSize)));

context.Copy(i, Const(-1)); // Could not place the thread in the list.

context.MarkLabel(endLabel);

return context.Copy(i);
}

private static Operand EmitThreadLocalMapIntGetValuePtr(EmitterContext context, IntPtr threadLocalMapPtr, Operand index)
{
Operand offset = context.Multiply(index, Const(sizeof(int)));
Operand structsPtr = Const((ulong)IntPtr.Add(threadLocalMapPtr, ThreadLocalMap<int>.StructsOffset));

return context.Add(structsPtr, context.SignExtend32(OperandType.I64, offset));
}

private static void EmitThreadLocalMapIntRelease(EmitterContext context, IntPtr threadLocalMapPtr, Operand threadId, Operand index)
{
Operand offset = context.Multiply(index, Const(sizeof(int)));
Operand idsPtr = Const((ulong)IntPtr.Add(threadLocalMapPtr, ThreadLocalMap<int>.ThreadIdsOffset));
Operand idPtr = context.Add(idsPtr, context.SignExtend32(OperandType.I64, offset));

context.CompareAndSwap(idPtr, threadId, Const(0));
}

private static void EmitAtomicAddI32(EmitterContext context, Operand ptr, Operand additive)
{
Operand loop = Label();
context.MarkLabel(loop);

Operand initial = context.Load(OperandType.I32, ptr);
Operand newValue = context.Add(initial, additive);

Operand replaced = context.CompareAndSwap(ptr, initial, newValue);

context.BranchIfFalse(loop, context.ICompareEqual(initial, replaced));
}

private static void EmitNativeReaderLockAcquire(EmitterContext context, IntPtr nativeReaderLockPtr)
{
Operand writeLockPtr = Const((ulong)IntPtr.Add(nativeReaderLockPtr, NativeReaderWriterLock.WriteLockOffset));

// Spin until we can acquire the write lock.
Operand spinLabel = Label();
context.MarkLabel(spinLabel);

// Old value must be 0 to continue (we gained the write lock)
context.BranchIfTrue(spinLabel, context.CompareAndSwap(writeLockPtr, Const(0), Const(1)));

// Increment reader count.
EmitAtomicAddI32(context, Const((ulong)IntPtr.Add(nativeReaderLockPtr, NativeReaderWriterLock.ReaderCountOffset)), Const(1));

// Release write lock.
context.CompareAndSwap(writeLockPtr, Const(1), Const(0));
}

private static void EmitNativeReaderLockRelease(EmitterContext context, IntPtr nativeReaderLockPtr)
{
// Decrement reader count.
EmitAtomicAddI32(context, Const((ulong)IntPtr.Add(nativeReaderLockPtr, NativeReaderWriterLock.ReaderCountOffset)), Const(-1));
}
}
}
23 changes: 22 additions & 1 deletion ARMeilleure/Signal/WindowsSignalHandlerRegistration.cs
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
using System;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;

namespace ARMeilleure.Signal
{
class WindowsSignalHandlerRegistration
unsafe class WindowsSignalHandlerRegistration
{
[DllImport("kernel32.dll")]
private static extern IntPtr AddVectoredExceptionHandler(uint first, IntPtr handler);

[DllImport("kernel32.dll")]
private static extern ulong RemoveVectoredExceptionHandler(IntPtr handle);

[DllImport("kernel32.dll", SetLastError = true, CharSet = CharSet.Ansi)]
static extern IntPtr LoadLibrary([MarshalAs(UnmanagedType.LPStr)] string lpFileName);

[DllImport("kernel32.dll", CharSet = CharSet.Ansi, ExactSpelling = true, SetLastError = true)]
private static extern IntPtr GetProcAddress(IntPtr hModule, string procName);

private static IntPtr _getCurrentThreadIdPtr;

public static IntPtr RegisterExceptionHandler(IntPtr action)
{
return AddVectoredExceptionHandler(1, action);
Expand All @@ -20,5 +29,17 @@ public static bool RemoveExceptionHandler(IntPtr handle)
{
return RemoveVectoredExceptionHandler(handle) != 0;
}

public static IntPtr GetCurrentThreadIdFunc()
{
if (_getCurrentThreadIdPtr == IntPtr.Zero)
{
IntPtr handle = LoadLibrary("kernel32.dll");

_getCurrentThreadIdPtr = GetProcAddress(handle, "GetCurrentThreadId");
}

return _getCurrentThreadIdPtr;
}
}
}
Loading

0 comments on commit 14ce9e1

Please sign in to comment.