From 18852f647e7b67258d5a84b1bbf01a2420ed324c Mon Sep 17 00:00:00 2001 From: James Frowen Date: Fri, 21 Jul 2023 13:21:21 +0100 Subject: [PATCH] feat: adding attribute that allows for combination of checks Allows people to do a combination of condition with OR checks. This is not possible with current attributes because they will be combine using AND, [Server, HasAuthority] would require both server and hasAuthority to be true for the method to go away. Where as new [NetworkMethod(NetworkFlags.Server | NetworkFlags.HasAuthority)] would allow method to run if either server or HasAuthority to be true. --- Assets/Mirage/Runtime/CustomAttributes.cs | 41 ++++ .../Weaver/Processors/AttributeProcessor.cs | 89 ++++++- .../Runtime/ExampleGuards_NetworkMethod.cs | 223 ++++++++++++++++++ .../ExampleGuards_NetworkMethod.cs.meta | 11 + Assets/Tests/Runtime/GuardsTests.cs | 1 - 5 files changed, 358 insertions(+), 7 deletions(-) create mode 100644 Assets/Tests/Runtime/ExampleGuards_NetworkMethod.cs create mode 100644 Assets/Tests/Runtime/ExampleGuards_NetworkMethod.cs.meta diff --git a/Assets/Mirage/Runtime/CustomAttributes.cs b/Assets/Mirage/Runtime/CustomAttributes.cs index a2c2b4462f..58937d384f 100644 --- a/Assets/Mirage/Runtime/CustomAttributes.cs +++ b/Assets/Mirage/Runtime/CustomAttributes.cs @@ -184,6 +184,47 @@ public class LocalPlayerAttribute : Attribute public bool error = true; } + /// + /// Prevents this method from running unless the NetworkFlags match the current state + /// Can only be used inside a NetworkBehaviour + /// + [AttributeUsage(AttributeTargets.Method)] + public class NetworkMethodAttribute : Attribute + { + /// + /// If true, if called incorrectly method will throw.
+ /// If false, no error is thrown, but the method won't execute.
+ /// + /// useful for unity built in methods such as Await, Update, Start, etc. + /// + ///
+ public bool error = true; + + public NetworkMethodAttribute(NetworkFlags flags) { } + } + + [Flags] + public enum NetworkFlags + { + // note: NotActive can't be 0 as it needs its own flag + // This is so that people can check for (Server | NotActive) + /// + /// If both server and client are not active. Can be used to check for singleplayer or unspawned object + /// + NotActive = 1, + Server = 2, + Client = 4, + /// + /// If either Server or Client is active. + /// + /// Note this will not check host mode. For host mode you need to use and + /// + /// + Active = Server | Client, + HasAuthority = 8, + LocalOwner = 16, + } + /// /// Converts a string property into a Scene property in the inspector /// diff --git a/Assets/Mirage/Weaver/Processors/AttributeProcessor.cs b/Assets/Mirage/Weaver/Processors/AttributeProcessor.cs index 39bd6f7f7a..66448ad0e3 100644 --- a/Assets/Mirage/Weaver/Processors/AttributeProcessor.cs +++ b/Assets/Mirage/Weaver/Processors/AttributeProcessor.cs @@ -99,6 +99,7 @@ private void ProcessMethodAttributes(MethodDefinition md, FoundType foundType) InjectGuard(md, foundType, IsClient, "[Client] function '{0}' called when client not active"); InjectGuard(md, foundType, HasAuthority, "[Has Authority] function '{0}' called on player without authority"); InjectGuard(md, foundType, IsLocalPlayer, "[Local Player] function '{0}' called on nonlocal player"); + InjectNetworkMethodGuard(md, foundType); CheckAttribute(md, foundType); CheckAttribute(md, foundType); } @@ -115,32 +116,39 @@ private void CheckAttribute(MethodDefinition md, FoundType foundType } } - private void InjectGuard(MethodDefinition md, FoundType foundType, MethodReference predicate, string format) + private bool TryGetAttribte(MethodDefinition md, FoundType foundType, out CustomAttribute attribute) { - var attribute = md.GetCustomAttribute(); + attribute = md.GetCustomAttribute(); if (attribute == null) - return; + return false; if (md.IsAbstract) { logger.Error($"{typeof(TAttribute)} can't be applied to abstract method. Apply to override methods instead.", md); - return; + return false; } if (!foundType.IsNetworkBehaviour) { logger.Error($"{attribute.AttributeType.Name} method {md.Name} must be declared in a NetworkBehaviour", md); - return; + return false; } if (md.Name == "Awake" && !md.HasParameters) { logger.Error($"{attribute.AttributeType.Name} will not work on the Awake method.", md); - return; + return false; } // dont need to set modified for errors, so we set it here when we start doing ILProcessing modified = true; + return true; + } + + private void InjectGuard(MethodDefinition md, FoundType foundType, MethodReference predicate, string format) + { + if (!TryGetAttribte(md, foundType, out var attribute)) + return; var throwError = attribute.GetField("error", true); var worker = md.Body.GetILProcessor(); @@ -149,6 +157,7 @@ private void InjectGuard(MethodDefinition md, FoundType foundType, M worker.InsertBefore(top, worker.Create(OpCodes.Ldarg_0)); worker.InsertBefore(top, worker.Create(OpCodes.Call, predicate)); worker.InsertBefore(top, worker.Create(OpCodes.Brtrue, top)); + if (throwError) { var message = string.Format(format, md.Name); @@ -165,6 +174,74 @@ private void InjectGuard(MethodDefinition md, FoundType foundType, M } } + private void InjectNetworkMethodGuard(MethodDefinition md, FoundType foundType) + { + if (!TryGetAttribte(md, foundType, out var attribute)) + return; + + // Get the required flags from the attribute constructor argument + var requiredFlagsValue = (NetworkFlags)attribute.ConstructorArguments[0].Value; + var throwError = attribute.GetField("error", true); + var worker = md.Body.GetILProcessor(); + var top = md.Body.Instructions[0]; + + // check for each flag + // if true, then jump to start of code + // this should act as an OR check + if (requiredFlagsValue.HasFlag(NetworkFlags.Server)) + { + worker.InsertBefore(top, worker.Create(OpCodes.Ldarg_0)); + worker.InsertBefore(top, worker.Create(OpCodes.Call, IsServer)); + worker.InsertBefore(top, worker.Create(OpCodes.Brtrue, top)); + } + if (requiredFlagsValue.HasFlag(NetworkFlags.Client)) + { + worker.InsertBefore(top, worker.Create(OpCodes.Ldarg_0)); + worker.InsertBefore(top, worker.Create(OpCodes.Call, IsClient)); + worker.InsertBefore(top, worker.Create(OpCodes.Brtrue, top)); + } + if (requiredFlagsValue.HasFlag(NetworkFlags.HasAuthority)) + { + worker.InsertBefore(top, worker.Create(OpCodes.Ldarg_0)); + worker.InsertBefore(top, worker.Create(OpCodes.Call, HasAuthority)); + worker.InsertBefore(top, worker.Create(OpCodes.Brtrue, top)); + } + if (requiredFlagsValue.HasFlag(NetworkFlags.LocalOwner)) + { + // Check if the object is the local player's + worker.InsertBefore(top, worker.Create(OpCodes.Ldarg_0)); + worker.InsertBefore(top, worker.Create(OpCodes.Call, IsLocalPlayer)); + worker.InsertBefore(top, worker.Create(OpCodes.Brtrue, top)); + } + + if (requiredFlagsValue.HasFlag(NetworkFlags.NotActive)) + { + // Check if neither Server nor Clients are active + worker.InsertBefore(top, worker.Create(OpCodes.Ldarg_0)); + worker.InsertBefore(top, worker.Create(OpCodes.Call, IsServer)); + worker.InsertBefore(top, worker.Create(OpCodes.Ldarg_0)); + worker.InsertBefore(top, worker.Create(OpCodes.Call, IsClient)); + worker.InsertBefore(top, worker.Create(OpCodes.Or)); + worker.InsertBefore(top, worker.Create(OpCodes.Brfalse, top)); + } + + if (throwError) + { + var message = $"Method '{md.Name}' cannot be executed as {nameof(NetworkFlags)} condition is not met."; + worker.InsertBefore(top, worker.Create(OpCodes.Ldstr, message)); + worker.InsertBefore(top, worker.Create(OpCodes.Newobj, () => new MethodInvocationException(""))); + worker.InsertBefore(top, worker.Create(OpCodes.Throw)); + } + else + { + // dont need to set param or return if we throw + InjectGuardParameters(md, worker, top); + InjectGuardReturnValue(md, worker, top); + worker.InsertBefore(top, worker.Create(OpCodes.Ret)); + } + } + + // this is required to early-out from a function with "out" parameters private static void InjectGuardParameters(MethodDefinition md, ILProcessor worker, Instruction top) { diff --git a/Assets/Tests/Runtime/ExampleGuards_NetworkMethod.cs b/Assets/Tests/Runtime/ExampleGuards_NetworkMethod.cs new file mode 100644 index 0000000000..06f2687ada --- /dev/null +++ b/Assets/Tests/Runtime/ExampleGuards_NetworkMethod.cs @@ -0,0 +1,223 @@ +using System.Collections.Generic; +using NUnit.Framework; +using UnityEngine; + +namespace Mirage.Tests.Runtime.GuardTests +{ + public class ExampleGuards_NetworkMethod : NetworkBehaviour + { + public const int RETURN_VALUE = 10; + public const int OUT_VALUE_1 = 20; + public const int OUT_VALUE_2 = 20; + + // Define a list to keep track of all method calls + public readonly List Calls = new List(); + + [NetworkMethod(NetworkFlags.NotActive)] + public void CallNotActive() + { + Calls.Add(nameof(CallNotActive)); + } + [NetworkMethod(NetworkFlags.Server)] + public void CallServer() + { + Calls.Add(nameof(CallServer)); + } + [NetworkMethod(NetworkFlags.Server, error = false)] + public void CallServerCallback() + { + Calls.Add(nameof(CallServerCallback)); + } + [NetworkMethod(NetworkFlags.Client)] + public void CallClient() + { + Calls.Add(nameof(CallClient)); + } + + [NetworkMethod(NetworkFlags.Active)] + public void CallActive() + { + Calls.Add(nameof(CallActive)); + } + + [NetworkMethod(NetworkFlags.HasAuthority)] + public void CallHasAuthority() + { + Calls.Add(nameof(CallHasAuthority)); + } + + [NetworkMethod(NetworkFlags.LocalOwner)] + public void CallLocalOwner() + { + Calls.Add(nameof(CallLocalOwner)); + } + + [NetworkMethod(NetworkFlags.Server | NetworkFlags.HasAuthority)] + public void CallServerOrHasAuthority() + { + Calls.Add(nameof(CallServerOrHasAuthority)); + } + + [NetworkMethod(NetworkFlags.Server | NetworkFlags.NotActive)] + public void CallServerOrNotActive() + { + Calls.Add(nameof(CallServerOrNotActive)); + } + } + + public class GuardsTests_NetworkMethod : ClientServerSetup + { + [Test] + public void CanCallServerAsServer() + { + serverComponent.CallServer(); + Assert.That(serverComponent.Calls, Has.Count.EqualTo(1)); + Assert.That(serverComponent.Calls, Does.Contain(nameof(ExampleGuards_NetworkMethod.CallServer))); + } + + [Test] + public void CanCallServerCallbackAsServer() + { + serverComponent.CallServerCallback(); + Assert.That(serverComponent.Calls, Has.Count.EqualTo(1)); + Assert.That(serverComponent.Calls, Does.Contain(nameof(ExampleGuards_NetworkMethod.CallServerCallback))); + } + + [Test] + public void CanCallActiveAsActive() + { + serverComponent.CallActive(); + Assert.That(serverComponent.Calls, Has.Count.EqualTo(1)); + Assert.That(serverComponent.Calls, Does.Contain(nameof(ExampleGuards_NetworkMethod.CallActive))); + + clientComponent.CallActive(); + Assert.That(clientComponent.Calls, Has.Count.EqualTo(1)); + Assert.That(clientComponent.Calls, Does.Contain(nameof(ExampleGuards_NetworkMethod.CallActive))); + } + + [Test] + public void CannotCallActiveAsActive() + { + var guardedComponent = CreateBehaviour(); + Assert.Throws(() => + { + guardedComponent.CallActive(); + }); + Assert.That(guardedComponent.Calls, Is.Empty); + } + + [Test] + public void CanCallAuthorityWithAuthority() + { + clientComponent.CallHasAuthority(); + Assert.That(clientComponent.Calls, Has.Count.EqualTo(1)); + Assert.That(clientComponent.Calls, Does.Contain(nameof(ExampleGuards_NetworkMethod.CallHasAuthority))); + } + + [Test] + public void CannotCallAuthorityWithoutAuthority() + { + Assert.Throws(() => + { + serverComponent.CallHasAuthority(); + }); + Assert.That(serverComponent.Calls, Is.Empty); + } + + [Test] + public void CanCallLocalPlayerAsLocalPlayer() + { + clientComponent.CallLocalOwner(); + Assert.That(clientComponent.Calls, Has.Count.EqualTo(1)); + Assert.That(clientComponent.Calls, Does.Contain(nameof(ExampleGuards_NetworkMethod.CallLocalOwner))); + } + + [Test] + public void CannotCallLocalPlayerAsNonLocalPlayer() + { + Assert.Throws(() => + { + serverComponent.CallLocalOwner(); + }); + Assert.That(serverComponent.Calls, Is.Empty); + } + + [Test] + public void CannotCallNotActiveAsServer() + { + Assert.Throws(() => + { + serverComponent.CallNotActive(); + }); + Assert.That(serverComponent.Calls, Is.Empty); + } + + [Test] + public void CannotCallNotActiveAsClient() + { + Assert.Throws(() => + { + clientComponent.CallNotActive(); + }); + Assert.That(clientComponent.Calls, Is.Empty); + } + + [Test] + public void CannotCallNotActiveAsUnspawned() + { + var guardedComponent = CreateBehaviour(); + Debug.Assert(guardedComponent.NetId == 0); + + guardedComponent.CallNotActive(); + Assert.That(guardedComponent.Calls, Has.Count.EqualTo(1)); + Assert.That(guardedComponent.Calls, Does.Contain(nameof(ExampleGuards_NetworkMethod.CallNotActive))); + } + + [Test] + public void CanCallServerOrHasAuthority() + { + serverComponent.CallServerOrHasAuthority(); + Assert.That(serverComponent.Calls, Has.Count.EqualTo(1)); + Assert.That(serverComponent.Calls, Does.Contain(nameof(ExampleGuards_NetworkMethod.CallServerOrHasAuthority))); + + clientComponent.CallServerOrHasAuthority(); + Assert.That(clientComponent.Calls, Has.Count.EqualTo(1)); + Assert.That(clientComponent.Calls, Does.Contain(nameof(ExampleGuards_NetworkMethod.CallServerOrHasAuthority))); + } + + [Test] + public void CannotCallServerOrHasAuthorityAsUnspawned() + { + var guardedComponent = CreateBehaviour(); + Assert.Throws(() => + { + guardedComponent.CallServerOrHasAuthority(); + }); + Assert.That(guardedComponent.Calls, Is.Empty); + } + + [Test] + public void CanCallServerOrNotActive() + { + serverComponent.CallServerOrNotActive(); + Assert.That(serverComponent.Calls, Has.Count.EqualTo(1)); + Assert.That(serverComponent.Calls, Does.Contain(nameof(ExampleGuards_NetworkMethod.CallServerOrNotActive))); + + var guardedComponent = CreateBehaviour(); + Debug.Assert(guardedComponent.NetId == 0); + guardedComponent.CallServerOrNotActive(); + Assert.That(guardedComponent.Calls, Has.Count.EqualTo(1)); + Assert.That(guardedComponent.Calls, Does.Contain(nameof(ExampleGuards_NetworkMethod.CallServerOrNotActive))); + } + + [Test] + public void CannotCallServerOrNotActive() + { + Assert.Throws(() => + { + clientComponent.CallServerOrNotActive(); + }); + Assert.That(clientComponent.Calls, Is.Empty); + } + } +} diff --git a/Assets/Tests/Runtime/ExampleGuards_NetworkMethod.cs.meta b/Assets/Tests/Runtime/ExampleGuards_NetworkMethod.cs.meta new file mode 100644 index 0000000000..3fda1b69b5 --- /dev/null +++ b/Assets/Tests/Runtime/ExampleGuards_NetworkMethod.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: ae4dee919ad36ad44a0e95e14770d1af +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Assets/Tests/Runtime/GuardsTests.cs b/Assets/Tests/Runtime/GuardsTests.cs index 6efa195ab2..81a92e685f 100644 --- a/Assets/Tests/Runtime/GuardsTests.cs +++ b/Assets/Tests/Runtime/GuardsTests.cs @@ -459,6 +459,5 @@ public void GuardLocalPlayerNoError() guardedComponent.CallLocalPlayerNoError(); Assert.That(guardedComponent.Calls, Does.Not.Contain(nameof(ExampleGuards.CallLocalPlayerNoError))); } - } }