From 913895a0766d51c073403b59b5bda34ce453aefb Mon Sep 17 00:00:00 2001 From: Chris O'Hara Date: Wed, 20 Sep 2023 10:06:05 +1000 Subject: [PATCH] Start on desugar tests --- compiler/desugar.go | 22 ++--- compiler/desugar_test.go | 169 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 182 insertions(+), 9 deletions(-) create mode 100644 compiler/desugar_test.go diff --git a/compiler/desugar.go b/compiler/desugar.go index 33bff45..9da86ed 100644 --- a/compiler/desugar.go +++ b/compiler/desugar.go @@ -129,13 +129,12 @@ func (d *desugarer) desugar(stmt ast.Stmt, breakTo, continueTo, userLabel *ast.I if userLabel != nil { d.addUserLabel(userLabel, forLabel) } - body := &ast.BlockStmt{ - List: append([]ast.Stmt{ - &ast.IfStmt{ - Cond: &ast.UnaryExpr{Op: token.NOT, X: s.Cond}, - Body: &ast.BlockStmt{List: []ast.Stmt{&ast.BranchStmt{Tok: token.BREAK}}}, - }, - }, s.Body.List...), + body := &ast.BlockStmt{List: s.Body.List} + if s.Cond != nil { + body.List = append([]ast.Stmt{&ast.IfStmt{ + Cond: &ast.UnaryExpr{Op: token.NOT, X: s.Cond}, + Body: &ast.BlockStmt{List: []ast.Stmt{&ast.BranchStmt{Tok: token.BREAK}}}, + }}, body.List...) } stmt = &ast.LabeledStmt{ Label: forLabel, @@ -182,6 +181,7 @@ func (d *desugarer) desugar(stmt ast.Stmt, breakTo, continueTo, userLabel *ast.I } case *ast.IncDecStmt: + // TODO: desugar expressions case *ast.LabeledStmt: // Remove the user's label, but notify the next step so that generated @@ -365,7 +365,9 @@ func (d *desugarer) desugar(stmt ast.Stmt, breakTo, continueTo, userLabel *ast.I // TODO: desugar expressions case *ast.SwitchStmt: - // Rewrite `switch init; tag {}` to `init; switch tag {}` + // Rewrite switch statements: + // - `switch init; tag { case A: ... case B: ... }` => `{ init; if _tag := tag; _tag == A { ... } else if _tag == B { ... }` + // - `switch { case A: ... case B: ... default: ... } => `if A { ... } else if B { ... } else { ... }` init := d.desugar(s.Init, nil, nil, nil) switchLabel := d.newLabel() if userLabel != nil { @@ -384,7 +386,9 @@ func (d *desugarer) desugar(stmt ast.Stmt, breakTo, continueTo, userLabel *ast.I } case *ast.TypeSwitchStmt: - // Rewrite `switch init; assign {}` to `init; switch assign {}` + // Rewrite type switch statements: + // - `switch init; x.(type) { ... }` to `{ init; _x := x; switch _x.(type) { ... } }` + // - `switch init; x := y.(type) { ... }` to `{ init; _t := y; switch x := _y.(type) { ... } }` init := d.desugar(s.Init, nil, nil, nil) switchLabel := d.newLabel() if userLabel != nil { diff --git a/compiler/desugar_test.go b/compiler/desugar_test.go new file mode 100644 index 0000000..02744b5 --- /dev/null +++ b/compiler/desugar_test.go @@ -0,0 +1,169 @@ +package compiler + +import ( + "bytes" + "go/ast" + "go/format" + "go/parser" + "go/token" + "go/types" + "strings" + "testing" +) + +func TestDesugar(t *testing.T) { + for _, test := range []struct { + name string + body string + expect string + uses map[string]types.Object + defs map[string]types.Object + types map[string]types.TypeAndValue + }{ + { + name: "if cond", + body: "if bar == 1 {}", + expect: ` +{ + _v0 := bar == 1 + if _v0 { + } +}`, + }, + { + name: "if init + cond", + body: "if foo := bar; bar == 1 {}", + expect: ` +{ + foo := bar + _v0 := bar == 1 + if _v0 { + } +}`, + }, + { + name: "for init + cond + post", + body: "for i := 0; i < 10; i++ { result += i }", + expect: ` +{ + i := 0 +_l0: + for ; ; i++ { + { + _v0 := !(i < 10) + if _v0 { + break _l0 + } + } + result += i + } +}`, + }, + { + name: "labeled for", + body: "outer: for i := 0; i < 10; i++ { for j := 0; j < 10; j++ { break outer } }", + expect: ` +{ + i := 0 +_l0: + for ; ; i++ { + { + _v0 := !(i < 10) + if _v0 { + break _l0 + } + } + { + j := 0 + _l1: + for ; ; j++ { + { + _v1 := !(j < 10) + if _v1 { + break _l1 + } + } + break _l0 + } + } + } +}`, + }, + { + name: "labeled for break and continue handling", + body: ` +outer: + for { + switch { + case true: + break + case false: + continue + default: + break outer + } + }`, + expect: ` +_l0: + for { + _l1: + switch { + case true: + break _l1 + case false: + continue _l0 + default: + break _l0 + } + }`, + }, + } { + t.Run(test.name, func(t *testing.T) { + expr, err := parser.ParseExpr("func() {\n" + test.body + "\n}()") + if err != nil { + t.Fatal(err) + } + body := expr.(*ast.CallExpr).Fun.(*ast.FuncLit).Body + + info := &types.Info{ + Defs: map[*ast.Ident]types.Object{}, + Uses: map[*ast.Ident]types.Object{}, + Types: map[ast.Expr]types.TypeAndValue{}, + } + ast.Inspect(body, func(node ast.Node) bool { + if ident, ok := node.(*ast.Ident); ok { + if obj, ok := test.defs[ident.Name]; ok { + info.Defs[ident] = obj + } + if obj, ok := test.uses[ident.Name]; ok { + info.Uses[ident] = obj + } + if t, ok := test.types[ident.Name]; ok { + info.Types[ident] = t + } + } + return true + }) + desugared := desugar(body, info) + desugared = unnestBlocks(desugared) + + expect := strings.TrimSpace(test.expect) + actual := formatNode(desugared) + if actual != expect { + t.Errorf("unexpected desugared result") + t.Logf("expect:\n%s", test.expect) + t.Logf("actual:\n%s", actual) + } + }) + } +} + +func formatNode(node ast.Node) string { + fset := token.NewFileSet() + //ast.Print(fset, node) + var b bytes.Buffer + if err := format.Node(&b, fset, node); err != nil { + panic(err) + } + return b.String() +}