diff --git a/cmd/ldap-sql-adapter/main.go b/cmd/ldap-sql-adapter/main.go index 5862b0c..253627f 100644 --- a/cmd/ldap-sql-adapter/main.go +++ b/cmd/ldap-sql-adapter/main.go @@ -43,7 +43,7 @@ func main() { defer ldapserver.Stop() // Build handler - srv := server.NewServer(config, log) + srv := server.NewServer(config, log, provider) // Start go srv.Start() diff --git a/internal/provider/provider.go b/internal/provider/provider.go index 5374c24..19a5218 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -11,6 +11,7 @@ var ( // Provider is used to authenticate users type Provider interface { + Ping(ctx context.Context) error FindUserPasswordByUsername(ctx context.Context, username string) ([]byte, error) FindUserByUsernameOrEmail(ctx context.Context, username string, email string) (User, error) FindUserGroups(ctx context.Context, username string) ([]Group, error) diff --git a/internal/provider/sql.go b/internal/provider/sql.go index 3e16ca3..20a70d7 100644 --- a/internal/provider/sql.go +++ b/internal/provider/sql.go @@ -57,6 +57,11 @@ func NewSQLProvider(config SQLProviderConfig) (*SQLProvider, error) { return &SQLProvider{db: db, config: config}, nil } +func (p *SQLProvider) Ping(ctx context.Context) (err error) { + defer logMetric("Ping")(err) + return p.db.PingContext(ctx) +} + func (p *SQLProvider) FindUserPasswordByUsername(ctx context.Context, uid string) (passwordBytes []byte, err error) { defer logMetric("FindUserPasswordByUsername")(err) rows, err := p.db.NamedQueryContext(ctx, p.config.SQLGetUserPasswordByUsernameQuery, map[string]any{"uid": uid}) diff --git a/internal/server/server.go b/internal/server/server.go index fd379db..c108207 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -13,21 +13,24 @@ import ( "github.com/rs/zerolog/log" "github.com/blesswinsamuel/ldap-sql-proxy/internal/config" + "github.com/blesswinsamuel/ldap-sql-proxy/internal/provider" ) // Server contains router and handler methods type Server struct { - router *mux.Router - logger zerolog.Logger - config *config.Config + router *mux.Router + provider provider.Provider + logger zerolog.Logger + config *config.Config } // NewServer creates a new server object and builds router -func NewServer(cfg *config.Config, logger zerolog.Logger) *Server { +func NewServer(cfg *config.Config, logger zerolog.Logger, provider provider.Provider) *Server { s := &Server{ - router: mux.NewRouter(), - logger: logger, - config: cfg, + router: mux.NewRouter(), + logger: logger, + provider: provider, + config: cfg, } s.buildRoutes() @@ -69,6 +72,14 @@ func (s *Server) buildRoutes() { s.router.Use(s.loggerMiddleware) s.router.Handle("/metrics", promhttp.Handler()) + s.router.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { + err := s.provider.Ping(r.Context()) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) + }) } func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {