Skip to content

Commit

Permalink
App 2076: Creating public methods for a server (#194)
Browse files Browse the repository at this point in the history
  • Loading branch information
RoxyFarhad authored Jul 20, 2023
1 parent 4f73972 commit ad2767f
Show file tree
Hide file tree
Showing 5 changed files with 209 additions and 12 deletions.
6 changes: 5 additions & 1 deletion rpc/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@ func ContextWithAuthEntity(ctx context.Context, authEntity EntityInfo) context.C

// ContextAuthEntity returns the entity (e.g. a user) associated with this authentication context.
func ContextAuthEntity(ctx context.Context) (EntityInfo, bool) {
authEntity, ok := ctx.Value(ctxKeyAuthEntity).(EntityInfo)
authEntityValue := ctx.Value(ctxKeyAuthEntity)
if authEntityValue == nil {
return EntityInfo{}, false
}
authEntity, ok := authEntityValue.(EntityInfo)
if !ok || authEntity.Entity == "" {
return EntityInfo{}, false
}
Expand Down
18 changes: 13 additions & 5 deletions rpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,14 @@ type simpleServer struct {
signalingCallQueue WebRTCCallQueue
signalingServer *WebRTCSignalingServer
mdnsServers []*zeroconf.Server
exemptMethods map[string]bool
tlsConfig *tls.Config
firstSeenTLSCertLeaf *x509.Certificate
stopped bool
logger golog.Logger
// exempt methods do not perform any auth
exemptMethods map[string]bool
// public methods attempt, but do not require, authentication
publicMethods map[string]bool
tlsConfig *tls.Config
firstSeenTLSCertLeaf *x509.Certificate
stopped bool
logger golog.Logger

// auth

Expand Down Expand Up @@ -260,6 +263,7 @@ func NewServer(logger golog.Logger, opts ...ServerOption) (Server, error) {
authAudience: sOpts.authAudience,
authIssuer: sOpts.authIssuer,
exemptMethods: make(map[string]bool),
publicMethods: make(map[string]bool),
tlsConfig: sOpts.tlsConfig,
firstSeenTLSCertLeaf: firstSeenTLSCertLeaf,
logger: logger,
Expand Down Expand Up @@ -375,6 +379,10 @@ func NewServer(logger golog.Logger, opts ...ServerOption) (Server, error) {
server.exemptMethods[healthWatchMethod] = true
}

for _, method := range sOpts.publicMethods {
server.publicMethods[method] = true
}

if sOpts.authToHandler != nil {
if err := server.RegisterServiceServer(
context.Background(),
Expand Down
57 changes: 51 additions & 6 deletions rpc/server_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,20 +162,39 @@ func (ss *simpleServer) signAccessTokenForEntity(
return tokenString, nil
}

func (ss *simpleServer) isPublicMethod(
fullMethod string,
) bool {
return ss.publicMethods[fullMethod]
}

func (ss *simpleServer) authUnaryInterceptor(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
if !ss.exemptMethods[info.FullMethod] {
nextCtx, err := ss.ensureAuthed(ctx)
// no auth
if ss.exemptMethods[info.FullMethod] {
return handler(ctx, req)
}

// optional auth
if ss.isPublicMethod(info.FullMethod) {
nextCtx, err := ss.tryAuth(ctx)
if err != nil {
return nil, err
}
ctx = nextCtx
return handler(nextCtx, req)
}
return handler(ctx, req)

// private auth
nextCtx, err := ss.ensureAuthed(ctx)
if err != nil {
return nil, err
}

return handler(nextCtx, req)
}

func (ss *simpleServer) authStreamInterceptor(
Expand All @@ -184,13 +203,27 @@ func (ss *simpleServer) authStreamInterceptor(
info *grpc.StreamServerInfo,
handler grpc.StreamHandler,
) error {
if !ss.exemptMethods[info.FullMethod] {
nextCtx, err := ss.ensureAuthed(serverStream.Context())
if ss.exemptMethods[info.FullMethod] {
return handler(srv, serverStream)
}

// optional auth
if ss.isPublicMethod(info.FullMethod) {
nextCtx, err := ss.tryAuth(serverStream.Context())
if err != nil {
return err
}
serverStream = ctxWrappedServerStream{serverStream, nextCtx}
return handler(nextCtx, serverStream)
}

// private auth
nextCtx, err := ss.ensureAuthed(serverStream.Context())
if err != nil {
return err
}

serverStream = ctxWrappedServerStream{serverStream, nextCtx}
return handler(srv, serverStream)
}

Expand Down Expand Up @@ -236,6 +269,18 @@ var validSigningMethods = []string{
"RS256",
}

// tryAuth is called for public methods where auth is not required but preferable.
func (ss *simpleServer) tryAuth(ctx context.Context) (context.Context, error) {
nextCtx, err := ss.ensureAuthed(ctx)
if err != nil {
if status, _ := status.FromError(err); status.Code() != codes.Unauthenticated {
return nil, err
}
return ctx, nil
}
return nextCtx, nil
}

func (ss *simpleServer) ensureAuthed(ctx context.Context) (context.Context, error) {
tokenString, err := tokenFromContext(ctx)
if err != nil {
Expand Down
129 changes: 129 additions & 0 deletions rpc/server_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,135 @@ func TestServerAuthJWTAudienceAndID(t *testing.T) {
}
}

func TestServerPublicMethods(t *testing.T) {
logger := golog.NewTestLogger(t)

t.Run("NoAuthSet", func(t *testing.T) {
// this is an authenticated server - using the default auth service on server
rpcServer, err := NewServer(logger,
WithPublicMethods([]string{
"/proto.rpc.examples.echo.v1.EchoService/Echo",
"/proto.rpc.examples.echo.v1.EchoService/EchoMultiple",
}),
)
defer rpcServer.Stop()
test.That(t, err, test.ShouldBeNil)
es := echoserver.Server{}
err = rpcServer.RegisterServiceServer(
context.Background(),
&pb.EchoService_ServiceDesc,
&es,
pb.RegisterEchoServiceHandlerFromEndpoint,
)
test.That(t, err, test.ShouldBeNil)

listener, err := net.Listen("tcp", "localhost:0")
test.That(t, err, test.ShouldBeNil)
grpcOpts := []grpc.DialOption{
grpc.WithBlock(),
grpc.WithTransportCredentials(insecure.NewCredentials()),
}

errChan := make(chan error)
go func() {
errChan <- rpcServer.Serve(listener)
}()

conn, err := grpc.DialContext(context.Background(), listener.Addr().String(), grpcOpts...)
test.That(t, err, test.ShouldBeNil)
defer func() {
test.That(t, conn.Close(), test.ShouldBeNil)
}()
client := pb.NewEchoServiceClient(conn)
echoResp, err := client.Echo(context.Background(), &pb.EchoRequest{Message: "hello"})
test.That(t, err, test.ShouldBeNil)
test.That(t, echoResp, test.ShouldNotBeNil)
test.That(t, echoResp.Message, test.ShouldEqual, "hello")

// test the stream service
_, err = client.EchoMultiple(context.Background(), &pb.EchoMultipleRequest{Message: "hello"})
test.That(t, err, test.ShouldBeNil)
err = <-errChan
test.That(t, err, test.ShouldBeNil)
})

t.Run("Given an authenticated client, they can still access the public API", func(t *testing.T) {
testPrivKey, err := rsa.GenerateKey(rand.Reader, 512)
test.That(t, err, test.ShouldBeNil)

rpcServer, err := NewServer(logger,
// this is the main echo method
WithPublicMethods([]string{
"/proto.rpc.examples.echo.v1.EchoService/Echo",
"/proto.rpc.examples.echo.v1.EchoService/EchoMultiple",
}),
WithAuthHandler("fake", AuthHandlerFunc(func(ctx context.Context, entity, payload string) (map[string]string, error) {
return map[string]string{}, nil
})),
WithAuthRSAPrivateKey(testPrivKey),
)

defer rpcServer.Stop()
test.That(t, err, test.ShouldBeNil)
es := echoserver.Server{}
err = rpcServer.RegisterServiceServer(
context.Background(),
&pb.EchoService_ServiceDesc,
&es,
pb.RegisterEchoServiceHandlerFromEndpoint,
)
test.That(t, err, test.ShouldBeNil)

listener, err := net.Listen("tcp", "localhost:0")
test.That(t, err, test.ShouldBeNil)
grpcOpts := []grpc.DialOption{
grpc.WithBlock(),
grpc.WithTransportCredentials(insecure.NewCredentials()),
}

errChan := make(chan error)
go func() {
errChan <- rpcServer.Serve(listener)
}()

conn, err := grpc.DialContext(context.Background(), listener.Addr().String(), grpcOpts...)
test.That(t, err, test.ShouldBeNil)
defer func() {
test.That(t, conn.Close(), test.ShouldBeNil)
}()

// setup for auth stuff
authClient := rpcpb.NewAuthServiceClient(conn)
authResp, err := authClient.Authenticate(
context.Background(), &rpcpb.AuthenticateRequest{Entity: "foo", Credentials: &rpcpb.Credentials{
Type: "fake",
Payload: "something",
}})
test.That(t, err, test.ShouldBeNil)
_, err = jwt.Parse(authResp.AccessToken, func(token *jwt.Token) (interface{}, error) {
return &testPrivKey.PublicKey, nil
})
test.That(t, err, test.ShouldBeNil)

md := make(metadata.MD)
bearer := fmt.Sprintf("Bearer %s", authResp.AccessToken)
md.Set("authorization", bearer)
ctx := metadata.NewOutgoingContext(context.Background(), md)

client := pb.NewEchoServiceClient(conn)
echoResp, err := client.Echo(ctx, &pb.EchoRequest{Message: "hello"})
test.That(t, err, test.ShouldBeNil)
test.That(t, echoResp, test.ShouldNotBeNil)
test.That(t, echoResp.Message, test.ShouldEqual, "hello")

// test the stream service
_, err = client.EchoMultiple(context.Background(), &pb.EchoMultipleRequest{Message: "hello"})
test.That(t, err, test.ShouldBeNil)
err = <-errChan
test.That(t, err, test.ShouldBeNil)
})
}

func TestServerAuthKeyFunc(t *testing.T) {
testutils.SkipUnlessInternet(t)
logger := golog.NewTestLogger(t)
Expand Down
11 changes: 11 additions & 0 deletions rpc/server_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ type serverOptions struct {
// allowUnauthenticatedHealthCheck allows the server to have an unauthenticated healthcheck endpoint
allowUnauthenticatedHealthCheck bool

// publicMethods are api routes that attempt, but do not require, authentication
publicMethods []string

// authRSAPrivateKey is used to sign JWTs for authentication
authRSAPrivateKey *rsa.PrivateKey

Expand Down Expand Up @@ -434,3 +437,11 @@ func WithAllowUnauthenticatedHealthCheck() ServerOption {
return nil
})
}

// WithPublicMethods returns a server option with grpc methods that can bypass auth validation.
func WithPublicMethods(fullMethods []string) ServerOption {
return newFuncServerOption(func(o *serverOptions) error {
o.publicMethods = fullMethods
return nil
})
}

0 comments on commit ad2767f

Please sign in to comment.