Skip to content

Commit

Permalink
Pass over caller identity information through the sansshell proxy
Browse files Browse the repository at this point in the history
For multi party authentication support, we plan on storing state in-memory in the sanshell server so the server needs to know the original caller's identity. Right now that gets lost when the caller goes through the proxy.

This PR fixes that by including a json blob of marshalled rpcauth.PrincipalAuthInput information in the gRPC context sent from the proxy to the server. We use json and not base64-encoded proto because the size is unlikely to be significant and there's no existing common proto with the fields we want. I'm unconditionally sending the information when principal is populated so that we don't expose an excessive number of behavioral knobs. I've tweaked rpcauth.PeerInputFromContext so that it'll preserve information about the peer that gets added by rpcauth hooks.

The API design intentionally tries to avoid exposing easy ways to use the gRPC metadata without checking its validity. We assume that in some cases sansshell server will have both proxied clients and direct clients. Most direct clients should be unable to pass in arbitrary proxied identity information.

The interceptor for getting proxied identity is only implemented for unary RPCs because we don't yet have a use case for adding it to streaming RPCs.

Part of #346
  • Loading branch information
stvnrhodes committed Oct 20, 2023
1 parent d7d6050 commit 10cd2fb
Show file tree
Hide file tree
Showing 5 changed files with 253 additions and 0 deletions.
81 changes: 81 additions & 0 deletions auth/opa/proxiedidentity/proxiedidentity.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// Package proxiedidentity provides a way to pass the identity of an end user
// through the SansShell proxy
package proxiedidentity

import (
"context"
"encoding/json"
"errors"

"github.com/Snowflake-Labs/sansshell/auth/opa/rpcauth"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)

const reqProxiedIdentityKey = "sansshell-proxied-identity"

// ServerProxiedIdentityUnaryInterceptor adds information about a proxied caller to the RPC context
// if the provided function returns true. Allow functions will typically pull out information on the
// caller's identity from the context with https://godoc.org/google.golang.org/grpc/peer to decide
// if the addition is allowed.
func ServerProxiedIdentityUnaryInterceptor(allow func(context.Context) bool) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
identity, ok := fromMetadataInContext(ctx)
if !ok {
// No need to do anything more if there's no proxied identity
return handler(ctx, req)
}
if !allow(ctx) {
return nil, errors.New("peer not allowed to proxy identities")
}
ctx = newContext(ctx, identity)
return handler(ctx, req)
}
}

type proxiedIdentityKey struct{}

// newContext creates a new context with the identity attached.
func newContext(ctx context.Context, p *rpcauth.PrincipalAuthInput) context.Context {
return context.WithValue(ctx, proxiedIdentityKey{}, p)
}

// FromContext returns the identity in ctx if it exists. It will typically
// only exist if ServerProxiedIdentityUnaryInterceptor was used.
func FromContext(ctx context.Context) (p *rpcauth.PrincipalAuthInput, ok bool) {
p, ok = ctx.Value(proxiedIdentityKey{}).(*rpcauth.PrincipalAuthInput)
return
}

// AppendToMetadataInOutgoingContext includes the identity in the grpc metadata
// used in outgoing calls with the context.
func AppendToMetadataInOutgoingContext(ctx context.Context, p *rpcauth.PrincipalAuthInput) context.Context {
b, err := json.Marshal(p)
if err != nil {
// There shouldn't be any possible value of PrincipalAuthInput that fails to marshal, so let's
// return the original context so that the caller doesn't need to consider failures.
return ctx
}
return metadata.AppendToOutgoingContext(ctx, reqProxiedIdentityKey, string(b))
}

// fromMetadataInContext fetches the identity from the grpc metadata
// embedded within the context if it exists. If using this, ensure
// that the metadata comes from a trusted source.
func fromMetadataInContext(ctx context.Context) (p *rpcauth.PrincipalAuthInput, ok bool) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, false
}
identity := md.Get(reqProxiedIdentityKey)
if len(identity) != 1 {
// No need to do anything more if there's no proxied identity
return nil, false
}

parsed := new(rpcauth.PrincipalAuthInput)
if err := json.Unmarshal([]byte(identity[0]), parsed); err != nil {
return nil, false
}
return parsed, true
}
119 changes: 119 additions & 0 deletions auth/opa/proxiedidentity/proxiedidentity_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package proxiedidentity

import (
"context"
"net"
"reflect"
"testing"

"github.com/Snowflake-Labs/sansshell/auth/opa/rpcauth"
healthcheckpb "github.com/Snowflake-Labs/sansshell/services/healthcheck"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/test/bufconn"
"google.golang.org/protobuf/types/known/emptypb"
)

type fakeHealthCheck struct {
callback func(context.Context)
}

func (h *fakeHealthCheck) Ok(ctx context.Context, req *emptypb.Empty) (*emptypb.Empty, error) {
h.callback(ctx)
return &emptypb.Empty{}, nil
}

func TestProxyingIdentityOverRPC(t *testing.T) {
ctx := context.Background()

for _, tc := range []struct {
desc string
srvInterceptors []grpc.UnaryServerInterceptor
identityProxied bool
rpcError bool
}{
{
desc: "interceptor missing",
identityProxied: false,
},
{
desc: "passed",
srvInterceptors: []grpc.UnaryServerInterceptor{ServerProxiedIdentityUnaryInterceptor(func(context.Context) bool { return true })},
identityProxied: true,
},
{
desc: "interceptor says no",
srvInterceptors: []grpc.UnaryServerInterceptor{ServerProxiedIdentityUnaryInterceptor(func(context.Context) bool { return false })},
rpcError: true,
identityProxied: false,
},
} {
ctx := ctx
t.Run(tc.desc, func(t *testing.T) {

buffer := 1024
lis := bufconn.Listen(buffer)
bufdial := func(context.Context, string) (net.Conn, error) { return lis.Dial() }
srv := grpc.NewServer(grpc.ChainUnaryInterceptor(tc.srvInterceptors...))
healthcheck := &fakeHealthCheck{}
healthcheckpb.RegisterHealthCheckServer(srv, healthcheck)
go func() {
if err := srv.Serve(lis); err != nil {
panic(err)
}
}()

conn, err := grpc.DialContext(ctx, "", grpc.WithContextDialer(bufdial), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatal(err)
}
client := healthcheckpb.NewHealthCheckClient(conn)

identity := &rpcauth.PrincipalAuthInput{
ID: "foobar",
Groups: []string{"baz"},
}

ctx = AppendToMetadataInOutgoingContext(ctx, identity)
var gotIdentity *rpcauth.PrincipalAuthInput
var idOk bool
var gotMetadata []string
healthcheck.callback = func(ctx context.Context) {
gotIdentity, idOk = FromContext(ctx)
md, _ := metadata.FromIncomingContext(ctx)
gotMetadata = md.Get(reqProxiedIdentityKey)
}
if _, err := client.Ok(ctx, &emptypb.Empty{}); err != nil {
if tc.rpcError {
return
}
t.Fatal(err)
}
if tc.rpcError {
t.Error("rpc error was missing")
}

if tc.identityProxied {
if !reflect.DeepEqual(gotIdentity, identity) {
t.Errorf("got %+v, want %+v", gotIdentity, identity)
}
if !idOk {
t.Error("FromContext was unexpectedly not ok")
}
} else {
if gotIdentity != nil {
t.Errorf("identity unexpectedly not nil: %+v", gotIdentity)
}
if idOk {
t.Error("FromContext was unexpectedly ok")
}
}
if len(gotMetadata) != 1 {
t.Errorf("expected exactly one metadata val, got %v", gotMetadata)
} else if gotMetadata[0] != `{"id":"foobar","groups":["baz"]}` {
t.Errorf("metadata did not match expectation, got %v", gotMetadata[0])
}
})
}
}
17 changes: 17 additions & 0 deletions auth/opa/rpcauth/input.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,26 @@ func NewRPCAuthInput(ctx context.Context, method string, req proto.Message) (*RP
return out, nil
}

type peerInfoKey struct{}

func addPeerToContext(ctx context.Context, p *PeerAuthInput) context.Context {
if p == nil {
return ctx
}
return context.WithValue(ctx, peerInfoKey{}, p)
}

// PeerInputFromContext populates peer information from the supplied
// context, if available.
func PeerInputFromContext(ctx context.Context) *PeerAuthInput {
// If this runs after rpcauth hooks, we can return richer data that includes
// information added by the hooks.
cached, ok := ctx.Value(peerInfoKey{}).(*PeerAuthInput)
if ok {
return cached
}

// If it runs before our rpcauth hooks, let's return the data as best we can.
out := &PeerAuthInput{}
p, ok := peer.FromContext(ctx)
if !ok {
Expand Down
29 changes: 29 additions & 0 deletions auth/opa/rpcauth/rpcauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"context"
"fmt"
"strings"
"sync"

"github.com/go-logr/logr"
"go.opentelemetry.io/otel/attribute"
Expand Down Expand Up @@ -171,6 +172,7 @@ func (g *Authorizer) Authorize(ctx context.Context, req interface{}, info *grpc.
if err := g.Eval(ctx, authInput); err != nil {
return nil, err
}
ctx = addPeerToContext(ctx, authInput.Peer)
return handler(ctx, req)
}

Expand All @@ -187,6 +189,7 @@ func (g *Authorizer) AuthorizeClient(ctx context.Context, method string, req, re
if err := g.Eval(ctx, authInput); err != nil {
return err
}
ctx = addPeerToContext(ctx, authInput.Peer)
return invoker(ctx, method, req, reply, cc, opts...)
}

Expand All @@ -209,6 +212,16 @@ type wrappedClientStream struct {
grpc.ClientStream
method string
authz *Authorizer

peerMu sync.Mutex
lastPeerAuthInput *PeerAuthInput
}

func (e *wrappedClientStream) Context() context.Context {
e.peerMu.Lock()
ctx := addPeerToContext(e.ClientStream.Context(), e.lastPeerAuthInput)
e.peerMu.Unlock()
return ctx
}

// see: grpc.ClientStream.SendMsg
Expand All @@ -225,6 +238,9 @@ func (e *wrappedClientStream) SendMsg(req interface{}) error {
if err := e.authz.Eval(ctx, authInput); err != nil {
return err
}
e.peerMu.Lock()
e.lastPeerAuthInput = authInput.Peer
e.peerMu.Unlock()
return e.ClientStream.SendMsg(req)
}

Expand All @@ -243,6 +259,16 @@ type wrappedStream struct {
grpc.ServerStream
info *grpc.StreamServerInfo
authz *Authorizer

peerMu sync.Mutex
lastPeerAuthInput *PeerAuthInput
}

func (e *wrappedStream) Context() context.Context {
e.peerMu.Lock()
ctx := addPeerToContext(e.ServerStream.Context(), e.lastPeerAuthInput)
e.peerMu.Unlock()
return ctx
}

// see: grpc.ServerStream.RecvMsg
Expand All @@ -266,5 +292,8 @@ func (e *wrappedStream) RecvMsg(req interface{}) error {
if err := e.authz.Eval(ctx, authInput); err != nil {
return err
}
e.peerMu.Lock()
e.lastPeerAuthInput = authInput.Peer
e.peerMu.Unlock()
return nil
}
7 changes: 7 additions & 0 deletions proxy/server/target.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"

"github.com/Snowflake-Labs/sansshell/auth/opa/proxiedidentity"
"github.com/Snowflake-Labs/sansshell/auth/opa/rpcauth"
pb "github.com/Snowflake-Labs/sansshell/proxy"
)
Expand Down Expand Up @@ -189,6 +190,12 @@ func (s *TargetStream) Send(req proto.Message) error {
func (s *TargetStream) Run(nonce uint32, replyChan chan *pb.ProxyReply) {
group, ctx := errgroup.WithContext(s.ctx)

peer := rpcauth.PeerInputFromContext(ctx)
if peer != nil && peer.Principal != nil {
// Unconditionally add information on the original caller to outgoing RPCs
ctx = proxiedidentity.AppendToMetadataInOutgoingContext(ctx, peer.Principal)
}

group.Go(func() error {
dialCtx, cancel := context.WithCancel(ctx)
var opts []grpc.DialOption
Expand Down

0 comments on commit 10cd2fb

Please sign in to comment.