diff --git a/compiler/desugar.go b/compiler/desugar.go index e956d85..0ec4ae9 100644 --- a/compiler/desugar.go +++ b/compiler/desugar.go @@ -571,44 +571,65 @@ func (d *desugarer) desugarList(stmts []ast.Stmt, breakTo, continueTo *ast.Ident return desugared } -func (d *desugarer) flatMap(stmt ast.Stmt) []ast.Stmt { - queue := []ast.Stmt{stmt} - for i := 0; i < len(queue); i++ { - switch s := queue[i].(type) { - case *ast.AssignStmt: - if s.Tok != token.DEFINE { - for j, expr := range s.Lhs { - s.Lhs[j] = d.decomposeExpression(expr, &queue, true) - } +func (d *desugarer) flatMap(stmt ast.Stmt) (result []ast.Stmt) { + var prereqs []ast.Stmt + switch s := stmt.(type) { + case *ast.AssignStmt: + var flags exprFlags + if s.Tok == token.DEFINE { + // LHS is just ast.Ident in this case; no need to decompose. + if len(s.Rhs) > 1 { + flags |= multiExprStmt } - for j, expr := range s.Rhs { - s.Rhs[j] = d.decomposeExpression(expr, &queue, false) + } else { + flags |= multiExprStmt + for j, expr := range s.Lhs { + s.Lhs[j], prereqs = d.decomposeExpression(expr, flags) + result = append(result, prereqs...) } - case *ast.DeclStmt: - g := s.Decl.(*ast.GenDecl) - if g.Tok == token.VAR { - for _, spec := range g.Specs { - v := spec.(*ast.ValueSpec) - for j, expr := range v.Values { - v.Values[j] = d.decomposeExpression(expr, &queue, false) - } + } + for j, expr := range s.Rhs { + s.Rhs[j], prereqs = d.decomposeExpression(expr, flags) + result = append(result, prereqs...) + } + case *ast.DeclStmt: + g := s.Decl.(*ast.GenDecl) + if g.Tok == token.VAR { + for _, spec := range g.Specs { + v := spec.(*ast.ValueSpec) + var flags exprFlags + if len(v.Values) > 1 { + flags |= multiExprStmt + } + for j, expr := range v.Values { + v.Values[j], prereqs = d.decomposeExpression(expr, flags) + result = append(result, prereqs...) } } - case *ast.ExprStmt: - s.X = d.decomposeExpression(s.X, &queue, false) - case *ast.SendStmt: - s.Chan = d.decomposeExpression(s.Chan, &queue, true) - s.Value = d.decomposeExpression(s.Value, &queue, false) - case *ast.ReturnStmt: - for j, expr := range s.Results { - s.Results[j] = d.decomposeExpression(expr, &queue, false) - } - case *ast.IncDecStmt: - s.X = d.decomposeExpression(s.X, &queue, false) } + case *ast.ExprStmt: + s.X, prereqs = d.decomposeExpression(s.X, exprFlags(0)) + result = append(result, prereqs...) + case *ast.SendStmt: + s.Chan, prereqs = d.decomposeExpression(s.Chan, multiExprStmt) + result = append(result, prereqs...) + s.Value, prereqs = d.decomposeExpression(s.Value, multiExprStmt) + result = append(result, prereqs...) + case *ast.ReturnStmt: + var flags exprFlags + if len(s.Results) > 1 { + flags |= multiExprStmt + } + for j, expr := range s.Results { + s.Results[j], prereqs = d.decomposeExpression(expr, flags) + result = append(result, prereqs...) + } + case *ast.IncDecStmt: + s.X, prereqs = d.decomposeExpression(s.X, exprFlags(0)) + result = append(result, prereqs...) } - reverse(queue) - return queue + result = append(result, stmt) + return } func (d *desugarer) mayYield(n ast.Node) (mayYield bool) { @@ -641,82 +662,109 @@ func (d *desugarer) mayYield(n ast.Node) (mayYield bool) { return } -func (d *desugarer) decomposeExpression(expr ast.Expr, queue *[]ast.Stmt, lhs bool) ast.Expr { +type exprFlags int + +const ( + // multiExprStmt is set if the expression is part of a statement + // that has more than one nested expression of type ast.Expr. + multiExprStmt exprFlags = 1 << iota +) + +func (d *desugarer) decomposeExpression(expr ast.Expr, flags exprFlags) (ast.Expr, []ast.Stmt) { if !d.mayYield(expr) { - return expr + return expr, nil } + queue := []ast.Expr{expr} + var tmps []*ast.Ident + decompose := func(e ast.Expr) ast.Expr { if !d.mayYield(e) { return e } tmp := d.newVar(d.info.TypeOf(e)) - *queue = append(*queue, &ast.AssignStmt{ - Lhs: []ast.Expr{tmp}, - Tok: token.DEFINE, - Rhs: []ast.Expr{e}, - }) + tmps = append(tmps, tmp) + queue = append(queue, e) return tmp } - switch e := expr.(type) { - case *ast.BadExpr: - panic("bad expr") - - case *ast.BinaryExpr: - e.X = decompose(e.X) - e.Y = decompose(e.Y) - case *ast.CallExpr: - e.Fun = decompose(e.Fun) - for i, arg := range e.Args { - e.Args[i] = decompose(arg) - } - case *ast.CompositeLit: - for i, elt := range e.Elts { - e.Elts[i] = decompose(elt) - } - // skip e.Type (type expression) + for i := 0; i < len(queue); i++ { + switch e := queue[i].(type) { + case *ast.BadExpr: + panic("bad expr") + + case *ast.BinaryExpr: + e.X = decompose(e.X) + e.Y = decompose(e.Y) + + case *ast.CallExpr: + if i == 0 && (flags&multiExprStmt) != 0 { + // Need to hoist the CallExpr out into a temporary variable in + // this case, so that the relative order of calls (and their + // prerequisites) is preserved. + queue[i] = decompose(e) + } else { + e.Fun = decompose(e.Fun) + for i, arg := range e.Args { + e.Args[i] = decompose(arg) + } + } + case *ast.CompositeLit: + for i, elt := range e.Elts { + e.Elts[i] = decompose(elt) + } + // skip e.Type (type expression) - case *ast.Ellipsis: - e.Elt = decompose(e.Elt) + case *ast.Ellipsis: + e.Elt = decompose(e.Elt) - case *ast.IndexExpr: - e.X = decompose(e.X) - e.Index = decompose(e.Index) + case *ast.IndexExpr: + e.X = decompose(e.X) + e.Index = decompose(e.Index) - case *ast.IndexListExpr: - e.X = decompose(e.X) - // skip e.Indices (type expressions) + case *ast.IndexListExpr: + e.X = decompose(e.X) + // skip e.Indices (type expressions) - case *ast.KeyValueExpr: - e.Key = decompose(e.Key) - e.Value = decompose(e.Value) + case *ast.KeyValueExpr: + e.Key = decompose(e.Key) + e.Value = decompose(e.Value) - case *ast.ParenExpr: - e.X = decompose(e.X) + case *ast.ParenExpr: + e.X = decompose(e.X) - case *ast.SelectorExpr: - e.X = decompose(e.X) + case *ast.SelectorExpr: + e.X = decompose(e.X) - case *ast.SliceExpr: - e.X = decompose(e.X) - e.Low = decompose(e.Low) - e.Max = decompose(e.Max) - e.High = decompose(e.High) + case *ast.SliceExpr: + e.X = decompose(e.X) + e.Low = decompose(e.Low) + e.Max = decompose(e.Max) + e.High = decompose(e.High) - case *ast.StarExpr: - e.X = decompose(e.X) + case *ast.StarExpr: + e.X = decompose(e.X) - case *ast.TypeAssertExpr: - e.X = decompose(e.X) - // skip e.Type (type expression) + case *ast.TypeAssertExpr: + e.X = decompose(e.X) + // skip e.Type (type expression) - case *ast.UnaryExpr: - e.X = decompose(e.X) + case *ast.UnaryExpr: + e.X = decompose(e.X) - default: - panic(fmt.Sprintf("unsupported ast.Expr: %T", expr)) + default: + panic(fmt.Sprintf("unsupported ast.Expr: %T", queue[i])) + } + } + prereqs := make([]ast.Stmt, len(tmps)) + for i := range tmps { + prereqs[i] = &ast.AssignStmt{ + Lhs: []ast.Expr{tmps[i]}, + Tok: token.DEFINE, + Rhs: []ast.Expr{queue[i+1]}, + } } - return expr + reverse(prereqs) + return queue[0], prereqs } func reverse(stmts []ast.Stmt) { diff --git a/compiler/desugar_test.go b/compiler/desugar_test.go index d7114e1..5d76fd1 100644 --- a/compiler/desugar_test.go +++ b/compiler/desugar_test.go @@ -1082,12 +1082,33 @@ _l0: }, { name: "decompose expressions in decls", - body: "var a, b int = c(d()), e(f())", + body: "var _, _ int = a(b(0)), c(d(1))", + // See https://go.dev/play/p/PkwoJbDLgQV for order of evaluation. expect: ` { - _v1 := f() - _v0 := d() - var a, b int = c(_v0), e(_v1) + _v1 := b(0) + _v0 := a(_v1) + _v3 := d(1) + _v2 := c(_v3) + var _, _ int = _v0, _v2 +} +`, + }, + { + name: "decompose expressions in assignments", + body: "ints[a(b(0))], ints[c(d(1))] = e(f(10)), g(h(11))", + // See https://go.dev/play/p/WvrxhauFbsA for order of evaluation + expect: ` +{ + _v1 := b(0) + _v0 := a(_v1) + _v3 := d(1) + _v2 := c(_v3) + _v5 := f(10) + _v4 := e(_v5) + _v7 := h(11) + _v6 := g(_v7) + ints[_v0], ints[_v2] = _v4, _v6 } `, }, diff --git a/compiler/testdata/coroutine_durable.go b/compiler/testdata/coroutine_durable.go index c91e794..b938e5c 100644 --- a/compiler/testdata/coroutine_durable.go +++ b/compiler/testdata/coroutine_durable.go @@ -2522,81 +2522,81 @@ func Range10ClosureCapturingPointers() { func Range10ClosureHeterogenousCapture() { _c := coroutine.LoadContext[int, any]() _f, _fp := _c.Push() - var _o12 int8 - var _o13 int16 - var _o14 int32 - var _o15 int64 - var _o16 uint8 - var _o17 uint16 - var _o18 uint32 - var _o19 uint64 - var _o20 uintptr - var _o21 func() int - var _o22 int - var _o23 func() bool - var _o24 bool + var _o13 int8 + var _o14 int16 + var _o15 int32 + var _o16 int64 + var _o17 uint8 + var _o18 uint16 + var _o19 uint32 + var _o20 uint64 + var _o21 uintptr + var _o22 func() int + var _o23 int + var _o24 func() bool var _o25 bool + var _o26 bool if _f.IP > 0 { if _v := _f.Get(0); _v != nil { - _o12 = _v.(int8) + _o13 = _v.(int8) } if _v := _f.Get(1); _v != nil { - _o13 = _v.(int16) + _o14 = _v.(int16) } if _v := _f.Get(2); _v != nil { - _o14 = _v.(int32) + _o15 = _v.(int32) } if _v := _f.Get(3); _v != nil { - _o15 = _v.(int64) + _o16 = _v.(int64) } if _v := _f.Get(4); _v != nil { - _o16 = _v.(uint8) + _o17 = _v.(uint8) } if _v := _f.Get(5); _v != nil { - _o17 = _v.(uint16) + _o18 = _v.(uint16) } if _v := _f.Get(6); _v != nil { - _o18 = _v.(uint32) + _o19 = _v.(uint32) } if _v := _f.Get(7); _v != nil { - _o19 = _v.(uint64) + _o20 = _v.(uint64) } if _v := _f.Get(8); _v != nil { - _o20 = _v.(uintptr) + _o21 = _v.(uintptr) } if _v := _f.Get(9); _v != nil { - _o21 = _v.(func() int) + _o22 = _v.(func() int) } if _v := _f.Get(10); _v != nil { - _o22 = _v.(int) + _o23 = _v.(int) } if _v := _f.Get(11); _v != nil { - _o23 = _v.(func() bool) + _o24 = _v.(func() bool) } if _v := _f.Get(12); _v != nil { - _o24 = _v.(bool) + _o25 = _v.(bool) } if _v := _f.Get(13); _v != nil { - _o25 = _v.(bool) + _o26 = _v.(bool) } } defer func() { if _c.Unwinding() { - _f.Set(0, _o12) - _f.Set(1, _o13) - _f.Set(2, _o14) - _f.Set(3, _o15) - _f.Set(4, _o16) - _f.Set(5, _o17) - _f.Set(6, _o18) - _f.Set(7, _o19) - _f.Set(8, _o20) - _f.Set(9, _o21) - _f.Set(10, _o22) - _f.Set(11, _o23) - _f.Set(12, _o24) - _f.Set(13, _o25) + _f.Set(0, _o13) + _f.Set(1, _o14) + _f.Set(2, _o15) + _f.Set(3, _o16) + _f.Set(4, _o17) + _f.Set(5, _o18) + _f.Set(6, _o19) + _f.Set(7, _o20) + _f.Set(8, _o21) + _f.Set(9, _o22) + _f.Set(10, _o23) + _f.Set(11, _o24) + _f.Set(12, _o25) + _f.Set(13, _o26) _c.Store(_fp, _f) } else { _c.Pop() @@ -2606,53 +2606,53 @@ func Range10ClosureHeterogenousCapture() { case _f.IP < 11: switch { case _f.IP < 2: - _o12 = 0 + _o13 = 0 _f.IP = 2 fallthrough case _f.IP < 3: - _o13 = 1 + _o14 = 1 _f.IP = 3 fallthrough case _f.IP < 4: - _o14 = 2 + _o15 = 2 _f.IP = 4 fallthrough case _f.IP < 5: - _o15 = 3 + _o16 = 3 _f.IP = 5 fallthrough case _f.IP < 6: - _o16 = 4 + _o17 = 4 _f.IP = 6 fallthrough case _f.IP < 7: - _o17 = 5 + _o18 = 5 _f.IP = 7 fallthrough case _f.IP < 8: - _o18 = 6 + _o19 = 6 _f.IP = 8 fallthrough case _f.IP < 9: - _o19 = 7 + _o20 = 7 _f.IP = 9 fallthrough case _f.IP < 10: - _o20 = 8 + _o21 = 8 _f.IP = 10 fallthrough case _f.IP < 11: - _o21 = func() int { return int(_o20) + 1 } + _o22 = func() int { return int(_o21) + 1 } } _f.IP = 11 fallthrough case _f.IP < 12: - _o22 = 0 + _o23 = 0 _f.IP = 12 fallthrough case _f.IP < 13: - _o23 = func() (_ bool) { + _o24 = func() (_ bool) { _c := coroutine.LoadContext[int, any]() _f, _fp := _c.Push() var _o0 int @@ -2667,6 +2667,7 @@ func Range10ClosureHeterogenousCapture() { var _o9 bool var _o10 bool var _o11 bool + var _o12 int if _f.IP > 0 { if _v := _f.Get(0); _v != nil { _o0 = _v.(int) @@ -2704,6 +2705,9 @@ func Range10ClosureHeterogenousCapture() { if _v := _f.Get(11); _v != nil { _o11 = _v.(bool) } + if _v := _f.Get(12); _v != nil { + _o12 = _v.(int) + } } defer func() { if _c.Unwinding() { @@ -2719,6 +2723,7 @@ func Range10ClosureHeterogenousCapture() { _f.Set(9, _o9) _f.Set(10, _o10) _f.Set(11, _o11) + _f.Set(12, _o12) _c.Store(_fp, _f) } else { _c.Pop() @@ -2728,13 +2733,13 @@ func Range10ClosureHeterogenousCapture() { case _f.IP < 2: _f.IP = 2 fallthrough - case _f.IP < 23: + case _f.IP < 24: switch { case _f.IP < 3: - _o1 = _o22 + _o1 = _o23 _f.IP = 3 fallthrough - case _f.IP < 23: + case _f.IP < 24: switch { default: switch { @@ -2743,54 +2748,55 @@ func Range10ClosureHeterogenousCapture() { 0 _f.IP = 4 fallthrough - case _f.IP < 23: + case _f.IP < 24: if _o2 { - _o0 = int(_o12) + _o0 = int(_o13) } else { _o3 = _o1 == 1 if _o3 { - _o0 = int(_o13) + _o0 = int(_o14) } else { _o4 = _o1 == 2 if _o4 { - _o0 = int(_o14) + _o0 = int(_o15) } else { _o5 = _o1 == 3 if _o5 { - _o0 = int(_o15) + _o0 = int(_o16) } else { _o6 = _o1 == 4 if _o6 { - _o0 = int(_o16) + _o0 = int(_o17) } else { _o7 = _o1 == 5 if _o7 { - _o0 = int(_o17) + _o0 = int(_o18) } else { _o8 = _o1 == 6 if _o8 { - _o0 = int(_o18) + _o0 = int(_o19) } else { _o9 = _o1 == 7 if _o9 { - _o0 = int(_o19) + _o0 = int(_o20) } else { _o10 = _o1 == 8 if _o10 { - _o0 = int(_o20) + _o0 = int(_o21) } else { _o11 = _o1 == 9 if _o11 { - _o0 = _o21() + _o12 = _o22() + _o0 = _o12 } } } @@ -2804,19 +2810,19 @@ func Range10ClosureHeterogenousCapture() { } } } - _f.IP = 23 - fallthrough - case _f.IP < 24: - - coroutine.Yield[int, any](_o0) _f.IP = 24 fallthrough case _f.IP < 25: - _o22++ + + coroutine.Yield[int, any](_o0) _f.IP = 25 fallthrough case _f.IP < 26: - return _o22 < 10 + _o23++ + _f.IP = 26 + fallthrough + case _f.IP < 27: + return _o23 < 10 } return } @@ -2827,15 +2833,15 @@ func Range10ClosureHeterogenousCapture() { for ; ; _f.IP = 13 { switch { case _f.IP < 14: - _o24 = _o23() + _o25 = _o24() _f.IP = 14 fallthrough case _f.IP < 15: - _o25 = !_o24 + _o26 = !_o25 _f.IP = 15 fallthrough case _f.IP < 16: - if _o25 { + if _o26 { break _l0 } }