Skip to content

Commit

Permalink
Make functions a compile time construct on @strict (#2789)
Browse files Browse the repository at this point in the history
* Don't allow overriding functions with `@strict`

* Add nested function warning

* Lower base op cost

5 ops on `@strict`, 8 without

* Remove extra 4 ops

* Add tests and fix methods
  • Loading branch information
Vurv78 authored Nov 12, 2023
1 parent 10731c4 commit a4e6c4f
Show file tree
Hide file tree
Showing 4 changed files with 268 additions and 13 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
## SHOULD_FAIL:COMPILE

@strict

function test() {}

function test() {} # ERROR!
117 changes: 117 additions & 0 deletions data/expression2/tests/runtime/base/userfunctions/functions_const.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
## SHOULD_PASS:EXECUTE

@strict

# Ensure functions get called in the first place

Called = 0
function myfunction() {
Called = 1
}

myfunction()

assert(Called)


local X = 500
local Y = 1000
local Z = 5000

# Ensure function scoping doesn't affect outer scope

function test(X, Y, Z) {
assert(X == 1)
assert(Y == 2)
assert(Z == 3)
}

test(1, 2, 3)

assert(X == 500)
assert(Y == 1000)
assert(Z == 5000)

# Ensure functions return properly

function number returning() {
return 5
}

assert(returning() == 5)

function number returning2(X:array) {
return X[1, number] + 5
}

assert(returning2(array(5)) == 10)
assert(returning2(array()) == 5)

function array returningref(X:array) {
return X
}

local A = array()
assert(returningref(A):id() == A:id())

function returnvoid() {
if (1) { return }
error("unreachable")
}

returnvoid()

function void returnvoid2() {
return
}

returnvoid2()

function returnvoid3() {
return void
}

returnvoid3()

# Test recursion

function number recurse(N:number) {
if (N == 1) {
return 5
} else {
return recurse(N - 1) + 1
}
}

assert(recurse(10) == 14, recurse(10):toString())

Sentinel = -1
function recursevoid() {
Sentinel++
if (Sentinel == 0) {
recursevoid()
}
}

recursevoid()

assert(Sentinel == 1)

function number nilInput(X, Y:ranger, Z:vector) {
assert(Z == vec(1, 2, 3))
return 5
}

assert( nilInput(1, noranger(), vec(1, 2, 3)) == 5 )

Ran = 0

if (0) {
function constant() {
Ran = 1
}
}

constant()

assert(Ran)
120 changes: 120 additions & 0 deletions data/expression2/tests/runtime/base/userfunctions/methods_const.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
## SHOULD_PASS:EXECUTE

@strict

# Ensure methods get called in the first place

Called = 0
function number:mymethod() {
Called = 1
}

1:mymethod()

assert(Called)

local This = 10
local X = 500
local Y = 1000
local Z = 5000

# Ensure function scoping doesn't affect outer scope

function number number:method(X, Y, Z) {
assert(This == 500)
assert(X == 1)
assert(Y == 2)
assert(Z == 4)

return 5
}

assert( 500:method(1, 2, 4) == 5 )

assert(This == 10)
assert(X == 500)
assert(Y == 1000)
assert(Z == 5000)

# Ensure functions return properly

function number number:returning() {
return 5
}

assert(1:returning() == 5)

function number number:returning2(X:array) {
return X[1, number] + 5
}

assert(1:returning2(array(5)) == 10)
assert(1:returning2(array()) == 5)

function array number:returningref(X:array) {
return X
}

local A = array()
assert(1:returningref(A):id() == A:id())

function number:returnvoid() {
if (1) { return }
}

1:returnvoid()

function void number:returnvoid2() {
return
}

1:returnvoid2()

function number:returnvoid3() {
return void
}

1:returnvoid3()

# Test recursion

function number number:recurse(N:number) {
if (N == 1) {
return 5
} else {
return This:recurse(N - 1) + 1
}
}

assert(1:recurse(10) == 14, 1:recurse(10):toString())

Sentinel = -1
function number:recursevoid() {
Sentinel++
if (Sentinel == 0) {
This:recursevoid()
}
}

1:recursevoid()

assert(Sentinel == 1)

function number number:nilInput(X, Y:ranger, Z:vector) {
assert(Z == vec(1, 2, 3))
return 5
}

assert( 1:nilInput(1, noranger(), vec(1, 2, 3)) == 5 )

Ran = 0

if (0) {
function number:constant() {
Ran = 1
}
}

1:constant()

assert(Ran)
37 changes: 24 additions & 13 deletions lua/entities/gmod_wire_expression2/base/compiler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ end
---@field persist IODirective
---@field inputs IODirective
---@field outputs IODirective
---@field strict boolean
local Compiler = {}
Compiler.__index = Compiler

Expand All @@ -100,7 +101,7 @@ end
function Compiler.from(directives, dvars, includes)
local global_scope = Scope.new()
return setmetatable({
persist = directives.persist, inputs = directives.inputs, outputs = directives.outputs,
persist = directives.persist, inputs = directives.inputs, outputs = directives.outputs, strict = directives.strict,
global_scope = global_scope, scope = global_scope, warnings = {}, registered_events = {}, user_functions = {}, user_methods = {},
delta_vars = dvars or {}, includes = includes or {}
}, Compiler)
Expand Down Expand Up @@ -671,6 +672,10 @@ local CompileVisitors = {
end
end

if self.strict and not self.scope:IsGlobalScope() then
self:Warning("Functions should be in the top scope, nesting them does nothing", trace)
end

local fn_data, lookup_variadic, userfunction = self:GetFunction(name.value, param_types, meta_type)
if fn_data then
if not userfunction then
Expand All @@ -684,12 +689,15 @@ local CompileVisitors = {
self:Assert(fn_data.returns == nil, "Cannot override function returning void with differing return type", trace)
end

-- Tag function if it is ever re-declared. Used as an optimization
fn_data.const = fn_data.op == nil
if not self.strict then
self:Warning("Do not override functions. This is a hard error with @strict.", trace)
else
self:Error("Cannot override existing function '" .. name.value .. "'", trace)
end
end
end

local fn = { args = param_types, returns = return_type and { return_type }, meta = meta_type, cost = variadic_ty and 25 or 10, attrs = {} }
local fn = { args = param_types, returns = return_type and { return_type }, meta = meta_type, cost = variadic_ty and 10 or 5 + (self.strict and 0 or 3), attrs = {} }
local sig = table.concat(param_types, "", 1, #param_types - 1) .. ((variadic_ty and ".." or "") .. (param_types[#param_types] or ""))

if meta_type then
Expand Down Expand Up @@ -821,9 +829,11 @@ local CompileVisitors = {
local sig = name.value .. "(" .. (meta_type and (meta_type .. ":") or "") .. sig .. ")"
local fn = fn.op

return function(state) ---@param state RuntimeContext
state.funcs[sig] = fn
state.funcs_ret[sig] = return_type
if not self.strict then
return function(state) ---@param state RuntimeContext
state.funcs[sig] = fn
state.funcs_ret[sig] = return_type
end
end
end,

Expand Down Expand Up @@ -1481,7 +1491,6 @@ local CompileVisitors = {
self:Warning("Use of deprecated function (" .. name.value .. ") " .. (type(value) == "string" and value or ""), trace)
end

self.scope.data.ops = self.scope.data.ops + ((fn_data.cost or 15) + (fn_data.attrs["legacy"] and 10 or 0))

if fn_data.attrs["noreturn"] then
self.scope.data.dead = true
Expand All @@ -1490,8 +1499,9 @@ local CompileVisitors = {
local nargs = #args
local user_function = self.user_functions[name.value] and self.user_functions[name.value][arg_sig]
if user_function then
-- Calling a user function - chance of being overridden. Also not legacy.
if user_function.const then
if self.strict then -- If @strict, functions are compile time constructs (like events).
self.scope.data.ops = self.scope.data.ops + fn_data.cost

local fn = user_function.op
return function(state)
local rargs = {}
Expand All @@ -1501,6 +1511,8 @@ local CompileVisitors = {
return fn(state, rargs, types)
end, fn_data.returns and (fn_data.returns[1] ~= "" and fn_data.returns[1] or nil)
else
self.scope.data.ops = self.scope.data.ops + (fn_data.cost or 15) + (fn_data.attrs["legacy"] and 10 or 0)

local full_sig = name.value .. "(" .. arg_sig .. ")"
return function(state) ---@param state RuntimeContext
local rargs = {}
Expand Down Expand Up @@ -1559,16 +1571,15 @@ local CompileVisitors = {
local nargs = #args
local user_method = self.user_methods[meta_type] and self.user_methods[meta_type][name.value] and self.user_methods[meta_type][name.value][arg_sig]
if user_method then
-- Calling a user function - chance of being overridden. Also not legacy.
if user_method.const then
if self.strict then -- If @strict, functions are compile time constructs (like events).
local fn = user_method.op
return function(state)
local rargs = { meta(state) }
for k = 1, nargs do
rargs[k + 1] = args[k](state)
end
return fn(state, rargs, types)
end
end, fn_data.returns and (fn_data.returns[1] ~= "" and fn_data.returns[1] or nil)
else
local full_sig = name.value .. "(" .. meta_type .. ":" .. arg_sig .. ")"
return function(state) ---@param state RuntimeContext
Expand Down

0 comments on commit a4e6c4f

Please sign in to comment.