Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add identity implementation #8

Merged
merged 1 commit into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ module github.com/restatedev/sdk-go
go 1.22.0

require (
github.com/golang-jwt/jwt/v5 v5.2.1
github.com/google/uuid v1.6.0
github.com/mr-tron/base58 v1.2.0
github.com/posener/h2conn v0.0.0-20231204025407-3997deeca0f0
github.com/stretchr/testify v1.9.0
github.com/vmihailenco/msgpack/v5 v5.4.1
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/mr-tron/base58 v1.2.0 h1:T/HDJBh4ZCPbU39/+c3rRvE0uKBQlU27+QI8LJ4t64o=
github.com/mr-tron/base58 v1.2.0/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/posener/h2conn v0.0.0-20231204025407-3997deeca0f0 h1:zZg03nifrj6ayWNaDO8tNj57tqrOIKDwiBaLkhtK7Kk=
Expand Down
30 changes: 30 additions & 0 deletions internal/identity/identity.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package identity

import "fmt"

const SIGNATURE_SCHEME_HEADER = "X-Restate-Signature-Scheme"

type SignatureScheme string

var (
SchemeUnsigned SignatureScheme = "unsigned"
errMissingIdentity = fmt.Errorf("request has no identity")
)

func ValidateRequestIdentity(keySet KeySetV1, path string, headers map[string][]string) error {
switch len(headers[SIGNATURE_SCHEME_HEADER]) {
case 0:
return errMissingIdentity
case 1:
switch SignatureScheme(headers[SIGNATURE_SCHEME_HEADER][0]) {
case SchemeV1:
return validateV1(keySet, path, headers)
case SchemeUnsigned:
return errMissingIdentity
default:
return fmt.Errorf("unexpected signature scheme %v, allowed values are [%s %s]", headers[SIGNATURE_SCHEME_HEADER][0], SchemeUnsigned, SchemeV1)
}
default:
return fmt.Errorf("unexpected multi-value signature scheme header: %v", headers[SIGNATURE_SCHEME_HEADER])
}
}
83 changes: 83 additions & 0 deletions internal/identity/v1.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package identity

import (
"crypto/ed25519"
"fmt"
"strings"

jwt "github.com/golang-jwt/jwt/v5"
"github.com/mr-tron/base58"
)

const (
JWT_HEADER = "X-Restate-Jwt-V1"
SchemeV1 SignatureScheme = "v1"
)

type KeySetV1 = map[string]ed25519.PublicKey

func validateV1(keySet KeySetV1, path string, headers map[string][]string) error {
switch len(headers[JWT_HEADER]) {
case 0:
return fmt.Errorf("v1 signature scheme expects the following headers: [%s]", JWT_HEADER)
case 1:
default:
return fmt.Errorf("unexpected multi-value JWT header: %v", headers[JWT_HEADER])
}

token, err := jwt.Parse(headers[JWT_HEADER][0], func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodEd25519); !ok {
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
}

kid, ok := token.Header["kid"]
if !ok {
return nil, fmt.Errorf("Token missing 'kid' header field")
}

kidS, ok := kid.(string)
if !ok {
return nil, fmt.Errorf("Token 'kid' header field was not a string: %v", kid)
}

key, ok := keySet[kidS]
if !ok {
return nil, fmt.Errorf("Key ID %s is not present in key set", kid)
}

return key, nil
}, jwt.WithValidMethods([]string{"EdDSA"}), jwt.WithAudience(path), jwt.WithExpirationRequired())
if err != nil {
return fmt.Errorf("failed to validate v1 request identity jwt: %w", err)
}

nbf, _ := token.Claims.GetNotBefore()
if nbf == nil {
// jwt library only validates nbf if its present, so we should check it was present
return fmt.Errorf("'nbf' claim is missing in v1 request identity jwt")
}

return nil
}

func ParseKeySetV1(keys []string) (KeySetV1, error) {
out := make(KeySetV1, len(keys))
for _, key := range keys {
if !strings.HasPrefix(key, "publickeyv1_") {
return nil, fmt.Errorf("v1 public key must start with 'publickeyv1_'")
}

pubBytes, err := base58.Decode(key[len("publickeyv1_"):])
if err != nil {
return nil, fmt.Errorf("v1 public key must be valid base58: %w", err)
}

if len(pubBytes) != ed25519.PublicKeySize {
return nil, fmt.Errorf("v1 public key must have exactly %d bytes, found %d", ed25519.PublicKeySize, len(pubBytes))
}

out[key] = ed25519.PublicKey(pubBytes)
}

return out, nil
}
2 changes: 1 addition & 1 deletion internal/log/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func (t stringerValue[T]) LogValue() slog.Value {
}

func Stringer[T fmt.Stringer](key string, value T) slog.Attr {
return slog.Any(key, slog.AnyValue(stringerValue[T]{value}))
return slog.Any(key, stringerValue[T]{value})
}

func Error(err error) slog.Attr {
Expand Down
42 changes: 35 additions & 7 deletions server/restate.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/restatedev/sdk-go/generated/proto/discovery"
"github.com/restatedev/sdk-go/generated/proto/protocol"
"github.com/restatedev/sdk-go/internal"
"github.com/restatedev/sdk-go/internal/identity"
"github.com/restatedev/sdk-go/internal/log"
"github.com/restatedev/sdk-go/internal/state"
"golang.org/x/net/http2"
Expand Down Expand Up @@ -45,6 +46,8 @@ type Restate struct {
dropReplayLogs bool
systemLog *slog.Logger
routers map[string]restate.Router
keyIDs []string
keySet identity.KeySetV1
}

// NewRestate creates a new instance of Restate server
Expand All @@ -69,6 +72,11 @@ func (r *Restate) WithLogger(h slog.Handler, dropReplayLogs bool) *Restate {
return r
}

func (r *Restate) WithIdentityV1(keys ...string) *Restate {
r.keyIDs = append(r.keyIDs, keys...)
return r
}

func (r *Restate) Bind(router restate.Router) *Restate {
if _, ok := r.routers[router.Name()]; ok {
// panic because this is a programming error
Expand Down Expand Up @@ -120,38 +128,37 @@ func (r *Restate) discoverHandler(writer http.ResponseWriter, req *http.Request)

acceptVersionsString := req.Header.Get("accept")
if acceptVersionsString == "" {
writer.Write([]byte("missing accept header"))
writer.WriteHeader(http.StatusUnsupportedMediaType)
writer.Write([]byte("missing accept header"))

return
}

serviceDiscoveryProtocolVersion := selectSupportedServiceDiscoveryProtocolVersion(acceptVersionsString)

if serviceDiscoveryProtocolVersion == discovery.ServiceDiscoveryProtocolVersion_SERVICE_DISCOVERY_PROTOCOL_VERSION_UNSPECIFIED {
writer.Write([]byte(fmt.Sprintf("Unsupported service discovery protocol version '%s'", acceptVersionsString)))
writer.WriteHeader(http.StatusUnsupportedMediaType)
writer.Write([]byte(fmt.Sprintf("Unsupported service discovery protocol version '%s'", acceptVersionsString)))
return
}

response, err := r.discover()
if err != nil {
writer.Write([]byte(err.Error()))
writer.WriteHeader(http.StatusInternalServerError)
writer.Write([]byte(err.Error()))

return
}

bytes, err := json.Marshal(response)
if err != nil {
writer.Write([]byte(err.Error()))
writer.WriteHeader(http.StatusInternalServerError)
writer.Write([]byte(err.Error()))

return
}

writer.Header().Add("Content-Type", serviceDiscoveryProtocolVersionToHeaderValue(serviceDiscoveryProtocolVersion))
writer.WriteHeader(200)
if _, err := writer.Write(bytes); err != nil {
r.systemLog.LogAttrs(req.Context(), slog.LevelError, "Failed to write discovery information", log.Error(err))
}
Expand Down Expand Up @@ -252,6 +259,17 @@ func (r *Restate) callHandler(serviceProtocolVersion protocol.ServiceProtocolVer
}

func (r *Restate) handler(writer http.ResponseWriter, request *http.Request) {
if r.keySet != nil {
if err := identity.ValidateRequestIdentity(r.keySet, request.RequestURI, request.Header); err != nil {
r.systemLog.LogAttrs(request.Context(), slog.LevelError, "Rejecting request as its JWT did not validate", log.Error(err))

writer.WriteHeader(http.StatusUnauthorized)
writer.Write([]byte("Unauthorized"))

return
}
}

if request.RequestURI == "/discover" {
r.discoverHandler(writer, request)
return
Expand All @@ -261,8 +279,8 @@ func (r *Restate) handler(writer http.ResponseWriter, request *http.Request) {
if serviceProtocolVersionString == "" {
r.systemLog.ErrorContext(request.Context(), "Missing content-type header")

writer.Write([]byte("missing content-type header"))
writer.WriteHeader(http.StatusUnsupportedMediaType)
writer.Write([]byte("missing content-type header"))

return
}
Expand All @@ -272,8 +290,8 @@ func (r *Restate) handler(writer http.ResponseWriter, request *http.Request) {
if !isServiceProtocolVersionSupported(serviceProtocolVersion) {
r.systemLog.LogAttrs(request.Context(), slog.LevelError, "Unsupported service protocol version", slog.String("version", serviceProtocolVersionString))

writer.Write([]byte(fmt.Sprintf("Unsupported service protocol version '%s'", serviceProtocolVersionString)))
writer.WriteHeader(http.StatusUnsupportedMediaType)
writer.Write([]byte(fmt.Sprintf("Unsupported service protocol version '%s'", serviceProtocolVersionString)))

return
}
Expand All @@ -297,6 +315,16 @@ func (r *Restate) handler(writer http.ResponseWriter, request *http.Request) {
}

func (r *Restate) Start(ctx context.Context, address string) error {
if r.keyIDs == nil {
r.systemLog.WarnContext(ctx, "Accepting requests without validating request signatures; handler access must be restricted")
} else {
ks, err := identity.ParseKeySetV1(r.keyIDs)
if err != nil {
return fmt.Errorf("invalid request identity keys: %w", err)
}
r.keySet = ks
r.systemLog.LogAttrs(ctx, slog.LevelInfo, "Validating requests using signing keys", slog.Any("keys", r.keyIDs))
}

listener, err := net.Listen("tcp", address)
if err != nil {
Expand Down
Loading