From 6a47b280274e1a6d39f2fa5e14d326d285f32e81 Mon Sep 17 00:00:00 2001 From: Jeroen Rinzema Date: Sun, 11 Jun 2023 11:08:43 +0200 Subject: [PATCH] feat: include the ability to set context values during auth --- auth.go | 26 +++++++++++++------------- auth_test.go | 28 +++++++++++----------------- wire.go | 2 +- 3 files changed, 25 insertions(+), 31 deletions(-) diff --git a/auth.go b/auth.go index 315d917..f1cf5b8 100644 --- a/auth.go +++ b/auth.go @@ -23,19 +23,19 @@ const ( ) // AuthStrategy represents a authentication strategy used to authenticate a user -type AuthStrategy func(ctx context.Context, writer *buffer.Writer, reader *buffer.Reader) (err error) +type AuthStrategy func(ctx context.Context, writer *buffer.Writer, reader *buffer.Reader) (_ context.Context, err error) // handleAuth handles the client authentication for the given connection. // This methods validates the incoming credentials and writes to the client whether // the provided credentials are correct. When the provided credentials are invalid // or any unexpected error occures is an error returned and should the connection be closed. -func (srv *Server) handleAuth(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error { +func (srv *Server) handleAuth(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) (context.Context, error) { srv.logger.Debug("authenticating client connection") if srv.Auth == nil { // No authentication strategy configured. // Announcing to the client that the connection is authenticated - return writeAuthType(writer, authOK) + return ctx, writeAuthType(writer, authOK) } return srv.Auth(ctx, writer, reader) @@ -45,38 +45,38 @@ func (srv *Server) handleAuth(ctx context.Context, reader *buffer.Reader, writer // clear text password and validates if the provided username and password (received // inside the client parameters) are valid. If the provided credentials are invalid // or any unexpected error occures is an error returned and should the connection be closed. -func ClearTextPassword(validate func(username, password string) (bool, error)) AuthStrategy { - return func(ctx context.Context, writer *buffer.Writer, reader *buffer.Reader) (err error) { +func ClearTextPassword(validate func(ctx context.Context, username, password string) (context.Context, bool, error)) AuthStrategy { + return func(ctx context.Context, writer *buffer.Writer, reader *buffer.Reader) (_ context.Context, err error) { err = writeAuthType(writer, authClearTextPassword) if err != nil { - return err + return ctx, err } params := ClientParameters(ctx) t, _, err := reader.ReadTypedMsg() if err != nil { - return err + return ctx, err } if t != types.ClientPassword { - return errors.New("unexpected password message") + return ctx, errors.New("unexpected password message") } password, err := reader.GetString() if err != nil { - return err + return ctx, err } - valid, err := validate(params[ParamUsername], password) + ctx, valid, err := validate(ctx, params[ParamUsername], password) if err != nil { - return err + return ctx, err } if !valid { - return ErrorCode(writer, pgerror.WithCode(errors.New("invalid username/password"), codes.InvalidPassword)) + return ctx, ErrorCode(writer, pgerror.WithCode(errors.New("invalid username/password"), codes.InvalidPassword)) } - return writeAuthType(writer, authOK) + return ctx, writeAuthType(writer, authOK) } } diff --git a/auth_test.go b/auth_test.go index 3cc487d..3830434 100644 --- a/auth_test.go +++ b/auth_test.go @@ -9,6 +9,7 @@ import ( "github.com/jeroenrinzema/psql-wire/internal/buffer" "github.com/jeroenrinzema/psql-wire/internal/types" + "github.com/stretchr/testify/require" "go.uber.org/zap" ) @@ -21,16 +22,12 @@ func TestDefaultHandleAuth(t *testing.T) { writer := buffer.NewWriter(zap.NewNop(), sink) server := &Server{logger: zap.NewNop()} - err := server.handleAuth(ctx, reader, writer) - if err != nil { - t.Fatal(err) - } + ctx, err := server.handleAuth(ctx, reader, writer) + require.NoError(t, err) result := buffer.NewReader(zap.NewNop(), sink, buffer.DefaultBufferSize) ty, ln, err := result.ReadTypedMsg() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) if ln == 0 { t.Error("unexpected length, expected typed message length to be greater then 0") @@ -41,9 +38,7 @@ func TestDefaultHandleAuth(t *testing.T) { } status, err := result.GetUint32() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) if authType(status) != authOK { t.Errorf("unexpected auth status %d, expected OK", status) @@ -62,12 +57,12 @@ func TestClearTextPassword(t *testing.T) { incoming.AddNullTerminate() incoming.End() //nolint:errcheck - validate := func(username, password string) (bool, error) { + validate := func(ctx context.Context, username, password string) (context.Context, bool, error) { if password != expected { - return false, fmt.Errorf("unexpected password: %s", password) + return ctx, false, fmt.Errorf("unexpected password: %s", password) } - return true, nil + return ctx, true, nil } sink := bytes.NewBuffer([]byte{}) @@ -77,8 +72,7 @@ func TestClearTextPassword(t *testing.T) { writer := buffer.NewWriter(zap.NewNop(), sink) server := &Server{logger: zap.NewNop(), Auth: ClearTextPassword(validate)} - err := server.handleAuth(ctx, reader, writer) - if err != nil { - t.Error("unexpected error:", err) - } + out, err := server.handleAuth(ctx, reader, writer) + require.NoError(t, err) + require.Equal(t, ctx, out) } diff --git a/wire.go b/wire.go index 678797e..784a8b6 100644 --- a/wire.go +++ b/wire.go @@ -145,7 +145,7 @@ func (srv *Server) serve(ctx context.Context, conn net.Conn) error { return err } - err = srv.handleAuth(ctx, reader, writer) + ctx, err = srv.handleAuth(ctx, reader, writer) if err != nil { return err }