Skip to content

Commit

Permalink
Fix order of evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
chriso committed Sep 21, 2023
1 parent 89e57db commit c5643b9
Show file tree
Hide file tree
Showing 3 changed files with 242 additions and 167 deletions.
220 changes: 134 additions & 86 deletions compiler/desugar.go
Original file line number Diff line number Diff line change
Expand Up @@ -571,44 +571,65 @@ func (d *desugarer) desugarList(stmts []ast.Stmt, breakTo, continueTo *ast.Ident
return desugared
}

func (d *desugarer) flatMap(stmt ast.Stmt) []ast.Stmt {
queue := []ast.Stmt{stmt}
for i := 0; i < len(queue); i++ {
switch s := queue[i].(type) {
case *ast.AssignStmt:
if s.Tok != token.DEFINE {
for j, expr := range s.Lhs {
s.Lhs[j] = d.decomposeExpression(expr, &queue, true)
}
func (d *desugarer) flatMap(stmt ast.Stmt) (result []ast.Stmt) {
var prereqs []ast.Stmt
switch s := stmt.(type) {
case *ast.AssignStmt:
var flags exprFlags
if s.Tok == token.DEFINE {
// LHS is just ast.Ident in this case; no need to decompose.
if len(s.Rhs) > 1 {
flags |= multiExprStmt
}
for j, expr := range s.Rhs {
s.Rhs[j] = d.decomposeExpression(expr, &queue, false)
} else {
flags |= multiExprStmt
for j, expr := range s.Lhs {
s.Lhs[j], prereqs = d.decomposeExpression(expr, flags)
result = append(result, prereqs...)
}
case *ast.DeclStmt:
g := s.Decl.(*ast.GenDecl)
if g.Tok == token.VAR {
for _, spec := range g.Specs {
v := spec.(*ast.ValueSpec)
for j, expr := range v.Values {
v.Values[j] = d.decomposeExpression(expr, &queue, false)
}
}
for j, expr := range s.Rhs {
s.Rhs[j], prereqs = d.decomposeExpression(expr, flags)
result = append(result, prereqs...)
}
case *ast.DeclStmt:
g := s.Decl.(*ast.GenDecl)
if g.Tok == token.VAR {
for _, spec := range g.Specs {
v := spec.(*ast.ValueSpec)
var flags exprFlags
if len(v.Values) > 1 {
flags |= multiExprStmt
}
for j, expr := range v.Values {
v.Values[j], prereqs = d.decomposeExpression(expr, flags)
result = append(result, prereqs...)
}
}
case *ast.ExprStmt:
s.X = d.decomposeExpression(s.X, &queue, false)
case *ast.SendStmt:
s.Chan = d.decomposeExpression(s.Chan, &queue, true)
s.Value = d.decomposeExpression(s.Value, &queue, false)
case *ast.ReturnStmt:
for j, expr := range s.Results {
s.Results[j] = d.decomposeExpression(expr, &queue, false)
}
case *ast.IncDecStmt:
s.X = d.decomposeExpression(s.X, &queue, false)
}
case *ast.ExprStmt:
s.X, prereqs = d.decomposeExpression(s.X, exprFlags(0))
result = append(result, prereqs...)
case *ast.SendStmt:
s.Chan, prereqs = d.decomposeExpression(s.Chan, multiExprStmt)
result = append(result, prereqs...)
s.Value, prereqs = d.decomposeExpression(s.Value, multiExprStmt)
result = append(result, prereqs...)
case *ast.ReturnStmt:
var flags exprFlags
if len(s.Results) > 1 {
flags |= multiExprStmt
}
for j, expr := range s.Results {
s.Results[j], prereqs = d.decomposeExpression(expr, flags)
result = append(result, prereqs...)
}
case *ast.IncDecStmt:
s.X, prereqs = d.decomposeExpression(s.X, exprFlags(0))
result = append(result, prereqs...)
}
reverse(queue)
return queue
result = append(result, stmt)
return
}

func (d *desugarer) mayYield(n ast.Node) (mayYield bool) {
Expand Down Expand Up @@ -641,82 +662,109 @@ func (d *desugarer) mayYield(n ast.Node) (mayYield bool) {
return
}

func (d *desugarer) decomposeExpression(expr ast.Expr, queue *[]ast.Stmt, lhs bool) ast.Expr {
type exprFlags int

const (
// multiExprStmt is set if the expression is part of a statement
// that has more than one nested expression of type ast.Expr.
multiExprStmt exprFlags = 1 << iota
)

func (d *desugarer) decomposeExpression(expr ast.Expr, flags exprFlags) (ast.Expr, []ast.Stmt) {
if !d.mayYield(expr) {
return expr
return expr, nil
}
queue := []ast.Expr{expr}
var tmps []*ast.Ident

decompose := func(e ast.Expr) ast.Expr {
if !d.mayYield(e) {
return e
}
tmp := d.newVar(d.info.TypeOf(e))
*queue = append(*queue, &ast.AssignStmt{
Lhs: []ast.Expr{tmp},
Tok: token.DEFINE,
Rhs: []ast.Expr{e},
})
tmps = append(tmps, tmp)
queue = append(queue, e)
return tmp
}
switch e := expr.(type) {
case *ast.BadExpr:
panic("bad expr")

case *ast.BinaryExpr:
e.X = decompose(e.X)
e.Y = decompose(e.Y)

case *ast.CallExpr:
e.Fun = decompose(e.Fun)
for i, arg := range e.Args {
e.Args[i] = decompose(arg)
}
case *ast.CompositeLit:
for i, elt := range e.Elts {
e.Elts[i] = decompose(elt)
}
// skip e.Type (type expression)
for i := 0; i < len(queue); i++ {
switch e := queue[i].(type) {
case *ast.BadExpr:
panic("bad expr")

case *ast.BinaryExpr:
e.X = decompose(e.X)
e.Y = decompose(e.Y)

case *ast.CallExpr:
if i == 0 && (flags&multiExprStmt) != 0 {
// Need to hoist the CallExpr out into a temporary variable in
// this case, so that the relative order of calls (and their
// prerequisites) is preserved.
queue[i] = decompose(e)
} else {
e.Fun = decompose(e.Fun)
for i, arg := range e.Args {
e.Args[i] = decompose(arg)
}
}
case *ast.CompositeLit:
for i, elt := range e.Elts {
e.Elts[i] = decompose(elt)
}
// skip e.Type (type expression)

case *ast.Ellipsis:
e.Elt = decompose(e.Elt)
case *ast.Ellipsis:
e.Elt = decompose(e.Elt)

case *ast.IndexExpr:
e.X = decompose(e.X)
e.Index = decompose(e.Index)
case *ast.IndexExpr:
e.X = decompose(e.X)
e.Index = decompose(e.Index)

case *ast.IndexListExpr:
e.X = decompose(e.X)
// skip e.Indices (type expressions)
case *ast.IndexListExpr:
e.X = decompose(e.X)
// skip e.Indices (type expressions)

case *ast.KeyValueExpr:
e.Key = decompose(e.Key)
e.Value = decompose(e.Value)
case *ast.KeyValueExpr:
e.Key = decompose(e.Key)
e.Value = decompose(e.Value)

case *ast.ParenExpr:
e.X = decompose(e.X)
case *ast.ParenExpr:
e.X = decompose(e.X)

case *ast.SelectorExpr:
e.X = decompose(e.X)
case *ast.SelectorExpr:
e.X = decompose(e.X)

case *ast.SliceExpr:
e.X = decompose(e.X)
e.Low = decompose(e.Low)
e.Max = decompose(e.Max)
e.High = decompose(e.High)
case *ast.SliceExpr:
e.X = decompose(e.X)
e.Low = decompose(e.Low)
e.Max = decompose(e.Max)
e.High = decompose(e.High)

case *ast.StarExpr:
e.X = decompose(e.X)
case *ast.StarExpr:
e.X = decompose(e.X)

case *ast.TypeAssertExpr:
e.X = decompose(e.X)
// skip e.Type (type expression)
case *ast.TypeAssertExpr:
e.X = decompose(e.X)
// skip e.Type (type expression)

case *ast.UnaryExpr:
e.X = decompose(e.X)
case *ast.UnaryExpr:
e.X = decompose(e.X)

default:
panic(fmt.Sprintf("unsupported ast.Expr: %T", expr))
default:
panic(fmt.Sprintf("unsupported ast.Expr: %T", queue[i]))
}
}
prereqs := make([]ast.Stmt, len(tmps))
for i := range tmps {
prereqs[i] = &ast.AssignStmt{
Lhs: []ast.Expr{tmps[i]},
Tok: token.DEFINE,
Rhs: []ast.Expr{queue[i+1]},
}
}
return expr
reverse(prereqs)
return queue[0], prereqs
}

func reverse(stmts []ast.Stmt) {
Expand Down
29 changes: 25 additions & 4 deletions compiler/desugar_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1082,12 +1082,33 @@ _l0:
},
{
name: "decompose expressions in decls",
body: "var a, b int = c(d()), e(f())",
body: "var _, _ int = a(b(0)), c(d(1))",
// See https://go.dev/play/p/PkwoJbDLgQV for order of evaluation.
expect: `
{
_v1 := f()
_v0 := d()
var a, b int = c(_v0), e(_v1)
_v1 := b(0)
_v0 := a(_v1)
_v3 := d(1)
_v2 := c(_v3)
var _, _ int = _v0, _v2
}
`,
},
{
name: "decompose expressions in assignments",
body: "ints[a(b(0))], ints[c(d(1))] = e(f(10)), g(h(11))",
// See https://go.dev/play/p/WvrxhauFbsA for order of evaluation
expect: `
{
_v1 := b(0)
_v0 := a(_v1)
_v3 := d(1)
_v2 := c(_v3)
_v5 := f(10)
_v4 := e(_v5)
_v7 := h(11)
_v6 := g(_v7)
ints[_v0], ints[_v2] = _v4, _v6
}
`,
},
Expand Down
Loading

0 comments on commit c5643b9

Please sign in to comment.