Skip to content

Commit

Permalink
add separate query for getting user password only and another for get…
Browse files Browse the repository at this point in the history
…ting user by username or email
  • Loading branch information
blesswinsamuel committed Sep 9, 2023
1 parent 1fb0de3 commit cfe037c
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 41 deletions.
66 changes: 36 additions & 30 deletions internal/ldapserver/ldapserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,14 @@ func (s *LdapServer) handleBind(w ldap.ResponseWriter, m *ldap.Message) {
switch organizationUnit {
case "people":
uid := dn["uid"][0]
user, err := s.provider.FindByUID(ctx, uid)
userPasswordHashed, err := s.provider.FindUserPasswordByUsername(ctx, uid)
if err != nil {
errorResponse(ctx, w, ldap.NewBindResponse(ldap.LDAPResultInvalidCredentials), err, "unable to find user: %s", uid)
return
}
logger.Info().Interface("user", userPasswordHashed).Msg("found user during bind")
// fmt.Println(password, user["password"])
err = bcrypt.CompareHashAndPassword([]byte(user["password"].(string)), []byte(password))
err = bcrypt.CompareHashAndPassword(userPasswordHashed, []byte(password))
if err != nil {
errorResponse(ctx, w, ldap.NewBindResponse(ldap.LDAPResultInvalidCredentials), err, "invalid password for user: %s", uid)
return
Expand All @@ -147,35 +148,48 @@ func (s *LdapServer) handleBind(w ldap.ResponseWriter, m *ldap.Message) {
}
}

func (s *LdapServer) handleSearchUsers(w ldap.ResponseWriter, m *ldap.Message) {
r := m.GetSearchRequest()
logger := log.With().Str("method", "handleSearchUsers").
Int("id", m.MessageID().Int()).
Str("base_dn", string(r.BaseObject())).Str("filter", r.FilterString()).Int("scope", r.Scope().Int()).
Logger()
ctx := logger.WithContext(context.Background())
logger.Debug().Msg("search users request")

func parseSearchFilter(filter message.Filter) map[string]string {
condition := map[string]string{}
switch filter := r.Filter().(type) {
switch filter := filter.(type) {
case message.FilterAnd:
for _, f := range filter {
switch f := f.(type) {
case message.FilterEqualityMatch:
condition[string(f.AttributeDesc())] = string(f.AssertionValue())
case message.FilterOr:
for _, f := range f {
switch f := f.(type) {
case message.FilterEqualityMatch:
condition[string(f.AttributeDesc())] = string(f.AssertionValue())
}
}
}
}
}
return condition
}

logger.Info().Interface("condition", condition).Msg("condition")
func (s *LdapServer) handleSearchUsers(w ldap.ResponseWriter, m *ldap.Message) {
r := m.GetSearchRequest()
logger := log.With().Str("method", "handleSearchUsers").
Int("id", m.MessageID().Int()).
Str("base_dn", string(r.BaseObject())).Str("filter", r.FilterString()).Int("scope", r.Scope().Int()).
Logger()
ctx := logger.WithContext(context.Background())
logger.Debug().Msg("search users request")

condition := parseSearchFilter(r.Filter())

logger.Info().Interface("condition", condition).Msg("search users condition")

uid := condition["uid"]
email := condition["email"]

if uid == "" {
errorResponse(ctx, w, ldap.NewSearchResultDoneResponse(ldap.LDAPResultNoSuchObject), nil, "uid is empty")
if uid == "" && email == "" {
errorResponse(ctx, w, ldap.NewSearchResultDoneResponse(ldap.LDAPResultNoSuchObject), nil, "uid and email is empty")
return
}
user, err := s.provider.FindByUID(ctx, uid)
user, err := s.provider.FindUserByUsernameOrEmail(ctx, uid, email)
if err != nil {
if err.Error() == "user not found" {
logger.Warn().Msg("user not found")
Expand All @@ -185,7 +199,8 @@ func (s *LdapServer) handleSearchUsers(w ldap.ResponseWriter, m *ldap.Message) {
errorResponse(ctx, w, ldap.NewSearchResultDoneResponse(ldap.LDAPResultNoSuchObject), err, "unable to find user by uid")
return
}
entry := ldap.NewSearchResultEntry(fmt.Sprintf("uid=%s,ou=%s,%s", uid, "people", s.config.BaseDN))
logger.Info().Interface("user", user).Msg("found user during search user")
entry := ldap.NewSearchResultEntry(fmt.Sprintf("uid=%s,ou=%s,%s", user["uid"], "people", s.config.BaseDN))
for k, v := range user {
if k == "password" {
continue
Expand Down Expand Up @@ -216,18 +231,9 @@ func (s *LdapServer) handleSearchGroups(w ldap.ResponseWriter, m *ldap.Message)
ctx := logger.WithContext(context.Background())
logger.Debug().Msg("search groups request")

condition := map[string]string{}
switch filter := r.Filter().(type) {
case message.FilterAnd:
for _, f := range filter {
switch f := f.(type) {
case message.FilterEqualityMatch:
condition[string(f.AttributeDesc())] = string(f.AssertionValue())
}
}
}
condition := parseSearchFilter(r.Filter())

logger.Info().Interface("condition", condition).Msg("condition")
logger.Info().Interface("condition", condition).Msg("search groups condition")

memberDN := condition["member"]

Expand All @@ -236,12 +242,12 @@ func (s *LdapServer) handleSearchGroups(w ldap.ResponseWriter, m *ldap.Message)
return
}
memberDNParsed := s.parseDN(memberDN)
groups, err := s.provider.FindGroups(ctx, memberDNParsed["uid"][0])
groups, err := s.provider.FindUserGroups(ctx, memberDNParsed["uid"][0])
if err != nil {
errorResponse(ctx, w, ldap.NewSearchResultDoneResponse(ldap.LDAPResultNoSuchObject), err, "unable to find group by uid")
return
}
log.Info().Interface("groups", groups).Msg("found groups")
log.Info().Interface("groups", groups).Msg("found user groups during search groups")
for _, group := range groups {
groupName := group["name"].(string)
entry := ldap.NewSearchResultEntry(fmt.Sprintf("cn=%s,ou=%s,%s", groupName, "groups", s.config.BaseDN))
Expand Down
7 changes: 4 additions & 3 deletions internal/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ import "context"

// Provider is used to authenticate users
type Provider interface {
FindByUID(ctx context.Context, uid string) (User, error)
FindGroups(ctx context.Context, uid string) ([]Group, error)
UpdateUserPassword(ctx context.Context, uid string, password string) error
FindUserPasswordByUsername(ctx context.Context, username string) ([]byte, error)
FindUserByUsernameOrEmail(ctx context.Context, username string, email string) (User, error)
FindUserGroups(ctx context.Context, username string) ([]Group, error)
UpdateUserPassword(ctx context.Context, username string, password string) error
}

// User is the authenticated user
Expand Down
43 changes: 35 additions & 8 deletions internal/provider/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,17 @@ import (
_ "github.com/lib/pq"
)

var (
ErrUserNotFound = fmt.Errorf("user not found")
)

type SQLProviderConfig struct {
DatabaseURL string `long:"database-url" env:"DATABASE_URL"`

SQLGetUserQuery string `long:"sql-get-user-query" env:"SQL_GET_USER_QUERY" default:""`
SQLGetUserGroupsQuery string `long:"sql-get-user-groups-query" env:"SQL_GET_USER_GROUPS_QUERY" default:""`
SQLUpdatePasswordQuery string `long:"sql-update-password-query" env:"SQL_UPDATE_PASSWORD_QUERY" default:""`
SQLGetUserPasswordByUsernameQuery string `long:"sql-get-user-password-by-username-query" env:"SQL_GET_USER_PASSWORD_BY_USERNAME_QUERY" default:""`
SQLGetUserByUsernameOrEmailQuery string `long:"sql-get-user-by-username-or-email-query" env:"SQL_GET_USER_BY_USERNAME_OR_EMAIL_QUERY" default:""`
SQLGetUserGroupsQuery string `long:"sql-get-user-groups-query" env:"SQL_GET_USER_GROUPS_QUERY" default:""`
SQLUpdatePasswordQuery string `long:"sql-update-password-query" env:"SQL_UPDATE_PASSWORD_QUERY" default:""`
}

type SQLProvider struct {
Expand All @@ -29,11 +34,33 @@ func NewSQLProvider(config SQLProviderConfig) (*SQLProvider, error) {
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("unable to ping database: %w", err)
}
if config.SQLGetUserPasswordByUsernameQuery == "" || config.SQLGetUserByUsernameOrEmailQuery == "" || config.SQLGetUserGroupsQuery == "" || config.SQLUpdatePasswordQuery == "" {
return nil, fmt.Errorf("sql queries not provided")
}
return &SQLProvider{db: db, config: config}, nil
}

func (p *SQLProvider) FindByUID(ctx context.Context, uid string) (User, error) {
rows, err := p.db.NamedQueryContext(ctx, p.config.SQLGetUserQuery, map[string]any{"uid": uid})
func (p *SQLProvider) FindUserPasswordByUsername(ctx context.Context, uid string) ([]byte, error) {
rows, err := p.db.NamedQueryContext(ctx, p.config.SQLGetUserPasswordByUsernameQuery, map[string]any{"uid": uid})
// password
if err != nil {
return nil, fmt.Errorf("unable to get user: %w", err)
}
defer rows.Close()
users, err := rowsToMap(rows)
if err != nil {
return nil, fmt.Errorf("unable to get columns: %w", err)
}
if len(users) != 1 {
return nil, ErrUserNotFound
}
user := users[0]
password := user["password"].(string)
return []byte(password), nil
}

func (p *SQLProvider) FindUserByUsernameOrEmail(ctx context.Context, uid string, email string) (User, error) {
rows, err := p.db.NamedQueryContext(ctx, p.config.SQLGetUserByUsernameOrEmailQuery, map[string]any{"uid": uid, "email": email})
// givenname, sn, displayname, mail, uid, password
if err != nil {
return nil, fmt.Errorf("unable to get user: %w", err)
Expand All @@ -43,13 +70,13 @@ func (p *SQLProvider) FindByUID(ctx context.Context, uid string) (User, error) {
if err != nil {
return nil, fmt.Errorf("unable to get columns: %w", err)
}
if len(users) == 0 {
return nil, fmt.Errorf("user not found")
if len(users) != 1 {
return nil, ErrUserNotFound
}
return users[0], nil
}

func (p *SQLProvider) FindGroups(ctx context.Context, uid string) ([]Group, error) {
func (p *SQLProvider) FindUserGroups(ctx context.Context, uid string) ([]Group, error) {
rows, err := p.db.NamedQueryContext(ctx, p.config.SQLGetUserGroupsQuery, map[string]any{"uid": uid})
if err != nil {
return nil, fmt.Errorf("unable to get groups: %w", err)
Expand Down

0 comments on commit cfe037c

Please sign in to comment.