From 285096616b4297984b45f9b2e93c35e4e0ddee80 Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Mon, 25 Sep 2023 11:22:22 -0700 Subject: [PATCH] Strip pos information --- compiler/compile.go | 2 + compiler/pos.go | 135 +++++++++++++++++++++++ compiler/testdata/coroutine_durable.go | 141 ++++++++----------------- compiler/testdata/http/main_durable.go | 105 ++++++++++++++++++ 4 files changed, 288 insertions(+), 95 deletions(-) create mode 100644 compiler/pos.go create mode 100644 compiler/testdata/http/main_durable.go diff --git a/compiler/compile.go b/compiler/compile.go index a6a0d48..178b036 100644 --- a/compiler/compile.go +++ b/compiler/compile.go @@ -266,6 +266,8 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors) er continue } + clearPos(gen) + generateFunctypes(p, gen, colorsByFunc) // Find all the required imports for this file. diff --git a/compiler/pos.go b/compiler/pos.go new file mode 100644 index 0000000..f6bc85a --- /dev/null +++ b/compiler/pos.go @@ -0,0 +1,135 @@ +package compiler + +import ( + "go/ast" + "go/token" +) + +// clearPos resets the token.Pos field(s) in each type of ast.Node. +// When AST nodes are generated alongside nodes that have position +// information, it can cause formatting to produce invalid results. +// This function clears the information from all nodes so that the +// formatter produces correct results. +func clearPos(tree ast.Node) { + ast.Inspect(tree, func(node ast.Node) bool { + switch n := node.(type) { + case *ast.ArrayType: + n.Lbrack = token.NoPos + case *ast.AssignStmt: + n.TokPos = token.NoPos + case *ast.BasicLit: + n.ValuePos = token.NoPos + case *ast.BinaryExpr: + n.OpPos = token.NoPos + case *ast.BlockStmt: + n.Rbrace = token.NoPos + n.Lbrace = token.NoPos + case *ast.BranchStmt: + n.TokPos = token.NoPos + case *ast.CallExpr: + n.Ellipsis = token.NoPos + n.Lparen = token.NoPos + n.Rparen = token.NoPos + case *ast.CaseClause: + n.Colon = token.NoPos + n.Case = token.NoPos + case *ast.ChanType: + n.Arrow = token.NoPos + n.Begin = token.NoPos + case *ast.CommClause: + n.Case = token.NoPos + n.Colon = token.NoPos + case *ast.Comment: + n.Slash = token.NoPos + case *ast.CommentGroup: + case *ast.CompositeLit: + n.Lbrace = token.NoPos + n.Rbrace = token.NoPos + case *ast.DeclStmt: + case *ast.DeferStmt: + n.Defer = token.NoPos + case *ast.Ellipsis: + n.Ellipsis = token.NoPos + case *ast.EmptyStmt: + n.Semicolon = token.NoPos + case *ast.ExprStmt: + case *ast.Field: + case *ast.FieldList: + n.Closing = token.NoPos + n.Opening = token.NoPos + case *ast.File: + n.Package = token.NoPos + n.FileStart = token.NoPos + n.FileEnd = token.NoPos + case *ast.ForStmt: + n.For = token.NoPos + case *ast.FuncDecl: + case *ast.FuncLit: + case *ast.FuncType: + n.Func = token.NoPos + case *ast.GenDecl: + n.Lparen = token.NoPos + n.Rparen = token.NoPos + n.TokPos = token.NoPos + case *ast.GoStmt: + n.Go = token.NoPos + case *ast.Ident: + n.NamePos = token.NoPos + case *ast.IfStmt: + n.If = token.NoPos + case *ast.ImportSpec: + n.EndPos = token.NoPos + case *ast.IncDecStmt: + n.TokPos = token.NoPos + case *ast.IndexExpr: + n.Lbrack = token.NoPos + n.Rbrack = token.NoPos + case *ast.IndexListExpr: + n.Lbrack = token.NoPos + n.Rbrack = token.NoPos + case *ast.InterfaceType: + n.Interface = token.NoPos + case *ast.KeyValueExpr: + n.Colon = token.NoPos + case *ast.LabeledStmt: + n.Colon = token.NoPos + case *ast.MapType: + n.Map = token.NoPos + case *ast.Package: + case *ast.ParenExpr: + n.Lparen = token.NoPos + n.Rparen = token.NoPos + case *ast.RangeStmt: + n.TokPos = token.NoPos + n.For = token.NoPos + n.Range = token.NoPos + case *ast.ReturnStmt: + n.Return = token.NoPos + case *ast.SelectStmt: + n.Select = token.NoPos + case *ast.SelectorExpr: + case *ast.SendStmt: + n.Arrow = token.NoPos + case *ast.SliceExpr: + n.Lbrack = token.NoPos + n.Rbrack = token.NoPos + case *ast.StarExpr: + n.Star = token.NoPos + case *ast.StructType: + n.Struct = token.NoPos + case *ast.SwitchStmt: + n.Switch = token.NoPos + case *ast.TypeAssertExpr: + n.Lparen = token.NoPos + n.Rparen = token.NoPos + case *ast.TypeSpec: + n.Assign = token.NoPos + case *ast.TypeSwitchStmt: + n.Switch = token.NoPos + case *ast.UnaryExpr: + n.OpPos = token.NoPos + case *ast.ValueSpec: + } + return true + }) +} diff --git a/compiler/testdata/coroutine_durable.go b/compiler/testdata/coroutine_durable.go index 5c1d4ac..b2f8496 100644 --- a/compiler/testdata/coroutine_durable.go +++ b/compiler/testdata/coroutine_durable.go @@ -14,7 +14,9 @@ import _types "github.com/stealthrocket/coroutine/types" func SomeFunctionThatShouldExistInTheCompiledFile() { } //go:noinline -func Identity(n int) { coroutine.Yield[int, any](n) } +func Identity(n int) { + coroutine.Yield[int, any](n) +} //go:noinline func SquareGenerator(n int) { _c := coroutine.LoadContext[int, any]() @@ -273,12 +275,10 @@ func FizzBuzzIfGenerator(n int) { fallthrough case _f.IP < 7: for ; _f0.X1 <= _f0.X0; _f0.X1, _f.IP = _f0.X1+1, 2 { - if _f0.X1% - 3 == 0 && _f0.X1%5 == 0 { + if _f0.X1%3 == 0 && _f0.X1%5 == 0 { coroutine.Yield[int, any](FizzBuzz) } else { - if _f0.X1% - 3 == 0 { + if _f0.X1%3 == 0 { coroutine.Yield[int, any](Fizz) } else { switch { @@ -290,7 +290,6 @@ func FizzBuzzIfGenerator(n int) { if _f0.X2 == 0 { coroutine.Yield[int, any](Buzz) } else { - coroutine.Yield[int, any](_f0.X1) } } @@ -346,8 +345,7 @@ func FizzBuzzSwitchGenerator(n int) { default: switch { case _f.IP < 3: - _f0.X2 = _f0.X1% - 3 == 0 && _f0.X1%5 == 0 + _f0.X2 = _f0.X1%3 == 0 && _f0.X1%5 == 0 _f.IP = 3 fallthrough case _f.IP < 9: @@ -356,8 +354,7 @@ func FizzBuzzSwitchGenerator(n int) { } else { switch { case _f.IP < 5: - _f0.X3 = _f0.X1% - 3 == 0 + _f0.X3 = _f0.X1%3 == 0 _f.IP = 5 fallthrough case _f.IP < 9: @@ -366,15 +363,13 @@ func FizzBuzzSwitchGenerator(n int) { } else { switch { case _f.IP < 7: - _f0.X4 = _f0.X1% - 5 == 0 + _f0.X4 = _f0.X1%5 == 0 _f.IP = 7 fallthrough case _f.IP < 9: if _f0.X4 { coroutine.Yield[int, any](Buzz) } else { - coroutine.Yield[int, any](_f0.X1) } } @@ -390,20 +385,13 @@ func FizzBuzzSwitchGenerator(n int) { func Shadowing(_ int) { _c := coroutine.LoadContext[int, any]() _f, _fp := _c.Push() - const _o0 = 11 - const _o1 = 12 - type _o2 uint16 - type _o3 uint32 - const _o4 = 1 type _o5 [_o4]uint8 - type _o6 [_o4]uint8 - const _o7 = unsafe.Sizeof(_o6{}) * 2 type _o8 [_o7]uint8 var _f0 *struct { @@ -515,7 +503,6 @@ func Shadowing(_ int) { _f.IP = 5 fallthrough case _f.IP < 6: - coroutine.Yield[int, any](_f0.X0) _f.IP = 6 fallthrough @@ -533,7 +520,6 @@ func Shadowing(_ int) { _f.IP = 8 fallthrough case _f.IP < 9: - coroutine.Yield[int, any](_f0.X0) _f.IP = 9 fallthrough @@ -552,8 +538,7 @@ func Shadowing(_ int) { default: switch { case _f.IP < 12: - _f0.X5 = _f0.X4 == - 1 + _f0.X5 = _f0.X4 == 1 _f.IP = 12 fallthrough case _f.IP < 16: @@ -572,14 +557,12 @@ func Shadowing(_ int) { case _f.IP < 15: switch { default: - coroutine.Yield[int, any](_f0.X6) } } _f.IP = 15 fallthrough case _f.IP < 16: - coroutine.Yield[int, any](_f0.X3) } } @@ -589,7 +572,6 @@ func Shadowing(_ int) { _f.IP = 16 fallthrough case _f.IP < 17: - coroutine.Yield[int, any](_f0.X0) _f.IP = 17 fallthrough @@ -611,13 +593,11 @@ func Shadowing(_ int) { _f.IP = 20 fallthrough case _f.IP < 21: - coroutine.Yield[int, any](_f0.X8) } _f.IP = 21 fallthrough case _f.IP < 22: - coroutine.Yield[int, any](_f0.X0) _f.IP = 22 fallthrough @@ -637,7 +617,6 @@ func Shadowing(_ int) { _f.IP = 25 fallthrough case _f.IP < 26: - coroutine.Yield[int, any](_f0.X10) _f.IP = 26 fallthrough @@ -655,13 +634,11 @@ func Shadowing(_ int) { _f.IP = 28 fallthrough case _f.IP < 29: - coroutine.Yield[int, any](_o1) } _f.IP = 29 fallthrough case _f.IP < 30: - coroutine.Yield[int, any](_o0) _f.IP = 30 fallthrough @@ -984,14 +961,12 @@ func LoopBreakAndContinue(_ int) { _f.IP = 4 fallthrough case _f.IP < 5: - if _f0.X0 > - 5 { + if _f0.X0 > 5 { break _l0 } _f.IP = 5 fallthrough case _f.IP < 6: - coroutine.Yield[int, any](_f0.X0) } } @@ -1026,30 +1001,22 @@ func LoopBreakAndContinue(_ int) { switch { default: { - _f0.X5 = _f0.X4 == - - 0 + _f0.X5 = _f0.X4 == 0 if _f0.X5 { continue _l2 } else { - _f0.X6 = _f0.X4 == - - 1 + _f0.X6 = _f0.X4 == 1 if _f0.X6 { { _f0.X7 = _f0.X2 switch { default: { - _f0.X8 = _f0.X7 == - - 0 + _f0.X8 = _f0.X7 == 0 if _f0.X8 { continue _l1 } else { - _f0.X9 = _f0.X7 == - - 1 + _f0.X9 = _f0.X7 == 1 if _f0.X9 { break _l1 } @@ -1212,7 +1179,6 @@ func RangeOverMaps(n int) { fallthrough case _f.IP < 9: for ; _f0.X3 < len(_f0.X2); _f0.X3, _f.IP = _f0.X3+1, 8 { - coroutine.Yield[int, any](0) } } @@ -1261,7 +1227,6 @@ func RangeOverMaps(n int) { fallthrough case _f.IP < 17: if _f0.X9 { - coroutine.Yield[int, any](_f0.X8) } } @@ -1316,7 +1281,6 @@ func RangeOverMaps(n int) { if _f0.X16 { switch { case _f.IP < 25: - coroutine.Yield[int, any](_f0.X14) _f.IP = 25 fallthrough @@ -1333,12 +1297,14 @@ func RangeOverMaps(n int) { _f.IP = 26 fallthrough case _f.IP < 27: - _f0.X17 = make(map[int]struct{}, _f0.X0) + _f0.X17 = make(map[int]struct { + }, _f0.X0) _f.IP = 27 fallthrough case _f.IP < 28: for _f0.X18 = 0; _f0.X18 < _f0.X0; _f0.X18++ { - _f0.X17[_f0.X18] = struct{}{} + _f0.X17[_f0.X18] = struct { + }{} } _f.IP = 28 fallthrough @@ -1390,7 +1356,6 @@ func RangeOverMaps(n int) { if _f0.X24 { switch { case _f.IP < 37: - delete(_f0.X17, _f0.X23) _f.IP = 37 fallthrough @@ -1448,11 +1413,14 @@ func Range(n int, do func(int)) { } } //go:noinline -func Double(n int) { coroutine.Yield[int, any](2 * n) } +func Double(n int) { + coroutine.Yield[int, any](2 * n) +} //go:noinline func RangeTriple(n int) { - Range(n, func(i int) { coroutine.Yield[int, any](3 * i) }, - ) + Range(n, func(i int) { + coroutine.Yield[int, any](3 * i) + }) } //go:noinline func RangeTripleFuncValue(n int) { @@ -1483,11 +1451,12 @@ func RangeTripleFuncValue(n int) { }() switch { case _f.IP < 2: - _f0.X1 = func(i int) { coroutine.Yield[int, any](3 * i) } + _f0.X1 = func(i int) { + coroutine.Yield[int, any](3 * i) + } _f.IP = 2 fallthrough case _f.IP < 3: - Range(_f0.X0, _f0.X1) } } @@ -1527,7 +1496,9 @@ func RangeReverseClosureCaptureByValue(n int) { _f.IP = 2 fallthrough case _f.IP < 3: - _f0.X2 = func() { coroutine.Yield[int, any](_f0.X0 - (_f0.X1 + 1)) } + _f0.X2 = func() { + coroutine.Yield[int, any](_f0.X0 - (_f0.X1 + 1)) + } _f.IP = 3 fallthrough case _f.IP < 5: @@ -1628,7 +1599,6 @@ func Range10ClosureCapturingValues() { _f.IP = 4 fallthrough case _f.IP < 5: - return false } return @@ -1750,7 +1720,6 @@ func Range10ClosureCapturingPointers() { _f.IP = 4 fallthrough case _f.IP < 5: - return false } return @@ -1852,7 +1821,9 @@ func Range10ClosureHeterogenousCapture() { _f1.X6 = 6 _f1.X7 = 7 _f1.X8 = 8 - _f1.X9 = func() int { return int(_f1.X8) + 1 } + _f1.X9 = func() int { + return int(_f1.X8) + 1 + } } _f.IP = 11 fallthrough @@ -1930,36 +1901,25 @@ func Range10ClosureHeterogenousCapture() { case _f.IP < 13: switch { default: - if _f0.X2 = _f0.X1 == - - 0; _f0.X2 { + if _f0.X2 = _f0.X1 == 0; _f0.X2 { _f0.X0 = int(_f1.X0) - } else if _f0.X3 = _f0.X1 == - 1; _f0.X3 { + } else if _f0.X3 = _f0.X1 == 1; _f0.X3 { _f0.X0 = int(_f1.X1) - } else if _f0.X4 = _f0.X1 == - 2; _f0.X4 { + } else if _f0.X4 = _f0.X1 == 2; _f0.X4 { _f0.X0 = int(_f1.X2) - } else if _f0.X5 = _f0.X1 == - 3; _f0.X5 { + } else if _f0.X5 = _f0.X1 == 3; _f0.X5 { _f0.X0 = int(_f1.X3) - } else if _f0.X6 = _f0.X1 == - 4; _f0.X6 { + } else if _f0.X6 = _f0.X1 == 4; _f0.X6 { _f0.X0 = int(_f1.X4) - } else if _f0.X7 = _f0.X1 == - 5; _f0.X7 { + } else if _f0.X7 = _f0.X1 == 5; _f0.X7 { _f0.X0 = int(_f1.X5) - } else if _f0.X8 = _f0.X1 == - 6; _f0.X8 { + } else if _f0.X8 = _f0.X1 == 6; _f0.X8 { _f0.X0 = int(_f1.X6) - } else if _f0.X9 = _f0.X1 == - 7; _f0.X9 { + } else if _f0.X9 = _f0.X1 == 7; _f0.X9 { _f0.X0 = int(_f1.X7) - } else if _f0.X10 = _f0.X1 == - 8; _f0.X10 { + } else if _f0.X10 = _f0.X1 == 8; _f0.X10 { _f0.X0 = int(_f1.X8) - } else if _f0.X11 = _f0.X1 == - 9; _f0.X11 { + } else if _f0.X11 = _f0.X1 == 9; _f0.X11 { _f0.X0 = _f1.X9() } } @@ -1967,7 +1927,6 @@ func Range10ClosureHeterogenousCapture() { _f.IP = 13 fallthrough case _f.IP < 14: - coroutine.Yield[int, any](_f0.X0) _f.IP = 14 fallthrough @@ -2084,7 +2043,6 @@ func Range10Heterogenous() { _f.IP = 12 fallthrough case _f.IP < 22: - switch _f0.X9 { case 0: _f0.X10 = int(_f0.X0) @@ -2227,7 +2185,6 @@ func Select(n int) { fallthrough case _f.IP < 6: if _f0.X3 { - coroutine.Yield[int, any](-1) } } @@ -2287,18 +2244,15 @@ func Select(n int) { if _f0.X9 { switch { case _f.IP < 15: - if _f0.X4 >= - 5 { + if _f0.X4 >= 5 { break _l2 } _f.IP = 15 fallthrough case _f.IP < 16: - coroutine.Yield[int, any](_f0.X4) } } else if _f0.X10 = _f0.X8 == 2; _f0.X10 { - panic("unreachable") } } @@ -2343,14 +2297,12 @@ func Select(n int) { if _f0.X14 { switch { case _f.IP < 23: - if _f0.X4 >= - 6 { + if _f0.X4 >= 6 { break _l3 } _f.IP = 23 fallthrough case _f.IP < 24: - coroutine.Yield[int, any](_f0.X4 * 10) } } @@ -2984,8 +2936,7 @@ func YieldingDurations() { }() switch { case _f.IP < 2: - _f0.X0 = _f1.X0. - Nanoseconds() + _f0.X0 = _f1.X0.Nanoseconds() _f.IP = 2 fallthrough case _f.IP < 3: diff --git a/compiler/testdata/http/main_durable.go b/compiler/testdata/http/main_durable.go new file mode 100644 index 0000000..2a52f59 --- /dev/null +++ b/compiler/testdata/http/main_durable.go @@ -0,0 +1,105 @@ +// Code generated by coroc. DO NOT EDIT + +//go:build durable + +package main + +import ( + http "net/http" + coroutine "github.com/stealthrocket/coroutine" + fmt "fmt" +) +import _types "github.com/stealthrocket/coroutine/types" +//go:noinline +func RoundTrip(req *http.Request) (_ *http.Response, _ error) { + _c := coroutine.LoadContext[*http.Request, *http.Response]() + _f, _fp := _c.Push() + var _f0 *struct { + X0 *http.Request + X1 *http.Response + } + if _f.IP == 0 { + _f0 = &struct { + X0 *http.Request + X1 *http.Response + }{X0: req} + } else { + _f0 = _f.Get(0).(*struct { + X0 *http.Request + X1 *http.Response + }) + } + defer func() { + if _c.Unwinding() { + _f.Set(0, _f0) + _c.Store(_fp, _f) + } else { + _c.Pop() + } + }() + switch { + case _f.IP < 2: + _f0.X1 = coroutine.Yield[*http.Request, *http.Response](_f0.X0) + _f.IP = 2 + fallthrough + case _f.IP < 3: + return _f0.X1, nil + } + return +} +//go:noinline +func work() { + _c := coroutine.LoadContext[*http.Request, *http.Response]() + _f, _fp := _c.Push() + var _f0 *struct { + X0 *http.Response + X1 error + } + if _f.IP == 0 { + _f0 = &struct { + X0 *http.Response + X1 error + }{} + } else { + _f0 = _f.Get(0).(*struct { + X0 *http.Response + X1 error + }) + } + defer func() { + if _c.Unwinding() { + _f.Set(0, _f0) + _c.Store(_fp, _f) + } else { + _c.Pop() + } + }() + switch { + case _f.IP < 2: + _f0.X0, _f0.X1 = http.Get("http://example.com") + _f.IP = 2 + fallthrough + case _f.IP < 3: + if _f0.X1 != nil { + panic(_f0.X1) + } + _f.IP = 3 + fallthrough + case _f.IP < 4: + fmt.Println(_f0.X0.StatusCode) + } +} +func main() { + http.DefaultTransport = &yieldingRoundTripper{} + c := coroutine.New[*http.Request, *http.Response](work) + for c.Next() { + req := c.Recv() + fmt.Println("Requesting", req.URL.String()) + c.Send(&http.Response{StatusCode: 200}) + } +} +func init() { + _types.RegisterFunc[func(req *http.Request) (_ *http.Response, _ error)]("github.com/stealthrocket/coroutine/compiler/testdata/http.RoundTrip") + _types.RegisterFunc[func()]("github.com/stealthrocket/coroutine/compiler/testdata/http.main") + _types.RegisterFunc[func()]("github.com/stealthrocket/coroutine/compiler/testdata/http.work") +}