Skip to content

Commit

Permalink
proto clients
Browse files Browse the repository at this point in the history
  • Loading branch information
jackkleeman committed Nov 7, 2024
1 parent 1615efa commit 07b19db
Show file tree
Hide file tree
Showing 3 changed files with 226 additions and 16 deletions.
26 changes: 26 additions & 0 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ import (
"github.com/restatedev/sdk-go/internal/options"
)

// re-export for use in generated code
type IngressClientOption = options.IngressClientOption

type ingressContextKey struct{}

func Connect(ctx context.Context, ingressURL string, opts ...options.ConnectOption) (context.Context, error) {
Expand Down Expand Up @@ -506,3 +509,26 @@ func RejectAwakeable(ctx context.Context, id string, reason error) error {
}
return nil
}

type withRequestType[I any, O any] struct {
inner IngressClient[any, O]
}

func (w withRequestType[I, O]) Request(input I, options ...options.RequestOption) (O, error) {
return w.inner.Request(input, options...)
}

func (w withRequestType[I, O]) RequestFuture(input I, options ...options.RequestOption) IngressResponseFuture[O] {
return w.inner.RequestFuture(input, options...)
}

func (w withRequestType[I, O]) Send(input I, options ...options.SendOption) (Send, error) {
return w.inner.Send(input, options...)
}

// WithRequestType is primarily intended to be called from generated code, to provide
// type safety of input types. In other contexts it's generally less cumbersome to use [Object] and [Service],
// as the output type can be inferred.
func WithRequestType[I any, O any](inner IngressClient[any, O]) IngressClient[I, O] {
return withRequestType[I, O]{inner}
}
87 changes: 87 additions & 0 deletions examples/codegen/proto/helloworld_restate.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

129 changes: 113 additions & 16 deletions protoc-gen-go-restate/restate.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@ import (
)

const (
fmtPackage = protogen.GoImportPath("fmt")
sdkPackage = protogen.GoImportPath("github.com/restatedev/sdk-go")
fmtPackage = protogen.GoImportPath("fmt")
contextPackage = protogen.GoImportPath("context")
sdkPackage = protogen.GoImportPath("github.com/restatedev/sdk-go")
clientPackage = protogen.GoImportPath("github.com/restatedev/sdk-go/client")
)

type serviceGenerateHelper struct{}
Expand All @@ -32,6 +34,17 @@ func generateClientStruct(g *protogen.GeneratedFile, service *protogen.Service,
g.P("}")
}

func generateIngressClientStruct(g *protogen.GeneratedFile, service *protogen.Service, clientName string) {
g.P("type ", unexport(clientName), " struct {")
g.P("ctx ", contextPackage.Ident("Context"))
serviceType := proto.GetExtension(service.Desc.Options().(*descriptorpb.ServiceOptions), sdk.E_ServiceType).(sdk.ServiceType)
if serviceType == sdk.ServiceType_VIRTUAL_OBJECT {
g.P("key string")
}
g.P("options []", clientPackage.Ident("IngressClientOption"))
g.P("}")
}

func generateNewClientDefinitions(g *protogen.GeneratedFile, service *protogen.Service, clientName string) {
g.P("cOpts := append([]", sdkPackage.Ident("ClientOption"), "{", sdkPackage.Ident("WithProtoJSON"), "}, opts...)")
g.P("return &", unexport(clientName), "{")
Expand All @@ -47,6 +60,18 @@ func generateNewClientDefinitions(g *protogen.GeneratedFile, service *protogen.S
g.P("}")
}

func generateNewIngressClientDefinitions(g *protogen.GeneratedFile, service *protogen.Service, clientName string) {
g.P("cOpts := append([]", clientPackage.Ident("IngressClientOption"), "{", sdkPackage.Ident("WithProtoJSON"), "}, opts...)")
g.P("return &", unexport(clientName), "{")
g.P("ctx,")
serviceType := proto.GetExtension(service.Desc.Options().(*descriptorpb.ServiceOptions), sdk.E_ServiceType).(sdk.ServiceType)
if serviceType == sdk.ServiceType_VIRTUAL_OBJECT {
g.P("key,")
}
g.P("cOpts,")
g.P("}")
}

func generateUnimplementedServerType(gen *protogen.Plugin, g *protogen.GeneratedFile, service *protogen.Service) {
serverType := service.GoName + "Server"
mustOrShould := "must"
Expand Down Expand Up @@ -162,7 +187,7 @@ func genService(gen *protogen.Plugin, g *protogen.GeneratedFile, service *protog
g.P(deprecationComment)
}
g.P(method.Comments.Leading,
clientSignature(g, method))
clientSignature(g, method, false))
}
g.P("}")
g.P()
Expand All @@ -188,17 +213,66 @@ func genService(gen *protogen.Plugin, g *protogen.GeneratedFile, service *protog
generateNewClientDefinitions(g, service, clientName)
g.P("}")

var methodIndex int
// Client method implementations.
for _, method := range service.Methods {
if !method.Desc.IsStreamingServer() && !method.Desc.IsStreamingClient() {
genClientMethod(gen, g, method)
methodIndex++
genClientMethod(gen, g, method, false)
} else {
gen.Error(fmt.Errorf("streaming methods are not currently supported in Restate."))
}
}

// Ingress client interface.
ingressClientName := service.GoName + "IngressClient"

g.P("// ", ingressClientName, " is the ingress client API for ", service.GoName, " service.")
g.P("//")

// Copy comments from proto file.
genServiceComments(g, service)

if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() {
g.P("//")
g.P(deprecationComment)
}
g.AnnotateSymbol(ingressClientName, protogen.Annotation{Location: service.Location})
g.P("type ", ingressClientName, " interface {")
for _, method := range service.Methods {
g.AnnotateSymbol(ingressClientName+"."+method.GoName, protogen.Annotation{Location: method.Location})
if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() {
g.P(deprecationComment)
}
g.P(method.Comments.Leading,
clientSignature(g, method, true))
}
g.P("}")
g.P()

// Ingress client structure.
generateIngressClientStruct(g, service, ingressClientName)

// NewIngressClient factory.
if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() {
g.P(deprecationComment)
}
g.P("// New", ingressClientName, " must be called with a ctx returned from github.com/restatedev/sdk-go/client.Connect")
newIngressClientSignature := "New" + ingressClientName + " (ctx " + g.QualifiedGoIdent(contextPackage.Ident("Context"))
if serviceType == sdk.ServiceType_VIRTUAL_OBJECT {
newIngressClientSignature += ", key string"
}
newIngressClientSignature += ", opts..." + g.QualifiedGoIdent(clientPackage.Ident("IngressClientOption")) + ") " + ingressClientName

g.P("func ", newIngressClientSignature, " {")
generateNewIngressClientDefinitions(g, service, ingressClientName)
g.P("}")

// Ingress method implementations.
for _, method := range service.Methods {
if !method.Desc.IsStreamingServer() && !method.Desc.IsStreamingClient() {
genClientMethod(gen, g, method, true)
}
}

mustOrShould := "must"
if !*requireUnimplemented {
mustOrShould = "should"
Expand Down Expand Up @@ -268,41 +342,64 @@ func genService(gen *protogen.Plugin, g *protogen.GeneratedFile, service *protog
g.P()
}

func clientSignature(g *protogen.GeneratedFile, method *protogen.Method) string {
func clientSignature(g *protogen.GeneratedFile, method *protogen.Method, ingress bool) string {
var optionName protogen.GoIdent
var clientName protogen.GoIdent
if ingress {
optionName = clientPackage.Ident("IngressClientOption")
clientName = clientPackage.Ident("IngressClient")
} else {
optionName = sdkPackage.Ident("ClientOption")
clientName = sdkPackage.Ident("Client")
}

s := method.GoName + "("
s += "opts ..." + g.QualifiedGoIdent(sdkPackage.Ident("ClientOption")) + ") ("
s += g.QualifiedGoIdent(sdkPackage.Ident("Client")) + "[" + "*" + g.QualifiedGoIdent(method.Input.GoIdent) + ", *" + g.QualifiedGoIdent(method.Output.GoIdent) + "]"
s += "opts ..." + g.QualifiedGoIdent(optionName) + ") ("
s += g.QualifiedGoIdent(clientName) + "[" + "*" + g.QualifiedGoIdent(method.Input.GoIdent) + ", *" + g.QualifiedGoIdent(method.Output.GoIdent) + "]"
s += ")"
return s
}

func genClientMethod(gen *protogen.Plugin, g *protogen.GeneratedFile, method *protogen.Method) {
func genClientMethod(gen *protogen.Plugin, g *protogen.GeneratedFile, method *protogen.Method, ingress bool) {
var pack protogen.GoImportPath
var clientSuffix string
var optionName protogen.GoIdent
if ingress {
pack = clientPackage
clientSuffix = "IngressClient"
optionName = clientPackage.Ident("IngressClientOption")
} else {
pack = sdkPackage
clientSuffix = "Client"
optionName = sdkPackage.Ident("ClientOption")
}

service := method.Parent
serviceType := proto.GetExtension(service.Desc.Options().(*descriptorpb.ServiceOptions), sdk.E_ServiceType).(sdk.ServiceType)

if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() {
g.P(deprecationComment)
}
g.P("func (c *", unexport(service.GoName), "Client) ", clientSignature(g, method), "{")
g.P("func (c *", unexport(service.GoName), clientSuffix, ") ", clientSignature(g, method, ingress), "{")

g.P("cOpts := c.options")
g.P("if len(opts) > 0 {")
g.P("cOpts = append(append([]sdk_go.ClientOption{}, cOpts...), opts...)")
g.P("cOpts = append(append([]", optionName, "{}, cOpts...), opts...)")
g.P("}")
var getClient string
switch serviceType {
case sdk.ServiceType_SERVICE:
getClient = g.QualifiedGoIdent(sdkPackage.Ident("Service")) + `[*` + g.QualifiedGoIdent(method.Output.GoIdent) + `]` + `(c.ctx, "` + service.GoName + `",`
getClient = g.QualifiedGoIdent(pack.Ident("Service")) + `[*` + g.QualifiedGoIdent(method.Output.GoIdent) + `]` + `(c.ctx, "` + service.GoName + `",`
case sdk.ServiceType_VIRTUAL_OBJECT:
getClient = g.QualifiedGoIdent(sdkPackage.Ident("Object")) + `[*` + g.QualifiedGoIdent(method.Output.GoIdent) + `]` + `(c.ctx, "` + service.GoName + `", c.key,`
getClient = g.QualifiedGoIdent(pack.Ident("Object")) + `[*` + g.QualifiedGoIdent(method.Output.GoIdent) + `]` + `(c.ctx, "` + service.GoName + `", c.key,`
case sdk.ServiceType_WORKFLOW:
getClient = g.QualifiedGoIdent(sdkPackage.Ident("Workflow")) + `[*` + g.QualifiedGoIdent(method.Output.GoIdent) + `]` + `(c.ctx, "` + service.GoName + `", c.workflowID,`
getClient = g.QualifiedGoIdent(pack.Ident("Workflow")) + `[*` + g.QualifiedGoIdent(method.Output.GoIdent) + `]` + `(c.ctx, "` + service.GoName + `", c.workflowID,`
default:
gen.Error(fmt.Errorf("Unexpected service type: %s", serviceType.String()))
return
}
getClient += `"` + method.GoName + `", cOpts...)`
g.P("return ", sdkPackage.Ident("WithRequestType"), "[*", method.Input.GoIdent, "]", `(`, getClient, `)`)
g.P("return ", pack.Ident("WithRequestType"), "[*", method.Input.GoIdent, "]", `(`, getClient, `)`)
g.P("}")
g.P()
return
Expand Down

0 comments on commit 07b19db

Please sign in to comment.