diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d373ff3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,19 @@ +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, built with `go test -c` +*.test +*.wasm + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Dependency directories (remove the comment below to include it) +# vendor/ + +# Emacs +*~ diff --git a/compiler/coroutine_test.go b/compiler/coroutine_test.go index 914900e..a7a2e37 100644 --- a/compiler/coroutine_test.go +++ b/compiler/coroutine_test.go @@ -119,6 +119,12 @@ func TestCoroutineYield(t *testing.T) { yields: []int{0, 3, 6, 9}, }, + { + name: "range over closure", + coro: func() { Range10Closure() }, + yields: []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, + }, + { name: "select", coro: func() { Select(8) }, diff --git a/compiler/dispatch.go b/compiler/dispatch.go index 04fa507..58d4241 100644 --- a/compiler/dispatch.go +++ b/compiler/dispatch.go @@ -28,6 +28,9 @@ func trackDispatchSpans0(stmt ast.Stmt, dispatchSpans map[ast.Stmt]dispatchSpan, for _, child := range s.List { nextID = trackDispatchSpans0(child, dispatchSpans, nextID) } + if len(s.List) == 0 { + nextID++ + } case *ast.IfStmt: nextID = trackDispatchSpans0(s.Body, dispatchSpans, nextID) if s.Else != nil { diff --git a/compiler/function.go b/compiler/function.go index 497c54c..75104c8 100644 --- a/compiler/function.go +++ b/compiler/function.go @@ -55,17 +55,51 @@ func generateFunctypesInit(pkg *ssa.Package, fn *ssa.Function, init *ast.BlockSt return // ignore non-instantiated generic functions } - init.List = append(init.List, &ast.ExprStmt{ - X: &ast.CallExpr{ - Fun: &ast.IndexListExpr{ - X: &ast.SelectorExpr{ - X: ast.NewIdent("_types"), - Sel: ast.NewIdent("RegisterFunc"), - }, - Indices: []ast.Expr{ - newFuncType(fn.Signature), + var register ast.Expr + if len(fn.FreeVars) == 0 { + register = &ast.IndexListExpr{ + X: &ast.SelectorExpr{ + X: ast.NewIdent("_types"), + Sel: ast.NewIdent("RegisterFunc"), + }, + Indices: []ast.Expr{ + newFuncType(fn.Signature), + }, + } + } else { + fields := make([]*ast.Field, 1+len(fn.FreeVars)) + // first field is the function address (uintptr) + fields[0] = &ast.Field{ + Names: []*ast.Ident{ast.NewIdent("_")}, + Type: ast.NewIdent("uintptr"), + } + + for i, freeVar := range fn.FreeVars { + fields[i+1] = &ast.Field{ + Names: []*ast.Ident{ast.NewIdent(freeVar.Name())}, + Type: typeExpr(freeVar.Type()), + } + } + + register = &ast.IndexListExpr{ + X: &ast.SelectorExpr{ + X: ast.NewIdent("_types"), + Sel: ast.NewIdent("RegisterClosure"), + }, + Indices: []ast.Expr{ + newFuncType(fn.Signature), + &ast.StructType{ + Fields: &ast.FieldList{ + List: fields, + }, }, }, + } + } + + init.List = append(init.List, &ast.ExprStmt{ + X: &ast.CallExpr{ + Fun: register, Args: []ast.Expr{ &ast.BasicLit{ Kind: token.STRING, diff --git a/compiler/testdata/coroutine.go b/compiler/testdata/coroutine.go index 71817cb..5ade359 100644 --- a/compiler/testdata/coroutine.go +++ b/compiler/testdata/coroutine.go @@ -279,6 +279,22 @@ func RangeTripleFuncValue(n int) { Range(n, f) } +func Range10Closure() { + i := 0 + n := 10 + f := func() bool { + if i < n { + coroutine.Yield[int, any](i) + i++ + return true + } + return false + } + + for f() { + } +} + func Select(n int) { select { default: diff --git a/compiler/testdata/coroutine_durable.go b/compiler/testdata/coroutine_durable.go index e23dfb3..22c1101 100644 --- a/compiler/testdata/coroutine_durable.go +++ b/compiler/testdata/coroutine_durable.go @@ -1394,6 +1394,53 @@ func RangeTripleFuncValue(n int) { } } +func Range10Closure() { + _c := coroutine.LoadContext[int, any]() + _f, _fp := _c.Push() + var _o0 int + var _o1 int + var _o2 func() bool + if _f.IP > 0 { + _o0 = _f.Get(0).(int) + _o1 = _f.Get(1).(int) + _o2 = _f.Get(2).(func() bool) + } + defer func() { + if _c.Unwinding() { + _f.Set(0, _o0) + _f.Set(1, _o1) + _f.Set(2, _o2) + _c.Store(_fp, _f) + } else { + _c.Pop() + } + }() + switch { + case _f.IP < 2: + _o0 = 0 + _f.IP = 2 + fallthrough + case _f.IP < 3: + _o1 = 10 + _f.IP = 3 + fallthrough + case _f.IP < 4: + _o2 = func() bool { + if _o0 < _o1 { + coroutine.Yield[int, any](_o0) + _o0++ + return true + } + return false + } + _f.IP = 4 + fallthrough + case _f.IP < 5: + for ; _o2(); _f.IP = 4 { + } + } +} + func Select(n int) { _c := coroutine.LoadContext[int, any]() _f, _fp := _c.Push() diff --git a/compiler/testdata/coroutine_functypes.go b/compiler/testdata/coroutine_functypes.go index 439a281..fb76211 100644 --- a/compiler/testdata/coroutine_functypes.go +++ b/compiler/testdata/coroutine_functypes.go @@ -15,6 +15,12 @@ func init() { _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.LoopBreakAndContinue") _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.NestedLoops") _types.RegisterFunc[func(int, func(int))]("github.com/stealthrocket/coroutine/compiler/testdata.Range") + _types.RegisterFunc[func()]("github.com/stealthrocket/coroutine/compiler/testdata.Range10Closure") + _types.RegisterClosure[func() bool, struct { + _ uintptr + i *int + n *int + }]("github.com/stealthrocket/coroutine/compiler/testdata.Range10Closure.func2") _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.RangeArrayIndexValueGenerator") _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.RangeOverMaps") _types.RegisterFunc[func(int)]("github.com/stealthrocket/coroutine/compiler/testdata.RangeSliceIndexGenerator") diff --git a/coroutine_durable.go b/coroutine_durable.go index 07b6347..1f6cf37 100644 --- a/coroutine_durable.go +++ b/coroutine_durable.go @@ -116,6 +116,8 @@ func (c Coroutine[R, S]) Next() (hasNext bool) { case nil: case unwind{}: default: + // TODO: can we figure out a way to know when we are unwinding the + // stack and only recover then so we don't alter the panic stack? panic(err) } diff --git a/internal/serde/reflect.go b/internal/serde/reflect.go index 846868b..39ec956 100644 --- a/internal/serde/reflect.go +++ b/internal/serde/reflect.go @@ -404,18 +404,24 @@ func deserializePointer(d *Deserializer, t reflect.Type, p unsafe.Pointer) { } func serializeStruct(s *Serializer, t reflect.Type, p unsafe.Pointer) { - n := t.NumField() + serializeStructFields(s, p, t.NumField(), t.Field) +} + +func deserializeStruct(d *Deserializer, t reflect.Type, p unsafe.Pointer) { + deserializeStructFields(d, p, t.NumField(), t.Field) +} + +func serializeStructFields(s *Serializer, p unsafe.Pointer, n int, field func(int) reflect.StructField) { for i := 0; i < n; i++ { - ft := t.Field(i) + ft := field(i) fp := unsafe.Add(p, ft.Offset) SerializeAny(s, ft.Type, fp) } } -func deserializeStruct(d *Deserializer, t reflect.Type, p unsafe.Pointer) { - n := t.NumField() +func deserializeStructFields(d *Deserializer, p unsafe.Pointer, n int, field func(int) reflect.StructField) { for i := 0; i < n; i++ { - ft := t.Field(i) + ft := field(i) fp := unsafe.Add(p, ft.Offset) DeserializeAny(d, ft.Type, fp) } @@ -433,8 +439,11 @@ func serializeFunc(s *Serializer, t reflect.Type, p unsafe.Pointer) { fn := types.FuncByAddr(*(*uintptr)(p)) SerializeString(s, &fn.Name) - if fn.Closure != nil { // TODO - panic("cannot serialize closures yet") + if fn.Closure != nil { + t := fn.Closure + serializeStructFields(s, p, t.NumField()-1, func(i int) reflect.StructField { + return t.Field(i + 1) + }) } } @@ -454,7 +463,17 @@ func deserializeFunc(d *Deserializer, t reflect.Type, p unsafe.Pointer) { } if fn.Closure != nil { - panic(name + ": cannot deserialize closures yet") + t := fn.Closure + v := reflect.New(t) + + closure := v.UnsafePointer() + *(*uintptr)(closure) = fn.Addr + + deserializeStructFields(d, closure, t.NumField()-1, func(i int) reflect.StructField { + return t.Field(i + 1) + }) + + *(*unsafe.Pointer)(p) = closure } else { *(**types.Func)(p) = fn } diff --git a/types/func.go b/types/func.go index d1441a4..2d08929 100644 --- a/types/func.go +++ b/types/func.go @@ -59,6 +59,16 @@ func RegisterFunc[Type any](name string) { } } +// RegisterClosure is like RegisterFunc but the caller can specify the closure +// type (see types.Func for details). +func RegisterClosure[Type, Closure any](name string) { + if f := FuncByName(name); f != nil { + var signature Type + var closure Closure + f.Type, f.Closure = reflect.TypeOf(signature), reflect.TypeOf(closure) + } +} + // Go function values are pointers to an object starting with the function // address, whether they are referencing top-level functions or closures. //