Skip to content

Commit

Permalink
Enforce function returns at compile time (#2788)
Browse files Browse the repository at this point in the history
* Enforce function returns at compile time

* Raise cost of dynamic functions

* Revert previous change

Forgot const optimization isn't currently being used. Needs an analyzer step.

Should really not have included that code in the rewrite but it's fine being there for now.

* Add test cases

* Simplify

No need for `returned` field, can just check if scope is dead.

* Add switch case return logic

Adds logic for switch case to be the last statement in a function and be detected for return values.
  • Loading branch information
Vurv78 authored Oct 13, 2023
1 parent 56ee9b9 commit 59ca86e
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 27 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
## SHOULD_FAIL:COMPILE

function string nothing() {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
## SHOULD_PASS:COMPILE

function string nothing() {
return "something"
}

function number deadcase() {
if (1) {
return 2158129
} else {
return 2321515
}
}

function number switchcase() {
switch (5) {
case 5,
return 2
default,
return 5
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
## SHOULD_FAIL:COMPILE

function string failure() {
switch (5) {
case 2,
break
default,
break

# 'break' does not return a value or cause a runtime error, just early returns switch.
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
## SHOULD_FAIL:COMPILE

function string failure() {
switch (5) {
case 2,
return "boowomp"
# no default case, compiler can't guarantee that this always runs, fails to compile.
}
}
69 changes: 42 additions & 27 deletions lua/entities/gmod_wire_expression2/base/compiler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ cvars.AddChangeCallback("wire_expression2_quotatick", function(_, old, new)
end, "compiler_quota_check")

---@class ScopeData
---@field dead boolean?
---@field dead "ret"|true?
---@field loop boolean?
---@field switch_case boolean?
---@field function { [1]: string, [2]: EnvFunction}?
Expand Down Expand Up @@ -281,8 +281,10 @@ local CompileVisitors = {
---@param data { [1]: Node?, [2]: Node }[]
[NodeVariant.If] = function (self, trace, data)
local chain = {} ---@type { [1]: RuntimeOperator?, [2]: RuntimeOperator }[]
local dead, els = true, false

for i, ifeif in ipairs(data) do
self:Scope(function()
self:Scope(function(scope)
if ifeif[1] then -- if or elseif
local expr, expr_ty = self:CompileExpr(ifeif[1])

Expand All @@ -301,11 +303,19 @@ local CompileVisitors = {
self:CompileStmt(ifeif[2])
}
end

dead = dead and scope.data.dead
else -- else block
chain[i] = { nil, self:CompileStmt(ifeif[2]) }
dead, els = dead and scope.data.dead, true
end
end)
end

if els and dead then -- if (0) { return } else { return } mark any code after as dead
self.scope.data.dead = "ret"
end

return function(state) ---@param state RuntimeContext
for _, data in ipairs(chain) do
local cond, block = data[1], data[2]
Expand Down Expand Up @@ -512,14 +522,16 @@ local CompileVisitors = {
---@param data { [1]: Node, [2]: {[1]: Node, [2]: Node}[], [3]: Node? }
[NodeVariant.Switch] = function (self, trace, data)
local expr, expr_ty = self:CompileExpr(data[1])
local dead = true

local cases = {} ---@type { [1]: RuntimeOperator, [2]: RuntimeOperator }[]
for i, case in ipairs(data[2]) do
local cond, cond_ty = self:CompileExpr(case[1])
local block
self:Scope(function(scope)
local block = self:Scope(function(scope)
scope.data.switch_case = true
block = self:CompileStmt(case[2])
local b = self:CompileStmt(case[2])
dead = dead and scope.data.dead == "ret"
return b
end)

local eq = self:GetOperator("eq", { expr_ty, cond_ty }, case[1].trace)
Expand All @@ -531,7 +543,16 @@ local CompileVisitors = {
}
end

local default = data[3] and self:Scope(function() return self:CompileStmt(data[3]) end)
local default = data[3] and self:Scope(function(scope)
local b = self:CompileStmt(data[3])
dead = dead and scope.data.dead == "ret"
return b
end)

if dead and default then -- if all cases dead and has default case, mark scope as dead.
self.scope.data.dead = true
end

local ncases = #cases

return function(state) ---@param state RuntimeContext
Expand Down Expand Up @@ -668,7 +689,7 @@ local CompileVisitors = {
end
end

local fn = { args = param_types, returns = return_type and { return_type }, meta = meta_type, cost = 20, attrs = {} }
local fn = { args = param_types, returns = return_type and { return_type }, meta = meta_type, cost = variadic_ty and 25 or 10, attrs = {} }
local sig = table.concat(param_types, "", 1, #param_types - 1) .. ((variadic_ty and ".." or "") .. (param_types[#param_types] or ""))

if meta_type then
Expand Down Expand Up @@ -727,12 +748,8 @@ local CompileVisitors = {

state.Scopes, state.ScopeID, state.Scope = s_scopes, s_scopeid, s_scope

if state.__return__ then
state.__return__ = false
return state.__returnval__
elseif return_type then
state:forceThrow("Expected function return at runtime of type (" .. return_type .. ")")
end
state.__return__ = false
return state.__returnval__
end
else -- table
function fn.op(state, args, arg_types) ---@param state RuntimeContext
Expand All @@ -758,12 +775,8 @@ local CompileVisitors = {

state.Scopes, state.ScopeID, state.Scope = s_scopes, s_scopeid, s_scope

if state.__return__ then
state.__return__ = false
return state.__returnval__
elseif return_type then
state:forceThrow("Expected function return at runtime of type (" .. return_type .. ")")
end
state.__return__ = false
return state.__returnval__
end
end
else -- Todo: Make this output a different function when it doesn't early return, and/or has no parameters as an optimization.
Expand All @@ -784,23 +797,23 @@ local CompileVisitors = {

state.Scopes, state.ScopeID, state.Scope = s_scopes, s_scopeid, s_scope

if state.__return__ then
state.__return__ = false
return state.__returnval__
elseif return_type then
state:forceThrow("Expected function function at runtime of type (" .. return_type .. ")")
end
state.__return__ = false
return state.__returnval__
end
end

block = self:IsolatedScope(function (scope)
self:IsolatedScope(function (scope)
for i, type in ipairs(param_types) do
scope:DeclVar(param_names[i], { type = type, trace_if_unused = data[4][i] and data[4][i].name.trace or trace, initialized = true })
end

scope.data["function"] = { name.value, fn }

return self:CompileStmt(data[5])
block = self:CompileStmt(data[5])

if return_type then -- Ensure function either returns or errors
self:Assert(scope.data.dead, "This function marked to return '" .. data[1].value .. "' must return a value", data[1].trace)
end
end)

self:Assert((fn.returns and fn.returns[1]) == return_type, "Function " .. name.value .. " expects to return type (" .. (return_type or "void") .. ") but got type (" .. ((fn.returns and fn.returns[1]) or "void") .. ")", trace)
Expand Down Expand Up @@ -888,6 +901,8 @@ local CompileVisitors = {
local fn = self.scope:ResolveData("function")
self:Assert(fn, "Cannot use `return` outside of a function", trace)

self.scope.data.dead = "ret"

local retval, ret_ty
if data then
retval, ret_ty = self:CompileExpr(data)
Expand Down

0 comments on commit 59ca86e

Please sign in to comment.