From c4e18a931258c0b94564f31956e271cdeb81dd87 Mon Sep 17 00:00:00 2001 From: Vurv <56230599+Vurv78@users.noreply.github.com> Date: Fri, 13 Oct 2023 14:07:47 -0700 Subject: [PATCH 1/5] Don't allow overriding functions with `@strict` --- .../restrictions/fn_override_strict.txt | 7 +++++ .../gmod_wire_expression2/base/compiler.lua | 26 ++++++++++++------- 2 files changed, 24 insertions(+), 9 deletions(-) create mode 100644 data/expression2/tests/compiler/compiler/restrictions/fn_override_strict.txt diff --git a/data/expression2/tests/compiler/compiler/restrictions/fn_override_strict.txt b/data/expression2/tests/compiler/compiler/restrictions/fn_override_strict.txt new file mode 100644 index 0000000000..d953af68a2 --- /dev/null +++ b/data/expression2/tests/compiler/compiler/restrictions/fn_override_strict.txt @@ -0,0 +1,7 @@ +## SHOULD_FAIL:COMPILE + +@strict + +function test() {} + +function test() {} # ERROR! \ No newline at end of file diff --git a/lua/entities/gmod_wire_expression2/base/compiler.lua b/lua/entities/gmod_wire_expression2/base/compiler.lua index 5ad1ebe51f..0af7ce068a 100644 --- a/lua/entities/gmod_wire_expression2/base/compiler.lua +++ b/lua/entities/gmod_wire_expression2/base/compiler.lua @@ -81,6 +81,7 @@ end ---@field persist IODirective ---@field inputs IODirective ---@field outputs IODirective +---@field strict boolean local Compiler = {} Compiler.__index = Compiler @@ -100,7 +101,7 @@ end function Compiler.from(directives, dvars, includes) local global_scope = Scope.new() return setmetatable({ - persist = directives.persist, inputs = directives.inputs, outputs = directives.outputs, + persist = directives.persist, inputs = directives.inputs, outputs = directives.outputs, strict = directives.strict, global_scope = global_scope, scope = global_scope, warnings = {}, registered_events = {}, user_functions = {}, user_methods = {}, delta_vars = dvars or {}, includes = includes or {} }, Compiler) @@ -684,8 +685,11 @@ local CompileVisitors = { self:Assert(fn_data.returns == nil, "Cannot override function returning void with differing return type", trace) end - -- Tag function if it is ever re-declared. Used as an optimization - fn_data.const = fn_data.op == nil + if not self.strict then + self:Warning("Do not override functions. This is a hard error with @strict.", trace) + else + self:Error("Cannot override existing function '" .. name.value .. "'", trace) + end end end @@ -821,9 +825,11 @@ local CompileVisitors = { local sig = name.value .. "(" .. (meta_type and (meta_type .. ":") or "") .. sig .. ")" local fn = fn.op - return function(state) ---@param state RuntimeContext - state.funcs[sig] = fn - state.funcs_ret[sig] = return_type + if not self.strict then + return function(state) ---@param state RuntimeContext + state.funcs[sig] = fn + state.funcs_ret[sig] = return_type + end end end, @@ -1448,7 +1454,6 @@ local CompileVisitors = { self:Warning("Use of deprecated function (" .. name.value .. ") " .. (type(value) == "string" and value or ""), trace) end - self.scope.data.ops = self.scope.data.ops + ((fn_data.cost or 15) + (fn_data.attrs["legacy"] and 10 or 0)) if fn_data.attrs["noreturn"] then self.scope.data.dead = true @@ -1457,8 +1462,9 @@ local CompileVisitors = { local nargs = #args local user_function = self.user_functions[name.value] and self.user_functions[name.value][arg_sig] if user_function then - -- Calling a user function - chance of being overridden. Also not legacy. - if user_function.const then + if self.strict then -- If @strict, functions are compile time constructs (like events). + self.scope.data.ops = self.scope.data.ops + fn_data.cost + local fn = user_function.op return function(state) local rargs = {} @@ -1468,6 +1474,8 @@ local CompileVisitors = { return fn(state, rargs, types) end, fn_data.returns and (fn_data.returns[1] ~= "" and fn_data.returns[1] or nil) else + self.scope.data.ops = self.scope.data.ops + 4 + ((fn_data.cost or 15) + (fn_data.attrs["legacy"] and 10 or 0)) + local full_sig = name.value .. "(" .. arg_sig .. ")" return function(state) ---@param state RuntimeContext local rargs = {} From e7b246b78af4ee9dd74e93ce556f4b99de1327c4 Mon Sep 17 00:00:00 2001 From: Vurv <56230599+Vurv78@users.noreply.github.com> Date: Fri, 13 Oct 2023 14:19:18 -0700 Subject: [PATCH 2/5] Add nested function warning --- lua/entities/gmod_wire_expression2/base/compiler.lua | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lua/entities/gmod_wire_expression2/base/compiler.lua b/lua/entities/gmod_wire_expression2/base/compiler.lua index 0af7ce068a..ba8376b428 100644 --- a/lua/entities/gmod_wire_expression2/base/compiler.lua +++ b/lua/entities/gmod_wire_expression2/base/compiler.lua @@ -672,6 +672,10 @@ local CompileVisitors = { end end + if self.strict and not self.scope:IsGlobalScope() then + self:Warning("Functions should be in the top scope, nesting them does nothing", trace) + end + local fn_data, lookup_variadic, userfunction = self:GetFunction(name.value, param_types, meta_type) if fn_data then if not userfunction then From b115c1f6926be32ccab11ff323fff44630149d5f Mon Sep 17 00:00:00 2001 From: Vurv <56230599+Vurv78@users.noreply.github.com> Date: Fri, 13 Oct 2023 14:28:58 -0700 Subject: [PATCH 3/5] Lower base op cost 5 ops on `@strict`, 8 without --- lua/entities/gmod_wire_expression2/base/compiler.lua | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lua/entities/gmod_wire_expression2/base/compiler.lua b/lua/entities/gmod_wire_expression2/base/compiler.lua index ba8376b428..e53f9877ec 100644 --- a/lua/entities/gmod_wire_expression2/base/compiler.lua +++ b/lua/entities/gmod_wire_expression2/base/compiler.lua @@ -697,7 +697,7 @@ local CompileVisitors = { end end - local fn = { args = param_types, returns = return_type and { return_type }, meta = meta_type, cost = variadic_ty and 25 or 10, attrs = {} } + local fn = { args = param_types, returns = return_type and { return_type }, meta = meta_type, cost = variadic_ty and 10 or 5 + (self.strict and 0 or 3), attrs = {} } local sig = table.concat(param_types, "", 1, #param_types - 1) .. ((variadic_ty and ".." or "") .. (param_types[#param_types] or "")) if meta_type then From 9c1bcaf1236d622e44a84eb8c0a5257706f620cf Mon Sep 17 00:00:00 2001 From: Vurv <56230599+Vurv78@users.noreply.github.com> Date: Fri, 13 Oct 2023 14:31:35 -0700 Subject: [PATCH 4/5] Remove extra 4 ops --- lua/entities/gmod_wire_expression2/base/compiler.lua | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lua/entities/gmod_wire_expression2/base/compiler.lua b/lua/entities/gmod_wire_expression2/base/compiler.lua index e53f9877ec..d491cef755 100644 --- a/lua/entities/gmod_wire_expression2/base/compiler.lua +++ b/lua/entities/gmod_wire_expression2/base/compiler.lua @@ -1478,7 +1478,7 @@ local CompileVisitors = { return fn(state, rargs, types) end, fn_data.returns and (fn_data.returns[1] ~= "" and fn_data.returns[1] or nil) else - self.scope.data.ops = self.scope.data.ops + 4 + ((fn_data.cost or 15) + (fn_data.attrs["legacy"] and 10 or 0)) + self.scope.data.ops = self.scope.data.ops + (fn_data.cost or 15) + (fn_data.attrs["legacy"] and 10 or 0) local full_sig = name.value .. "(" .. arg_sig .. ")" return function(state) ---@param state RuntimeContext From 64f22a6c9e07f6296070a9bd40c62a9b4f9541da Mon Sep 17 00:00:00 2001 From: Vurv <56230599+Vurv78@users.noreply.github.com> Date: Fri, 13 Oct 2023 14:42:39 -0700 Subject: [PATCH 5/5] Add tests and fix methods --- .../base/userfunctions/functions_const.txt | 117 +++++++++++++++++ .../base/userfunctions/methods_const.txt | 120 ++++++++++++++++++ .../gmod_wire_expression2/base/compiler.lua | 5 +- 3 files changed, 239 insertions(+), 3 deletions(-) create mode 100644 data/expression2/tests/runtime/base/userfunctions/functions_const.txt create mode 100644 data/expression2/tests/runtime/base/userfunctions/methods_const.txt diff --git a/data/expression2/tests/runtime/base/userfunctions/functions_const.txt b/data/expression2/tests/runtime/base/userfunctions/functions_const.txt new file mode 100644 index 0000000000..29adb2a5e7 --- /dev/null +++ b/data/expression2/tests/runtime/base/userfunctions/functions_const.txt @@ -0,0 +1,117 @@ +## SHOULD_PASS:EXECUTE + +@strict + +# Ensure functions get called in the first place + +Called = 0 +function myfunction() { + Called = 1 +} + +myfunction() + +assert(Called) + + +local X = 500 +local Y = 1000 +local Z = 5000 + +# Ensure function scoping doesn't affect outer scope + +function test(X, Y, Z) { + assert(X == 1) + assert(Y == 2) + assert(Z == 3) +} + +test(1, 2, 3) + +assert(X == 500) +assert(Y == 1000) +assert(Z == 5000) + +# Ensure functions return properly + +function number returning() { + return 5 +} + +assert(returning() == 5) + +function number returning2(X:array) { + return X[1, number] + 5 +} + +assert(returning2(array(5)) == 10) +assert(returning2(array()) == 5) + +function array returningref(X:array) { + return X +} + +local A = array() +assert(returningref(A):id() == A:id()) + +function returnvoid() { + if (1) { return } + error("unreachable") +} + +returnvoid() + +function void returnvoid2() { + return +} + +returnvoid2() + +function returnvoid3() { + return void +} + +returnvoid3() + +# Test recursion + +function number recurse(N:number) { + if (N == 1) { + return 5 + } else { + return recurse(N - 1) + 1 + } +} + +assert(recurse(10) == 14, recurse(10):toString()) + +Sentinel = -1 +function recursevoid() { + Sentinel++ + if (Sentinel == 0) { + recursevoid() + } +} + +recursevoid() + +assert(Sentinel == 1) + +function number nilInput(X, Y:ranger, Z:vector) { + assert(Z == vec(1, 2, 3)) + return 5 +} + +assert( nilInput(1, noranger(), vec(1, 2, 3)) == 5 ) + +Ran = 0 + +if (0) { + function constant() { + Ran = 1 + } +} + +constant() + +assert(Ran) \ No newline at end of file diff --git a/data/expression2/tests/runtime/base/userfunctions/methods_const.txt b/data/expression2/tests/runtime/base/userfunctions/methods_const.txt new file mode 100644 index 0000000000..101e29fc00 --- /dev/null +++ b/data/expression2/tests/runtime/base/userfunctions/methods_const.txt @@ -0,0 +1,120 @@ +## SHOULD_PASS:EXECUTE + +@strict + +# Ensure methods get called in the first place + +Called = 0 +function number:mymethod() { + Called = 1 +} + +1:mymethod() + +assert(Called) + +local This = 10 +local X = 500 +local Y = 1000 +local Z = 5000 + +# Ensure function scoping doesn't affect outer scope + +function number number:method(X, Y, Z) { + assert(This == 500) + assert(X == 1) + assert(Y == 2) + assert(Z == 4) + + return 5 +} + +assert( 500:method(1, 2, 4) == 5 ) + +assert(This == 10) +assert(X == 500) +assert(Y == 1000) +assert(Z == 5000) + +# Ensure functions return properly + +function number number:returning() { + return 5 +} + +assert(1:returning() == 5) + +function number number:returning2(X:array) { + return X[1, number] + 5 +} + +assert(1:returning2(array(5)) == 10) +assert(1:returning2(array()) == 5) + +function array number:returningref(X:array) { + return X +} + +local A = array() +assert(1:returningref(A):id() == A:id()) + +function number:returnvoid() { + if (1) { return } +} + +1:returnvoid() + +function void number:returnvoid2() { + return +} + +1:returnvoid2() + +function number:returnvoid3() { + return void +} + +1:returnvoid3() + +# Test recursion + +function number number:recurse(N:number) { + if (N == 1) { + return 5 + } else { + return This:recurse(N - 1) + 1 + } +} + +assert(1:recurse(10) == 14, 1:recurse(10):toString()) + +Sentinel = -1 +function number:recursevoid() { + Sentinel++ + if (Sentinel == 0) { + This:recursevoid() + } +} + +1:recursevoid() + +assert(Sentinel == 1) + +function number number:nilInput(X, Y:ranger, Z:vector) { + assert(Z == vec(1, 2, 3)) + return 5 +} + +assert( 1:nilInput(1, noranger(), vec(1, 2, 3)) == 5 ) + +Ran = 0 + +if (0) { + function number:constant() { + Ran = 1 + } +} + +1:constant() + +assert(Ran) \ No newline at end of file diff --git a/lua/entities/gmod_wire_expression2/base/compiler.lua b/lua/entities/gmod_wire_expression2/base/compiler.lua index d491cef755..9bbdd2f71b 100644 --- a/lua/entities/gmod_wire_expression2/base/compiler.lua +++ b/lua/entities/gmod_wire_expression2/base/compiler.lua @@ -1538,8 +1538,7 @@ local CompileVisitors = { local nargs = #args local user_method = self.user_methods[meta_type] and self.user_methods[meta_type][name.value] and self.user_methods[meta_type][name.value][arg_sig] if user_method then - -- Calling a user function - chance of being overridden. Also not legacy. - if user_method.const then + if self.strict then -- If @strict, functions are compile time constructs (like events). local fn = user_method.op return function(state) local rargs = { meta(state) } @@ -1547,7 +1546,7 @@ local CompileVisitors = { rargs[k + 1] = args[k](state) end return fn(state, rargs, types) - end + end, fn_data.returns and (fn_data.returns[1] ~= "" and fn_data.returns[1] or nil) else local full_sig = name.value .. "(" .. meta_type .. ":" .. arg_sig .. ")" return function(state) ---@param state RuntimeContext