-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Pass over caller identity information through the sansshell proxy
For multi party authentication support, we plan on storing state in-memory in the sanshell server so the server needs to know the original caller's identity. Right now that gets lost when the caller goes through the proxy. This PR fixes that by including a json blob of marshalled rpcauth.PrincipalAuthInput information in the gRPC context sent from the proxy to the server. We use json and not base64-encoded proto because the size is unlikely to be significant and there's no existing common proto with the fields we want. I'm unconditionally sending the information when principal is populated so that we don't expose an excessive number of behavioral knobs. I've tweaked rpcauth.PeerInputFromContext so that it'll preserve information about the peer that gets added by rpcauth hooks. The API design intentionally tries to avoid exposing easy ways to use the gRPC metadata without checking its validity. We assume that in some cases sansshell server will have both proxied clients and direct clients. Most direct clients should be unable to pass in arbitrary proxied identity information. The interceptor for getting proxied identity is only implemented for unary RPCs because we don't yet have a use case for adding it to streaming RPCs. Part of #346
- Loading branch information
1 parent
d7d6050
commit 10cd2fb
Showing
5 changed files
with
253 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
// Package proxiedidentity provides a way to pass the identity of an end user | ||
// through the SansShell proxy | ||
package proxiedidentity | ||
|
||
import ( | ||
"context" | ||
"encoding/json" | ||
"errors" | ||
|
||
"github.com/Snowflake-Labs/sansshell/auth/opa/rpcauth" | ||
"google.golang.org/grpc" | ||
"google.golang.org/grpc/metadata" | ||
) | ||
|
||
const reqProxiedIdentityKey = "sansshell-proxied-identity" | ||
|
||
// ServerProxiedIdentityUnaryInterceptor adds information about a proxied caller to the RPC context | ||
// if the provided function returns true. Allow functions will typically pull out information on the | ||
// caller's identity from the context with https://godoc.org/google.golang.org/grpc/peer to decide | ||
// if the addition is allowed. | ||
func ServerProxiedIdentityUnaryInterceptor(allow func(context.Context) bool) grpc.UnaryServerInterceptor { | ||
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { | ||
identity, ok := fromMetadataInContext(ctx) | ||
if !ok { | ||
// No need to do anything more if there's no proxied identity | ||
return handler(ctx, req) | ||
} | ||
if !allow(ctx) { | ||
return nil, errors.New("peer not allowed to proxy identities") | ||
} | ||
ctx = newContext(ctx, identity) | ||
return handler(ctx, req) | ||
} | ||
} | ||
|
||
type proxiedIdentityKey struct{} | ||
|
||
// newContext creates a new context with the identity attached. | ||
func newContext(ctx context.Context, p *rpcauth.PrincipalAuthInput) context.Context { | ||
return context.WithValue(ctx, proxiedIdentityKey{}, p) | ||
} | ||
|
||
// FromContext returns the identity in ctx if it exists. It will typically | ||
// only exist if ServerProxiedIdentityUnaryInterceptor was used. | ||
func FromContext(ctx context.Context) (p *rpcauth.PrincipalAuthInput, ok bool) { | ||
p, ok = ctx.Value(proxiedIdentityKey{}).(*rpcauth.PrincipalAuthInput) | ||
return | ||
} | ||
|
||
// AppendToMetadataInOutgoingContext includes the identity in the grpc metadata | ||
// used in outgoing calls with the context. | ||
func AppendToMetadataInOutgoingContext(ctx context.Context, p *rpcauth.PrincipalAuthInput) context.Context { | ||
b, err := json.Marshal(p) | ||
if err != nil { | ||
// There shouldn't be any possible value of PrincipalAuthInput that fails to marshal, so let's | ||
// return the original context so that the caller doesn't need to consider failures. | ||
return ctx | ||
} | ||
return metadata.AppendToOutgoingContext(ctx, reqProxiedIdentityKey, string(b)) | ||
} | ||
|
||
// fromMetadataInContext fetches the identity from the grpc metadata | ||
// embedded within the context if it exists. If using this, ensure | ||
// that the metadata comes from a trusted source. | ||
func fromMetadataInContext(ctx context.Context) (p *rpcauth.PrincipalAuthInput, ok bool) { | ||
md, ok := metadata.FromIncomingContext(ctx) | ||
if !ok { | ||
return nil, false | ||
} | ||
identity := md.Get(reqProxiedIdentityKey) | ||
if len(identity) != 1 { | ||
// No need to do anything more if there's no proxied identity | ||
return nil, false | ||
} | ||
|
||
parsed := new(rpcauth.PrincipalAuthInput) | ||
if err := json.Unmarshal([]byte(identity[0]), parsed); err != nil { | ||
return nil, false | ||
} | ||
return parsed, true | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
package proxiedidentity | ||
|
||
import ( | ||
"context" | ||
"net" | ||
"reflect" | ||
"testing" | ||
|
||
"github.com/Snowflake-Labs/sansshell/auth/opa/rpcauth" | ||
healthcheckpb "github.com/Snowflake-Labs/sansshell/services/healthcheck" | ||
"google.golang.org/grpc" | ||
"google.golang.org/grpc/credentials/insecure" | ||
"google.golang.org/grpc/metadata" | ||
"google.golang.org/grpc/test/bufconn" | ||
"google.golang.org/protobuf/types/known/emptypb" | ||
) | ||
|
||
type fakeHealthCheck struct { | ||
callback func(context.Context) | ||
} | ||
|
||
func (h *fakeHealthCheck) Ok(ctx context.Context, req *emptypb.Empty) (*emptypb.Empty, error) { | ||
h.callback(ctx) | ||
return &emptypb.Empty{}, nil | ||
} | ||
|
||
func TestProxyingIdentityOverRPC(t *testing.T) { | ||
ctx := context.Background() | ||
|
||
for _, tc := range []struct { | ||
desc string | ||
srvInterceptors []grpc.UnaryServerInterceptor | ||
identityProxied bool | ||
rpcError bool | ||
}{ | ||
{ | ||
desc: "interceptor missing", | ||
identityProxied: false, | ||
}, | ||
{ | ||
desc: "passed", | ||
srvInterceptors: []grpc.UnaryServerInterceptor{ServerProxiedIdentityUnaryInterceptor(func(context.Context) bool { return true })}, | ||
identityProxied: true, | ||
}, | ||
{ | ||
desc: "interceptor says no", | ||
srvInterceptors: []grpc.UnaryServerInterceptor{ServerProxiedIdentityUnaryInterceptor(func(context.Context) bool { return false })}, | ||
rpcError: true, | ||
identityProxied: false, | ||
}, | ||
} { | ||
ctx := ctx | ||
t.Run(tc.desc, func(t *testing.T) { | ||
|
||
buffer := 1024 | ||
lis := bufconn.Listen(buffer) | ||
bufdial := func(context.Context, string) (net.Conn, error) { return lis.Dial() } | ||
srv := grpc.NewServer(grpc.ChainUnaryInterceptor(tc.srvInterceptors...)) | ||
healthcheck := &fakeHealthCheck{} | ||
healthcheckpb.RegisterHealthCheckServer(srv, healthcheck) | ||
go func() { | ||
if err := srv.Serve(lis); err != nil { | ||
panic(err) | ||
} | ||
}() | ||
|
||
conn, err := grpc.DialContext(ctx, "", grpc.WithContextDialer(bufdial), grpc.WithTransportCredentials(insecure.NewCredentials())) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
client := healthcheckpb.NewHealthCheckClient(conn) | ||
|
||
identity := &rpcauth.PrincipalAuthInput{ | ||
ID: "foobar", | ||
Groups: []string{"baz"}, | ||
} | ||
|
||
ctx = AppendToMetadataInOutgoingContext(ctx, identity) | ||
var gotIdentity *rpcauth.PrincipalAuthInput | ||
var idOk bool | ||
var gotMetadata []string | ||
healthcheck.callback = func(ctx context.Context) { | ||
gotIdentity, idOk = FromContext(ctx) | ||
md, _ := metadata.FromIncomingContext(ctx) | ||
gotMetadata = md.Get(reqProxiedIdentityKey) | ||
} | ||
if _, err := client.Ok(ctx, &emptypb.Empty{}); err != nil { | ||
if tc.rpcError { | ||
return | ||
} | ||
t.Fatal(err) | ||
} | ||
if tc.rpcError { | ||
t.Error("rpc error was missing") | ||
} | ||
|
||
if tc.identityProxied { | ||
if !reflect.DeepEqual(gotIdentity, identity) { | ||
t.Errorf("got %+v, want %+v", gotIdentity, identity) | ||
} | ||
if !idOk { | ||
t.Error("FromContext was unexpectedly not ok") | ||
} | ||
} else { | ||
if gotIdentity != nil { | ||
t.Errorf("identity unexpectedly not nil: %+v", gotIdentity) | ||
} | ||
if idOk { | ||
t.Error("FromContext was unexpectedly ok") | ||
} | ||
} | ||
if len(gotMetadata) != 1 { | ||
t.Errorf("expected exactly one metadata val, got %v", gotMetadata) | ||
} else if gotMetadata[0] != `{"id":"foobar","groups":["baz"]}` { | ||
t.Errorf("metadata did not match expectation, got %v", gotMetadata[0]) | ||
} | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters