Skip to content

Commit

Permalink
implement serializable closures (#43)
Browse files Browse the repository at this point in the history
This PR adds support for serializing closures.
  • Loading branch information
achille-roussel authored Sep 19, 2023
2 parents 01a0687 + ae25105 commit b75d687
Show file tree
Hide file tree
Showing 10 changed files with 179 additions and 17 deletions.
19 changes: 19 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Binaries for programs and plugins
*.exe
*.exe~
*.dll
*.so
*.dylib

# Test binary, built with `go test -c`
*.test
*.wasm

# Output of the go coverage tool, specifically when used with LiteIDE
*.out

# Dependency directories (remove the comment below to include it)
# vendor/

# Emacs
*~
6 changes: 6 additions & 0 deletions compiler/coroutine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,12 @@ func TestCoroutineYield(t *testing.T) {
yields: []int{0, 3, 6, 9},
},

{
name: "range over closure",
coro: func() { Range10Closure() },
yields: []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9},
},

{
name: "select",
coro: func() { Select(8) },
Expand Down
3 changes: 3 additions & 0 deletions compiler/dispatch.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ func trackDispatchSpans0(stmt ast.Stmt, dispatchSpans map[ast.Stmt]dispatchSpan,
for _, child := range s.List {
nextID = trackDispatchSpans0(child, dispatchSpans, nextID)
}
if len(s.List) == 0 {
nextID++
}
case *ast.IfStmt:
nextID = trackDispatchSpans0(s.Body, dispatchSpans, nextID)
if s.Else != nil {
Expand Down
52 changes: 43 additions & 9 deletions compiler/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,51 @@ func generateFunctypesInit(pkg *ssa.Package, fn *ssa.Function, init *ast.BlockSt
return // ignore non-instantiated generic functions
}

init.List = append(init.List, &ast.ExprStmt{
X: &ast.CallExpr{
Fun: &ast.IndexListExpr{
X: &ast.SelectorExpr{
X: ast.NewIdent("_types"),
Sel: ast.NewIdent("RegisterFunc"),
},
Indices: []ast.Expr{
newFuncType(fn.Signature),
var register ast.Expr
if len(fn.FreeVars) == 0 {
register = &ast.IndexListExpr{
X: &ast.SelectorExpr{
X: ast.NewIdent("_types"),
Sel: ast.NewIdent("RegisterFunc"),
},
Indices: []ast.Expr{
newFuncType(fn.Signature),
},
}
} else {
fields := make([]*ast.Field, 1+len(fn.FreeVars))
// first field is the function address (uintptr)
fields[0] = &ast.Field{
Names: []*ast.Ident{ast.NewIdent("_")},
Type: ast.NewIdent("uintptr"),
}

for i, freeVar := range fn.FreeVars {
fields[i+1] = &ast.Field{
Names: []*ast.Ident{ast.NewIdent(freeVar.Name())},
Type: typeExpr(freeVar.Type()),
}
}

register = &ast.IndexListExpr{
X: &ast.SelectorExpr{
X: ast.NewIdent("_types"),
Sel: ast.NewIdent("RegisterClosure"),
},
Indices: []ast.Expr{
newFuncType(fn.Signature),
&ast.StructType{
Fields: &ast.FieldList{
List: fields,
},
},
},
}
}

init.List = append(init.List, &ast.ExprStmt{
X: &ast.CallExpr{
Fun: register,
Args: []ast.Expr{
&ast.BasicLit{
Kind: token.STRING,
Expand Down
16 changes: 16 additions & 0 deletions compiler/testdata/coroutine.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,22 @@ func RangeTripleFuncValue(n int) {
Range(n, f)
}

func Range10Closure() {
i := 0
n := 10
f := func() bool {
if i < n {
coroutine.Yield[int, any](i)
i++
return true
}
return false
}

for f() {
}
}

func Select(n int) {
select {
default:
Expand Down
47 changes: 47 additions & 0 deletions compiler/testdata/coroutine_durable.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions compiler/testdata/coroutine_functypes.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions coroutine_durable.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ func (c Coroutine[R, S]) Next() (hasNext bool) {
case nil:
case unwind{}:
default:
// TODO: can we figure out a way to know when we are unwinding the
// stack and only recover then so we don't alter the panic stack?
panic(err)
}

Expand Down
35 changes: 27 additions & 8 deletions internal/serde/reflect.go
Original file line number Diff line number Diff line change
Expand Up @@ -404,18 +404,24 @@ func deserializePointer(d *Deserializer, t reflect.Type, p unsafe.Pointer) {
}

func serializeStruct(s *Serializer, t reflect.Type, p unsafe.Pointer) {
n := t.NumField()
serializeStructFields(s, p, t.NumField(), t.Field)
}

func deserializeStruct(d *Deserializer, t reflect.Type, p unsafe.Pointer) {
deserializeStructFields(d, p, t.NumField(), t.Field)
}

func serializeStructFields(s *Serializer, p unsafe.Pointer, n int, field func(int) reflect.StructField) {
for i := 0; i < n; i++ {
ft := t.Field(i)
ft := field(i)
fp := unsafe.Add(p, ft.Offset)
SerializeAny(s, ft.Type, fp)
}
}

func deserializeStruct(d *Deserializer, t reflect.Type, p unsafe.Pointer) {
n := t.NumField()
func deserializeStructFields(d *Deserializer, p unsafe.Pointer, n int, field func(int) reflect.StructField) {
for i := 0; i < n; i++ {
ft := t.Field(i)
ft := field(i)
fp := unsafe.Add(p, ft.Offset)
DeserializeAny(d, ft.Type, fp)
}
Expand All @@ -433,8 +439,11 @@ func serializeFunc(s *Serializer, t reflect.Type, p unsafe.Pointer) {
fn := types.FuncByAddr(*(*uintptr)(p))
SerializeString(s, &fn.Name)

if fn.Closure != nil { // TODO
panic("cannot serialize closures yet")
if fn.Closure != nil {
t := fn.Closure
serializeStructFields(s, p, t.NumField()-1, func(i int) reflect.StructField {
return t.Field(i + 1)
})
}
}

Expand All @@ -454,7 +463,17 @@ func deserializeFunc(d *Deserializer, t reflect.Type, p unsafe.Pointer) {
}

if fn.Closure != nil {
panic(name + ": cannot deserialize closures yet")
t := fn.Closure
v := reflect.New(t)

closure := v.UnsafePointer()
*(*uintptr)(closure) = fn.Addr

deserializeStructFields(d, closure, t.NumField()-1, func(i int) reflect.StructField {
return t.Field(i + 1)
})

*(*unsafe.Pointer)(p) = closure
} else {
*(**types.Func)(p) = fn
}
Expand Down
10 changes: 10 additions & 0 deletions types/func.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,16 @@ func RegisterFunc[Type any](name string) {
}
}

// RegisterClosure is like RegisterFunc but the caller can specify the closure
// type (see types.Func for details).
func RegisterClosure[Type, Closure any](name string) {
if f := FuncByName(name); f != nil {
var signature Type
var closure Closure
f.Type, f.Closure = reflect.TypeOf(signature), reflect.TypeOf(closure)
}
}

// Go function values are pointers to an object starting with the function
// address, whether they are referencing top-level functions or closures.
//
Expand Down

0 comments on commit b75d687

Please sign in to comment.