Skip to content

Commit

Permalink
feat: include the ability to set context values during auth
Browse files Browse the repository at this point in the history
  • Loading branch information
jeroenrinzema committed Jun 11, 2023
1 parent e7647ad commit 6a47b28
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 31 deletions.
26 changes: 13 additions & 13 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,19 @@ const (
)

// AuthStrategy represents a authentication strategy used to authenticate a user
type AuthStrategy func(ctx context.Context, writer *buffer.Writer, reader *buffer.Reader) (err error)
type AuthStrategy func(ctx context.Context, writer *buffer.Writer, reader *buffer.Reader) (_ context.Context, err error)

// handleAuth handles the client authentication for the given connection.
// This methods validates the incoming credentials and writes to the client whether
// the provided credentials are correct. When the provided credentials are invalid
// or any unexpected error occures is an error returned and should the connection be closed.
func (srv *Server) handleAuth(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) error {
func (srv *Server) handleAuth(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer) (context.Context, error) {
srv.logger.Debug("authenticating client connection")

if srv.Auth == nil {
// No authentication strategy configured.
// Announcing to the client that the connection is authenticated
return writeAuthType(writer, authOK)
return ctx, writeAuthType(writer, authOK)
}

return srv.Auth(ctx, writer, reader)
Expand All @@ -45,38 +45,38 @@ func (srv *Server) handleAuth(ctx context.Context, reader *buffer.Reader, writer
// clear text password and validates if the provided username and password (received
// inside the client parameters) are valid. If the provided credentials are invalid
// or any unexpected error occures is an error returned and should the connection be closed.
func ClearTextPassword(validate func(username, password string) (bool, error)) AuthStrategy {
return func(ctx context.Context, writer *buffer.Writer, reader *buffer.Reader) (err error) {
func ClearTextPassword(validate func(ctx context.Context, username, password string) (context.Context, bool, error)) AuthStrategy {
return func(ctx context.Context, writer *buffer.Writer, reader *buffer.Reader) (_ context.Context, err error) {
err = writeAuthType(writer, authClearTextPassword)
if err != nil {
return err
return ctx, err
}

params := ClientParameters(ctx)
t, _, err := reader.ReadTypedMsg()
if err != nil {
return err
return ctx, err
}

if t != types.ClientPassword {
return errors.New("unexpected password message")
return ctx, errors.New("unexpected password message")
}

password, err := reader.GetString()
if err != nil {
return err
return ctx, err
}

valid, err := validate(params[ParamUsername], password)
ctx, valid, err := validate(ctx, params[ParamUsername], password)
if err != nil {
return err
return ctx, err
}

if !valid {
return ErrorCode(writer, pgerror.WithCode(errors.New("invalid username/password"), codes.InvalidPassword))
return ctx, ErrorCode(writer, pgerror.WithCode(errors.New("invalid username/password"), codes.InvalidPassword))
}

return writeAuthType(writer, authOK)
return ctx, writeAuthType(writer, authOK)
}
}

Expand Down
28 changes: 11 additions & 17 deletions auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/jeroenrinzema/psql-wire/internal/buffer"
"github.com/jeroenrinzema/psql-wire/internal/types"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
)

Expand All @@ -21,16 +22,12 @@ func TestDefaultHandleAuth(t *testing.T) {
writer := buffer.NewWriter(zap.NewNop(), sink)

server := &Server{logger: zap.NewNop()}
err := server.handleAuth(ctx, reader, writer)
if err != nil {
t.Fatal(err)
}
ctx, err := server.handleAuth(ctx, reader, writer)
require.NoError(t, err)

result := buffer.NewReader(zap.NewNop(), sink, buffer.DefaultBufferSize)
ty, ln, err := result.ReadTypedMsg()
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)

if ln == 0 {
t.Error("unexpected length, expected typed message length to be greater then 0")
Expand All @@ -41,9 +38,7 @@ func TestDefaultHandleAuth(t *testing.T) {
}

status, err := result.GetUint32()
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)

if authType(status) != authOK {
t.Errorf("unexpected auth status %d, expected OK", status)
Expand All @@ -62,12 +57,12 @@ func TestClearTextPassword(t *testing.T) {
incoming.AddNullTerminate()
incoming.End() //nolint:errcheck

validate := func(username, password string) (bool, error) {
validate := func(ctx context.Context, username, password string) (context.Context, bool, error) {
if password != expected {
return false, fmt.Errorf("unexpected password: %s", password)
return ctx, false, fmt.Errorf("unexpected password: %s", password)
}

return true, nil
return ctx, true, nil
}

sink := bytes.NewBuffer([]byte{})
Expand All @@ -77,8 +72,7 @@ func TestClearTextPassword(t *testing.T) {
writer := buffer.NewWriter(zap.NewNop(), sink)

server := &Server{logger: zap.NewNop(), Auth: ClearTextPassword(validate)}
err := server.handleAuth(ctx, reader, writer)
if err != nil {
t.Error("unexpected error:", err)
}
out, err := server.handleAuth(ctx, reader, writer)
require.NoError(t, err)
require.Equal(t, ctx, out)
}
2 changes: 1 addition & 1 deletion wire.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func (srv *Server) serve(ctx context.Context, conn net.Conn) error {
return err
}

err = srv.handleAuth(ctx, reader, writer)
ctx, err = srv.handleAuth(ctx, reader, writer)
if err != nil {
return err
}
Expand Down

0 comments on commit 6a47b28

Please sign in to comment.