diff --git a/compiler/coroutine_test.go b/compiler/coroutine_test.go index 136f5bc..ee5c3d6 100644 --- a/compiler/coroutine_test.go +++ b/compiler/coroutine_test.go @@ -99,6 +99,12 @@ func TestCoroutineYield(t *testing.T) { coro: func() { RangeOverMaps(5) }, yields: []int{0, 5, 5, 50, 5, 4, 3, 2, 1, 0}, }, + + { + name: "select", + coro: func() { Select(8) }, + yields: []int{-1, 0, 0, 1, 10, 2, 20, 3, 30, 4, 40, 50, 0, 1, 2}, + }, } // TODO: remove me diff --git a/compiler/desugar.go b/compiler/desugar.go index 4363a33..37b415b 100644 --- a/compiler/desugar.go +++ b/compiler/desugar.go @@ -22,6 +22,10 @@ import ( // labels so that the desugaring pass (and other compilation passes) are able // to both decompose and introduce control flow. // +// Nondeterministic control flow and iteration (e.g. select, for..range +// over maps) is split into two parts so that yield points within can resume +// from the same place. +// // The desugaring pass works at the statement level (ast.Stmt) and does not // consider expressions (ast.Expr). This means that the pass does not // recurse into expressions that may contain statements. At this time, only @@ -273,6 +277,12 @@ func (d *desugarer) desugar(stmt ast.Stmt, breakTo, continueTo, userLabel *ast.I Body: d.desugarList(s.Body, breakTo, continueTo), } + case *ast.CommClause: + stmt = &ast.CommClause{ + Comm: d.desugar(s.Comm, nil, nil, nil), + Body: d.desugarList(s.Body, breakTo, continueTo), + } + case *ast.BranchStmt: if s.Label != nil { label := d.getUserLabel(s.Label) @@ -299,8 +309,40 @@ func (d *desugarer) desugar(stmt ast.Stmt, breakTo, continueTo, userLabel *ast.I // labels can be mapped. stmt = d.desugar(s.Stmt, breakTo, continueTo, s.Label) - case *ast.SelectStmt, *ast.CommClause: - panic("not implemented") + case *ast.SelectStmt: + // Rewrite select statements into a select+switch statement. The + // select cases exist only to record the selection; the select + // case bodies are moved into the switch statement over that + // selection. This allows coroutines to jump back to the right + // case when resuming. + selection := d.newVar(types.Typ[types.Int]) + switchBody := &ast.BlockStmt{List: make([]ast.Stmt, len(s.Body.List))} + for i, c := range s.Body.List { + cc := c.(*ast.CommClause) + id := &ast.BasicLit{Kind: token.INT, Value: strconv.Itoa(i + 1)} + switchBody.List[i] = &ast.CaseClause{ + List: []ast.Expr{id}, + Body: cc.Body, + } + cc.Body = []ast.Stmt{ + &ast.AssignStmt{ + Lhs: []ast.Expr{selection}, + Tok: token.ASSIGN, + Rhs: []ast.Expr{id}, + }, + } + } + stmt = &ast.BlockStmt{ + List: []ast.Stmt{ + &ast.AssignStmt{ + Lhs: []ast.Expr{selection}, + Tok: token.DEFINE, + Rhs: []ast.Expr{&ast.BasicLit{Kind: token.INT, Value: "0"}}, + }, + s, + d.desugar(&ast.SwitchStmt{Tag: selection, Body: switchBody}, breakTo, continueTo, userLabel), + }, + } case *ast.AssignStmt, *ast.DeclStmt, *ast.DeferStmt, *ast.EmptyStmt, *ast.ExprStmt, *ast.GoStmt, *ast.IncDecStmt, *ast.ReturnStmt, *ast.SendStmt: diff --git a/compiler/dispatch.go b/compiler/dispatch.go index de8b4c7..c108ff2 100644 --- a/compiler/dispatch.go +++ b/compiler/dispatch.go @@ -43,6 +43,12 @@ func trackDispatchSpans0(stmt ast.Stmt, dispatchSpans map[ast.Stmt]dispatchSpan, for _, child := range s.Body { nextID = trackDispatchSpans0(child, dispatchSpans, nextID) } + case *ast.SelectStmt: + nextID = trackDispatchSpans0(s.Body, dispatchSpans, nextID) + case *ast.CommClause: + for _, child := range s.Body { + nextID = trackDispatchSpans0(child, dispatchSpans, nextID) + } case *ast.LabeledStmt: nextID = trackDispatchSpans0(s.Stmt, dispatchSpans, nextID) default: @@ -105,6 +111,10 @@ func compileDispatch(stmt ast.Stmt, dispatchSpans map[ast.Stmt]dispatchSpan) ast for i, child := range s.Body.List { s.Body.List[i] = compileDispatch(child, dispatchSpans) } + case *ast.SelectStmt: + for i, child := range s.Body.List { + s.Body.List[i] = compileDispatch(child, dispatchSpans) + } case *ast.CaseClause: switch { case len(s.Body) == 1: @@ -113,6 +123,14 @@ func compileDispatch(stmt ast.Stmt, dispatchSpans map[ast.Stmt]dispatchSpan) ast case len(s.Body) > 1: s.Body = []ast.Stmt{compileDispatch0(s.Body, dispatchSpans)} } + case *ast.CommClause: + switch { + case len(s.Body) == 1: + child := compileDispatch(s.Body[0], dispatchSpans) + s.Body[0] = unnestBlocks(child) + case len(s.Body) > 1: + s.Body = []ast.Stmt{compileDispatch0(s.Body, dispatchSpans)} + } case *ast.LabeledStmt: s.Stmt = compileDispatch(s.Stmt, dispatchSpans) } diff --git a/compiler/testdata/coroutine.go b/compiler/testdata/coroutine.go index ed46bc6..053a40f 100644 --- a/compiler/testdata/coroutine.go +++ b/compiler/testdata/coroutine.go @@ -3,6 +3,7 @@ package testdata import ( + "time" "unsafe" "github.com/stealthrocket/coroutine" @@ -254,3 +255,38 @@ func RangeOverMaps(n int) { coroutine.Yield[int, any](len(m2)) } } + +func Select(n int) { + select { + default: + coroutine.Yield[int, any](-1) + } + + for i := 0; i < n; i++ { + select { + case <-time.After(0): + if i >= 5 { + break + } + coroutine.Yield[int, any](i) + case <-time.After(1 * time.Second): + panic("unreachable") + } + + foo: + select { + case <-time.After(0): + if i >= 6 { + break foo + } + coroutine.Yield[int, any](i * 10) + } + } + + select { + case <-time.After(0): + for j := 0; j < 3; j++ { + coroutine.Yield[int, any](j) + } + } +} diff --git a/compiler/testdata/coroutine_durable.go b/compiler/testdata/coroutine_durable.go index 031bb4d..bc0f703 100644 --- a/compiler/testdata/coroutine_durable.go +++ b/compiler/testdata/coroutine_durable.go @@ -10,6 +10,8 @@ import ( runtime "runtime" sync "sync" atomic "sync/atomic" + syscall "syscall" + time "time" unsafe "unsafe" ) @@ -1314,6 +1316,182 @@ func RangeOverMaps(n int) { } } } + +func Select(n int) { + _c := coroutine.LoadContext[int, any]() + _f, _fp := _c.Push() + var _o0 int + var _o1 int + var _o2 int + var _o3 int + var _o4 int + var _o5 int + if _f.IP > 0 { + n = _f.Get(0).(int) + _o0 = _f.Get(1).(int) + + _o1 = _f.Get(2).(int) + _o2 = _f.Get(3).(int) + _o3 = _f.Get(4).(int) + _o4 = _f.Get(5).(int) + + _o5 = _f.Get(6).(int) + } + defer func() { + if _c.Unwinding() { + _f.Set(0, n) + _f.Set(1, _o0) + _f.Set(2, _o1) + _f.Set(3, _o2) + _f.Set(4, _o3) + _f.Set(5, _o4) + _f.Set(6, _o5) + _c.Store(_fp, _f) + } else { + _c.Pop() + } + }() + switch { + case _f.IP < 4: + switch { + case _f.IP < 2: + _o0 = 0 + _f.IP = 2 + fallthrough + case _f.IP < 3: + select { + default: + _o0 = 1 + + } + _f.IP = 3 + fallthrough + case _f.IP < 4: + switch _o0 { + case 1: + coroutine.Yield[int, any](-1) + } + } + _f.IP = 4 + fallthrough + case _f.IP < 15: + switch { + case _f.IP < 5: + + _o1 = 0 + _f.IP = 5 + fallthrough + case _f.IP < 15: + for ; _o1 < n; _o1, _f.IP = _o1+1, 5 { + switch { + case _f.IP < 11: + switch { + case _f.IP < 6: + _o2 = 0 + _f.IP = 6 + fallthrough + case _f.IP < 8: + select { + case <-time.After(0): + _o2 = 1 + + case <-time.After(1 * time.Second): + _o2 = 2 + + } + _f.IP = 8 + fallthrough + case _f.IP < 11: + _l2: + switch _o2 { + case 1: + switch { + case _f.IP < 9: + if _o1 >= 5 { + break _l2 + } + _f.IP = 9 + fallthrough + case _f.IP < 10: + + coroutine.Yield[int, any](_o1) + } + case 2: + + panic("unreachable") + } + } + _f.IP = 11 + fallthrough + case _f.IP < 15: + switch { + case _f.IP < 12: + _o3 = 0 + _f.IP = 12 + fallthrough + case _f.IP < 13: + + select { + case <-time.After(0): + _o3 = 1 + + } + _f.IP = 13 + fallthrough + case _f.IP < 15: + _l3: + switch _o3 { + case 1: + switch { + case _f.IP < 14: + if _o1 >= 6 { + break _l3 + } + _f.IP = 14 + fallthrough + case _f.IP < 15: + + coroutine.Yield[int, any](_o1 * 10) + } + } + } + } + } + } + _f.IP = 15 + fallthrough + case _f.IP < 19: + switch { + case _f.IP < 16: + _o4 = 0 + _f.IP = 16 + fallthrough + case _f.IP < 17: + + select { + case <-time.After(0): + _o4 = 1 + + } + _f.IP = 17 + fallthrough + case _f.IP < 19: + switch _o4 { + case 1: + switch { + case _f.IP < 18: + _o5 = 0 + _f.IP = 18 + fallthrough + case _f.IP < 19: + for ; _o5 < 3; _o5, _f.IP = _o5+1, 18 { + coroutine.Yield[int, any](_o5) + } + } + } + } + } +} func init() { serde.RegisterType[**byte]() serde.RegisterType[*[100000]uintptr]() @@ -1325,8 +1503,10 @@ func init() { serde.RegisterType[*[1]uintptr]() serde.RegisterType[*[268435456]uintptr]() serde.RegisterType[*[281474976710655]uint32]() + serde.RegisterType[*[2]byte]() serde.RegisterType[*[2]float32]() serde.RegisterType[*[2]float64]() + serde.RegisterType[*[2]int32]() serde.RegisterType[*[2]uint32]() serde.RegisterType[*[2]uintptr]() serde.RegisterType[*[32]rune]() @@ -1337,6 +1517,7 @@ func init() { serde.RegisterType[*[70368744177663]uint16]() serde.RegisterType[*[8]byte]() serde.RegisterType[*[8]uint8]() + serde.RegisterType[*[]byte]() serde.RegisterType[*[]uint64]() serde.RegisterType[*bool]() serde.RegisterType[*byte]() @@ -1349,6 +1530,8 @@ func init() { serde.RegisterType[*uint64]() serde.RegisterType[*uint8]() serde.RegisterType[*uintptr]() + serde.RegisterType[[0]byte]() + serde.RegisterType[[0]uint8]() serde.RegisterType[[0]uintptr]() serde.RegisterType[[1000]uintptr]() serde.RegisterType[[100]byte]() @@ -1358,6 +1541,7 @@ func init() { serde.RegisterType[[1048576]uint8]() serde.RegisterType[[104]byte]() serde.RegisterType[[108]byte]() + serde.RegisterType[[108]int8]() serde.RegisterType[[10]byte]() serde.RegisterType[[10]string]() serde.RegisterType[[128]byte]() @@ -1365,11 +1549,15 @@ func init() { serde.RegisterType[[128]uintptr]() serde.RegisterType[[129]uint8]() serde.RegisterType[[131072]uintptr]() + serde.RegisterType[[133]string]() + serde.RegisterType[[13]int32]() serde.RegisterType[[14]byte]() + serde.RegisterType[[14]int8]() serde.RegisterType[[15]uint64]() serde.RegisterType[[16384]byte]() serde.RegisterType[[16384]uint8]() serde.RegisterType[[16]byte]() + serde.RegisterType[[16]int64]() serde.RegisterType[[16]uint64]() serde.RegisterType[[17]string]() serde.RegisterType[[1]byte]() @@ -1384,22 +1572,30 @@ func init() { serde.RegisterType[[24]uint32]() serde.RegisterType[[252]uintptr]() serde.RegisterType[[253]uintptr]() + serde.RegisterType[[256]int8]() serde.RegisterType[[256]uint64]() + serde.RegisterType[[2]byte]() + serde.RegisterType[[2]int]() + serde.RegisterType[[2]int32]() serde.RegisterType[[2]uint64]() serde.RegisterType[[2]uintptr]() serde.RegisterType[[32]byte]() serde.RegisterType[[32]string]() + serde.RegisterType[[32]uint8]() serde.RegisterType[[32]uintptr]() serde.RegisterType[[33]float64]() serde.RegisterType[[3]byte]() serde.RegisterType[[3]int]() + serde.RegisterType[[3]int64]() serde.RegisterType[[3]uint16]() serde.RegisterType[[3]uint32]() + serde.RegisterType[[3]uint64]() serde.RegisterType[[4096]byte]() serde.RegisterType[[40]byte]() serde.RegisterType[[44]byte]() serde.RegisterType[[4]byte]() serde.RegisterType[[4]float64]() + serde.RegisterType[[4]int64]() serde.RegisterType[[4]string]() serde.RegisterType[[4]uint16]() serde.RegisterType[[4]uint32]() @@ -1419,6 +1615,7 @@ func init() { serde.RegisterType[[64]byte]() serde.RegisterType[[64]uintptr]() serde.RegisterType[[65528]byte]() + serde.RegisterType[[65]int8]() serde.RegisterType[[65]uint32]() serde.RegisterType[[65]uintptr]() serde.RegisterType[[68]struct { @@ -1430,20 +1627,29 @@ func init() { serde.RegisterType[[68]uint32]() serde.RegisterType[[68]uint64]() serde.RegisterType[[68]uint8]() + serde.RegisterType[[6]byte]() + serde.RegisterType[[6]int]() + serde.RegisterType[[6]int8]() serde.RegisterType[[6]uintptr]() serde.RegisterType[[8192]byte]() serde.RegisterType[[8]byte]() serde.RegisterType[[8]string]() + serde.RegisterType[[8]uint32]() serde.RegisterType[[8]uint64]() serde.RegisterType[[8]uint8]() serde.RegisterType[[96]byte]() + serde.RegisterType[[96]int8]() serde.RegisterType[[9]string]() serde.RegisterType[[9]uintptr]() + serde.RegisterType[[]*byte]() serde.RegisterType[[][]int32]() serde.RegisterType[[]byte]() serde.RegisterType[[]float64]() serde.RegisterType[[]int]() + serde.RegisterType[[]int16]() serde.RegisterType[[]int32]() + serde.RegisterType[[]int64]() + serde.RegisterType[[]int8]() serde.RegisterType[[]rune]() serde.RegisterType[[]string]() serde.RegisterType[[]uint16]() @@ -1468,9 +1674,11 @@ func init() { serde.RegisterType[int32]() serde.RegisterType[int64]() serde.RegisterType[int8]() + serde.RegisterType[map[*byte][]byte]() serde.RegisterType[map[int]int]() serde.RegisterType[map[int]struct{}]() serde.RegisterType[map[string]bool]() + serde.RegisterType[map[string]int]() serde.RegisterType[map[string]uint64]() serde.RegisterType[rune]() serde.RegisterType[runtime.BlockProfileRecord]() @@ -1484,6 +1692,10 @@ func init() { serde.RegisterType[runtime.StackRecord]() serde.RegisterType[runtime.TypeAssertionError]() serde.RegisterType[string]() + serde.RegisterType[struct { + b bool + x any + }]() serde.RegisterType[struct { base uintptr end uintptr @@ -1510,6 +1722,82 @@ func init() { serde.RegisterType[sync.Pool]() serde.RegisterType[sync.RWMutex]() serde.RegisterType[sync.WaitGroup]() + serde.RegisterType[syscall.Cmsghdr]() + serde.RegisterType[syscall.Credential]() + serde.RegisterType[syscall.Dirent]() + serde.RegisterType[syscall.EpollEvent]() + serde.RegisterType[syscall.Errno]() + serde.RegisterType[syscall.FdSet]() + serde.RegisterType[syscall.Flock_t]() + serde.RegisterType[syscall.Fsid]() + serde.RegisterType[syscall.ICMPv6Filter]() + serde.RegisterType[syscall.IPMreq]() + serde.RegisterType[syscall.IPMreqn]() + serde.RegisterType[syscall.IPv6MTUInfo]() + serde.RegisterType[syscall.IPv6Mreq]() + serde.RegisterType[syscall.IfAddrmsg]() + serde.RegisterType[syscall.IfInfomsg]() + serde.RegisterType[syscall.Inet4Pktinfo]() + serde.RegisterType[syscall.Inet6Pktinfo]() + serde.RegisterType[syscall.InotifyEvent]() + serde.RegisterType[syscall.Iovec]() + serde.RegisterType[syscall.Linger]() + serde.RegisterType[syscall.Msghdr]() + serde.RegisterType[syscall.NetlinkMessage]() + serde.RegisterType[syscall.NetlinkRouteAttr]() + serde.RegisterType[syscall.NetlinkRouteRequest]() + serde.RegisterType[syscall.NlAttr]() + serde.RegisterType[syscall.NlMsgerr]() + serde.RegisterType[syscall.NlMsghdr]() + serde.RegisterType[syscall.ProcAttr]() + serde.RegisterType[syscall.PtraceRegs]() + serde.RegisterType[syscall.RawSockaddr]() + serde.RegisterType[syscall.RawSockaddrAny]() + serde.RegisterType[syscall.RawSockaddrInet4]() + serde.RegisterType[syscall.RawSockaddrInet6]() + serde.RegisterType[syscall.RawSockaddrLinklayer]() + serde.RegisterType[syscall.RawSockaddrNetlink]() + serde.RegisterType[syscall.RawSockaddrUnix]() + serde.RegisterType[syscall.Rlimit]() + serde.RegisterType[syscall.RtAttr]() + serde.RegisterType[syscall.RtGenmsg]() + serde.RegisterType[syscall.RtMsg]() + serde.RegisterType[syscall.RtNexthop]() + serde.RegisterType[syscall.Rusage]() + serde.RegisterType[syscall.Signal]() + serde.RegisterType[syscall.SockFilter]() + serde.RegisterType[syscall.SockFprog]() + serde.RegisterType[syscall.SockaddrInet4]() + serde.RegisterType[syscall.SockaddrInet6]() + serde.RegisterType[syscall.SockaddrLinklayer]() + serde.RegisterType[syscall.SockaddrNetlink]() + serde.RegisterType[syscall.SockaddrUnix]() + serde.RegisterType[syscall.SocketControlMessage]() + serde.RegisterType[syscall.Stat_t]() + serde.RegisterType[syscall.Statfs_t]() + serde.RegisterType[syscall.SysProcAttr]() + serde.RegisterType[syscall.SysProcIDMap]() + serde.RegisterType[syscall.Sysinfo_t]() + serde.RegisterType[syscall.TCPInfo]() + serde.RegisterType[syscall.Termios]() + serde.RegisterType[syscall.Time_t]() + serde.RegisterType[syscall.Timespec]() + serde.RegisterType[syscall.Timeval]() + serde.RegisterType[syscall.Timex]() + serde.RegisterType[syscall.Tms]() + serde.RegisterType[syscall.Ucred]() + serde.RegisterType[syscall.Ustat_t]() + serde.RegisterType[syscall.Utimbuf]() + serde.RegisterType[syscall.Utsname]() + serde.RegisterType[syscall.WaitStatus]() + serde.RegisterType[time.Duration]() + serde.RegisterType[time.Location]() + serde.RegisterType[time.Month]() + serde.RegisterType[time.ParseError]() + serde.RegisterType[time.Ticker]() + serde.RegisterType[time.Time]() + serde.RegisterType[time.Timer]() + serde.RegisterType[time.Weekday]() serde.RegisterType[uint]() serde.RegisterType[uint16]() serde.RegisterType[uint32]() diff --git a/compiler/unsupported.go b/compiler/unsupported.go index fb3d25a..739c25a 100644 --- a/compiler/unsupported.go +++ b/compiler/unsupported.go @@ -27,10 +27,6 @@ func unsupported(decl *ast.FuncDecl, info *types.Info) (err error) { err = fmt.Errorf("not implemented: defer") case *ast.GoStmt: err = fmt.Errorf("not implemented: go") - case *ast.SelectStmt: - err = fmt.Errorf("not implemented: select") - case *ast.CommClause: - err = fmt.Errorf("not implemented: select case") // Partially supported: case *ast.BranchStmt: @@ -41,7 +37,7 @@ func unsupported(decl *ast.FuncDecl, info *types.Info) (err error) { err = fmt.Errorf("not implemented: fallthrough") } case *ast.LabeledStmt: - // labeled for/switch/select statements are supported, + // Labeled for/switch/select statements are supported, // arbitrary labels are not. switch n.Stmt.(type) { case *ast.ForStmt, *ast.SwitchStmt, *ast.TypeSwitchStmt, *ast.SelectStmt: @@ -65,6 +61,7 @@ func unsupported(decl *ast.FuncDecl, info *types.Info) (err error) { case *ast.AssignStmt: case *ast.BlockStmt: case *ast.CaseClause: + case *ast.CommClause: case *ast.DeclStmt: case *ast.EmptyStmt: case *ast.ExprStmt: @@ -72,6 +69,7 @@ func unsupported(decl *ast.FuncDecl, info *types.Info) (err error) { case *ast.IncDecStmt: case *ast.RangeStmt: case *ast.ReturnStmt: + case *ast.SelectStmt: case *ast.SendStmt: case *ast.SwitchStmt: case *ast.TypeSwitchStmt: