Skip to content

Commit

Permalink
Select support (#36)
Browse files Browse the repository at this point in the history
This PR adds support for `select` statements.
  • Loading branch information
chriso authored Sep 19, 2023
2 parents c0c0475 + 2727cb8 commit 5ae0fba
Show file tree
Hide file tree
Showing 6 changed files with 395 additions and 7 deletions.
6 changes: 6 additions & 0 deletions compiler/coroutine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,12 @@ func TestCoroutineYield(t *testing.T) {
coro: func() { RangeOverMaps(5) },
yields: []int{0, 5, 5, 50, 5, 4, 3, 2, 1, 0},
},

{
name: "select",
coro: func() { Select(8) },
yields: []int{-1, 0, 0, 1, 10, 2, 20, 3, 30, 4, 40, 50, 0, 1, 2},
},
}

// TODO: remove me
Expand Down
46 changes: 44 additions & 2 deletions compiler/desugar.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ import (
// labels so that the desugaring pass (and other compilation passes) are able
// to both decompose and introduce control flow.
//
// Nondeterministic control flow and iteration (e.g. select, for..range
// over maps) is split into two parts so that yield points within can resume
// from the same place.
//
// The desugaring pass works at the statement level (ast.Stmt) and does not
// consider expressions (ast.Expr). This means that the pass does not
// recurse into expressions that may contain statements. At this time, only
Expand Down Expand Up @@ -273,6 +277,12 @@ func (d *desugarer) desugar(stmt ast.Stmt, breakTo, continueTo, userLabel *ast.I
Body: d.desugarList(s.Body, breakTo, continueTo),
}

case *ast.CommClause:
stmt = &ast.CommClause{
Comm: d.desugar(s.Comm, nil, nil, nil),
Body: d.desugarList(s.Body, breakTo, continueTo),
}

case *ast.BranchStmt:
if s.Label != nil {
label := d.getUserLabel(s.Label)
Expand All @@ -299,8 +309,40 @@ func (d *desugarer) desugar(stmt ast.Stmt, breakTo, continueTo, userLabel *ast.I
// labels can be mapped.
stmt = d.desugar(s.Stmt, breakTo, continueTo, s.Label)

case *ast.SelectStmt, *ast.CommClause:
panic("not implemented")
case *ast.SelectStmt:
// Rewrite select statements into a select+switch statement. The
// select cases exist only to record the selection; the select
// case bodies are moved into the switch statement over that
// selection. This allows coroutines to jump back to the right
// case when resuming.
selection := d.newVar(types.Typ[types.Int])
switchBody := &ast.BlockStmt{List: make([]ast.Stmt, len(s.Body.List))}
for i, c := range s.Body.List {
cc := c.(*ast.CommClause)
id := &ast.BasicLit{Kind: token.INT, Value: strconv.Itoa(i + 1)}
switchBody.List[i] = &ast.CaseClause{
List: []ast.Expr{id},
Body: cc.Body,
}
cc.Body = []ast.Stmt{
&ast.AssignStmt{
Lhs: []ast.Expr{selection},
Tok: token.ASSIGN,
Rhs: []ast.Expr{id},
},
}
}
stmt = &ast.BlockStmt{
List: []ast.Stmt{
&ast.AssignStmt{
Lhs: []ast.Expr{selection},
Tok: token.DEFINE,
Rhs: []ast.Expr{&ast.BasicLit{Kind: token.INT, Value: "0"}},
},
s,
d.desugar(&ast.SwitchStmt{Tag: selection, Body: switchBody}, breakTo, continueTo, userLabel),
},
}

case *ast.AssignStmt, *ast.DeclStmt, *ast.DeferStmt, *ast.EmptyStmt,
*ast.ExprStmt, *ast.GoStmt, *ast.IncDecStmt, *ast.ReturnStmt, *ast.SendStmt:
Expand Down
18 changes: 18 additions & 0 deletions compiler/dispatch.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ func trackDispatchSpans0(stmt ast.Stmt, dispatchSpans map[ast.Stmt]dispatchSpan,
for _, child := range s.Body {
nextID = trackDispatchSpans0(child, dispatchSpans, nextID)
}
case *ast.SelectStmt:
nextID = trackDispatchSpans0(s.Body, dispatchSpans, nextID)
case *ast.CommClause:
for _, child := range s.Body {
nextID = trackDispatchSpans0(child, dispatchSpans, nextID)
}
case *ast.LabeledStmt:
nextID = trackDispatchSpans0(s.Stmt, dispatchSpans, nextID)
default:
Expand Down Expand Up @@ -105,6 +111,10 @@ func compileDispatch(stmt ast.Stmt, dispatchSpans map[ast.Stmt]dispatchSpan) ast
for i, child := range s.Body.List {
s.Body.List[i] = compileDispatch(child, dispatchSpans)
}
case *ast.SelectStmt:
for i, child := range s.Body.List {
s.Body.List[i] = compileDispatch(child, dispatchSpans)
}
case *ast.CaseClause:
switch {
case len(s.Body) == 1:
Expand All @@ -113,6 +123,14 @@ func compileDispatch(stmt ast.Stmt, dispatchSpans map[ast.Stmt]dispatchSpan) ast
case len(s.Body) > 1:
s.Body = []ast.Stmt{compileDispatch0(s.Body, dispatchSpans)}
}
case *ast.CommClause:
switch {
case len(s.Body) == 1:
child := compileDispatch(s.Body[0], dispatchSpans)
s.Body[0] = unnestBlocks(child)
case len(s.Body) > 1:
s.Body = []ast.Stmt{compileDispatch0(s.Body, dispatchSpans)}
}
case *ast.LabeledStmt:
s.Stmt = compileDispatch(s.Stmt, dispatchSpans)
}
Expand Down
36 changes: 36 additions & 0 deletions compiler/testdata/coroutine.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package testdata

import (
"time"
"unsafe"

"github.com/stealthrocket/coroutine"
Expand Down Expand Up @@ -254,3 +255,38 @@ func RangeOverMaps(n int) {
coroutine.Yield[int, any](len(m2))
}
}

func Select(n int) {
select {
default:
coroutine.Yield[int, any](-1)
}

for i := 0; i < n; i++ {
select {
case <-time.After(0):
if i >= 5 {
break
}
coroutine.Yield[int, any](i)
case <-time.After(1 * time.Second):
panic("unreachable")
}

foo:
select {
case <-time.After(0):
if i >= 6 {
break foo
}
coroutine.Yield[int, any](i * 10)
}
}

select {
case <-time.After(0):
for j := 0; j < 3; j++ {
coroutine.Yield[int, any](j)
}
}
}
Loading

0 comments on commit 5ae0fba

Please sign in to comment.