From cfe037cb44c87e1acaa6af546abd5671feeb827a Mon Sep 17 00:00:00 2001 From: Blesswin Samuel Date: Sat, 9 Sep 2023 14:11:29 +0530 Subject: [PATCH] add separate query for getting user password only and another for getting user by username or email --- internal/ldapserver/ldapserver.go | 66 +++++++++++++++++-------------- internal/provider/provider.go | 7 ++-- internal/provider/sql.go | 43 ++++++++++++++++---- 3 files changed, 75 insertions(+), 41 deletions(-) diff --git a/internal/ldapserver/ldapserver.go b/internal/ldapserver/ldapserver.go index 19fb5c8..0b50adf 100644 --- a/internal/ldapserver/ldapserver.go +++ b/internal/ldapserver/ldapserver.go @@ -120,13 +120,14 @@ func (s *LdapServer) handleBind(w ldap.ResponseWriter, m *ldap.Message) { switch organizationUnit { case "people": uid := dn["uid"][0] - user, err := s.provider.FindByUID(ctx, uid) + userPasswordHashed, err := s.provider.FindUserPasswordByUsername(ctx, uid) if err != nil { errorResponse(ctx, w, ldap.NewBindResponse(ldap.LDAPResultInvalidCredentials), err, "unable to find user: %s", uid) return } + logger.Info().Interface("user", userPasswordHashed).Msg("found user during bind") // fmt.Println(password, user["password"]) - err = bcrypt.CompareHashAndPassword([]byte(user["password"].(string)), []byte(password)) + err = bcrypt.CompareHashAndPassword(userPasswordHashed, []byte(password)) if err != nil { errorResponse(ctx, w, ldap.NewBindResponse(ldap.LDAPResultInvalidCredentials), err, "invalid password for user: %s", uid) return @@ -147,35 +148,48 @@ func (s *LdapServer) handleBind(w ldap.ResponseWriter, m *ldap.Message) { } } -func (s *LdapServer) handleSearchUsers(w ldap.ResponseWriter, m *ldap.Message) { - r := m.GetSearchRequest() - logger := log.With().Str("method", "handleSearchUsers"). - Int("id", m.MessageID().Int()). - Str("base_dn", string(r.BaseObject())).Str("filter", r.FilterString()).Int("scope", r.Scope().Int()). - Logger() - ctx := logger.WithContext(context.Background()) - logger.Debug().Msg("search users request") - +func parseSearchFilter(filter message.Filter) map[string]string { condition := map[string]string{} - switch filter := r.Filter().(type) { + switch filter := filter.(type) { case message.FilterAnd: for _, f := range filter { switch f := f.(type) { case message.FilterEqualityMatch: condition[string(f.AttributeDesc())] = string(f.AssertionValue()) + case message.FilterOr: + for _, f := range f { + switch f := f.(type) { + case message.FilterEqualityMatch: + condition[string(f.AttributeDesc())] = string(f.AssertionValue()) + } + } } } } + return condition +} - logger.Info().Interface("condition", condition).Msg("condition") +func (s *LdapServer) handleSearchUsers(w ldap.ResponseWriter, m *ldap.Message) { + r := m.GetSearchRequest() + logger := log.With().Str("method", "handleSearchUsers"). + Int("id", m.MessageID().Int()). + Str("base_dn", string(r.BaseObject())).Str("filter", r.FilterString()).Int("scope", r.Scope().Int()). + Logger() + ctx := logger.WithContext(context.Background()) + logger.Debug().Msg("search users request") + + condition := parseSearchFilter(r.Filter()) + + logger.Info().Interface("condition", condition).Msg("search users condition") uid := condition["uid"] + email := condition["email"] - if uid == "" { - errorResponse(ctx, w, ldap.NewSearchResultDoneResponse(ldap.LDAPResultNoSuchObject), nil, "uid is empty") + if uid == "" && email == "" { + errorResponse(ctx, w, ldap.NewSearchResultDoneResponse(ldap.LDAPResultNoSuchObject), nil, "uid and email is empty") return } - user, err := s.provider.FindByUID(ctx, uid) + user, err := s.provider.FindUserByUsernameOrEmail(ctx, uid, email) if err != nil { if err.Error() == "user not found" { logger.Warn().Msg("user not found") @@ -185,7 +199,8 @@ func (s *LdapServer) handleSearchUsers(w ldap.ResponseWriter, m *ldap.Message) { errorResponse(ctx, w, ldap.NewSearchResultDoneResponse(ldap.LDAPResultNoSuchObject), err, "unable to find user by uid") return } - entry := ldap.NewSearchResultEntry(fmt.Sprintf("uid=%s,ou=%s,%s", uid, "people", s.config.BaseDN)) + logger.Info().Interface("user", user).Msg("found user during search user") + entry := ldap.NewSearchResultEntry(fmt.Sprintf("uid=%s,ou=%s,%s", user["uid"], "people", s.config.BaseDN)) for k, v := range user { if k == "password" { continue @@ -216,18 +231,9 @@ func (s *LdapServer) handleSearchGroups(w ldap.ResponseWriter, m *ldap.Message) ctx := logger.WithContext(context.Background()) logger.Debug().Msg("search groups request") - condition := map[string]string{} - switch filter := r.Filter().(type) { - case message.FilterAnd: - for _, f := range filter { - switch f := f.(type) { - case message.FilterEqualityMatch: - condition[string(f.AttributeDesc())] = string(f.AssertionValue()) - } - } - } + condition := parseSearchFilter(r.Filter()) - logger.Info().Interface("condition", condition).Msg("condition") + logger.Info().Interface("condition", condition).Msg("search groups condition") memberDN := condition["member"] @@ -236,12 +242,12 @@ func (s *LdapServer) handleSearchGroups(w ldap.ResponseWriter, m *ldap.Message) return } memberDNParsed := s.parseDN(memberDN) - groups, err := s.provider.FindGroups(ctx, memberDNParsed["uid"][0]) + groups, err := s.provider.FindUserGroups(ctx, memberDNParsed["uid"][0]) if err != nil { errorResponse(ctx, w, ldap.NewSearchResultDoneResponse(ldap.LDAPResultNoSuchObject), err, "unable to find group by uid") return } - log.Info().Interface("groups", groups).Msg("found groups") + log.Info().Interface("groups", groups).Msg("found user groups during search groups") for _, group := range groups { groupName := group["name"].(string) entry := ldap.NewSearchResultEntry(fmt.Sprintf("cn=%s,ou=%s,%s", groupName, "groups", s.config.BaseDN)) diff --git a/internal/provider/provider.go b/internal/provider/provider.go index 368c077..2b4a24a 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -4,9 +4,10 @@ import "context" // Provider is used to authenticate users type Provider interface { - FindByUID(ctx context.Context, uid string) (User, error) - FindGroups(ctx context.Context, uid string) ([]Group, error) - UpdateUserPassword(ctx context.Context, uid string, password string) error + FindUserPasswordByUsername(ctx context.Context, username string) ([]byte, error) + FindUserByUsernameOrEmail(ctx context.Context, username string, email string) (User, error) + FindUserGroups(ctx context.Context, username string) ([]Group, error) + UpdateUserPassword(ctx context.Context, username string, password string) error } // User is the authenticated user diff --git a/internal/provider/sql.go b/internal/provider/sql.go index 15d9a9b..8d80a5b 100644 --- a/internal/provider/sql.go +++ b/internal/provider/sql.go @@ -8,12 +8,17 @@ import ( _ "github.com/lib/pq" ) +var ( + ErrUserNotFound = fmt.Errorf("user not found") +) + type SQLProviderConfig struct { DatabaseURL string `long:"database-url" env:"DATABASE_URL"` - SQLGetUserQuery string `long:"sql-get-user-query" env:"SQL_GET_USER_QUERY" default:""` - SQLGetUserGroupsQuery string `long:"sql-get-user-groups-query" env:"SQL_GET_USER_GROUPS_QUERY" default:""` - SQLUpdatePasswordQuery string `long:"sql-update-password-query" env:"SQL_UPDATE_PASSWORD_QUERY" default:""` + SQLGetUserPasswordByUsernameQuery string `long:"sql-get-user-password-by-username-query" env:"SQL_GET_USER_PASSWORD_BY_USERNAME_QUERY" default:""` + SQLGetUserByUsernameOrEmailQuery string `long:"sql-get-user-by-username-or-email-query" env:"SQL_GET_USER_BY_USERNAME_OR_EMAIL_QUERY" default:""` + SQLGetUserGroupsQuery string `long:"sql-get-user-groups-query" env:"SQL_GET_USER_GROUPS_QUERY" default:""` + SQLUpdatePasswordQuery string `long:"sql-update-password-query" env:"SQL_UPDATE_PASSWORD_QUERY" default:""` } type SQLProvider struct { @@ -29,11 +34,33 @@ func NewSQLProvider(config SQLProviderConfig) (*SQLProvider, error) { if err := db.Ping(); err != nil { return nil, fmt.Errorf("unable to ping database: %w", err) } + if config.SQLGetUserPasswordByUsernameQuery == "" || config.SQLGetUserByUsernameOrEmailQuery == "" || config.SQLGetUserGroupsQuery == "" || config.SQLUpdatePasswordQuery == "" { + return nil, fmt.Errorf("sql queries not provided") + } return &SQLProvider{db: db, config: config}, nil } -func (p *SQLProvider) FindByUID(ctx context.Context, uid string) (User, error) { - rows, err := p.db.NamedQueryContext(ctx, p.config.SQLGetUserQuery, map[string]any{"uid": uid}) +func (p *SQLProvider) FindUserPasswordByUsername(ctx context.Context, uid string) ([]byte, error) { + rows, err := p.db.NamedQueryContext(ctx, p.config.SQLGetUserPasswordByUsernameQuery, map[string]any{"uid": uid}) + // password + if err != nil { + return nil, fmt.Errorf("unable to get user: %w", err) + } + defer rows.Close() + users, err := rowsToMap(rows) + if err != nil { + return nil, fmt.Errorf("unable to get columns: %w", err) + } + if len(users) != 1 { + return nil, ErrUserNotFound + } + user := users[0] + password := user["password"].(string) + return []byte(password), nil +} + +func (p *SQLProvider) FindUserByUsernameOrEmail(ctx context.Context, uid string, email string) (User, error) { + rows, err := p.db.NamedQueryContext(ctx, p.config.SQLGetUserByUsernameOrEmailQuery, map[string]any{"uid": uid, "email": email}) // givenname, sn, displayname, mail, uid, password if err != nil { return nil, fmt.Errorf("unable to get user: %w", err) @@ -43,13 +70,13 @@ func (p *SQLProvider) FindByUID(ctx context.Context, uid string) (User, error) { if err != nil { return nil, fmt.Errorf("unable to get columns: %w", err) } - if len(users) == 0 { - return nil, fmt.Errorf("user not found") + if len(users) != 1 { + return nil, ErrUserNotFound } return users[0], nil } -func (p *SQLProvider) FindGroups(ctx context.Context, uid string) ([]Group, error) { +func (p *SQLProvider) FindUserGroups(ctx context.Context, uid string) ([]Group, error) { rows, err := p.db.NamedQueryContext(ctx, p.config.SQLGetUserGroupsQuery, map[string]any{"uid": uid}) if err != nil { return nil, fmt.Errorf("unable to get groups: %w", err)