Skip to content

Commit

Permalink
Expose a Registry type in health package, so unit tests can stay isol…
Browse files Browse the repository at this point in the history
…ated from each other

Update docs.

Change health_test.go tests to create their own registries and register
the checks there. The tests now call CheckStatus directly instead of
polling the HTTP handler, which returns results from the default
registry.

Signed-off-by: Aaron Lehmann <[email protected]>
  • Loading branch information
aaronlehmann committed Aug 20, 2015
1 parent 79959f5 commit b9b9caf
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 123 deletions.
2 changes: 1 addition & 1 deletion health/doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
//
// The recommended way of registering checks is using a periodic Check.
// PeriodicChecks run on a certain schedule and asynchronously update the
// status of the check. This allows "CheckStatus()" to return without blocking
// status of the check. This allows CheckStatus to return without blocking
// on an expensive check.
//
// A trivial example of a check that runs every 5 seconds and shuts down our
Expand Down
101 changes: 68 additions & 33 deletions health/health.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,26 @@ import (
"github.com/docker/distribution/registry/api/errcode"
)

var (
mutex sync.RWMutex
registeredChecks = make(map[string]Checker)
)
// A Registry is a collection of checks. Most applications will use the global
// registry defined in DefaultRegistry. However, unit tests may need to create
// separate registries to isolate themselves from other tests.
type Registry struct {
mu sync.RWMutex
registeredChecks map[string]Checker
}

// NewRegistry creates a new registry. This isn't necessary for normal use of
// the package, but may be useful for unit tests so individual tests have their
// own set of checks.
func NewRegistry() *Registry {
return &Registry{
registeredChecks: make(map[string]Checker),
}
}

// DefaultRegistry is the default registry where checks are registered. It is
// the registry used by the HTTP handler.
var DefaultRegistry *Registry

// Checker is the interface for a Health Checker
type Checker interface {
Expand Down Expand Up @@ -144,11 +160,11 @@ func PeriodicThresholdChecker(check Checker, period time.Duration, threshold int
}

// CheckStatus returns a map with all the current health check errors
func CheckStatus() map[string]string { // TODO(stevvooe) this needs a proper type
mutex.RLock()
defer mutex.RUnlock()
func (registry *Registry) CheckStatus() map[string]string { // TODO(stevvooe) this needs a proper type
registry.mu.RLock()
defer registry.mu.RUnlock()
statusKeys := make(map[string]string)
for k, v := range registeredChecks {
for k, v := range registry.registeredChecks {
err := v.Check()
if err != nil {
statusKeys[k] = err.Error()
Expand All @@ -158,48 +174,66 @@ func CheckStatus() map[string]string { // TODO(stevvooe) this needs a proper typ
return statusKeys
}

// Register associates the checker with the provided name. We allow
// overwrites to a specific check status.
func Register(name string, check Checker) {
mutex.Lock()
defer mutex.Unlock()
_, ok := registeredChecks[name]
// CheckStatus returns a map with all the current health check errors from the
// default registry.
func CheckStatus() map[string]string {
return DefaultRegistry.CheckStatus()
}

// Register associates the checker with the provided name.
func (registry *Registry) Register(name string, check Checker) {
if registry == nil {
registry = DefaultRegistry
}
registry.mu.Lock()
defer registry.mu.Unlock()
_, ok := registry.registeredChecks[name]
if ok {
panic("Check already exists: " + name)
}
registeredChecks[name] = check
registry.registeredChecks[name] = check
}

// Unregister removes the named checker.
func Unregister(name string) {
mutex.Lock()
defer mutex.Unlock()
delete(registeredChecks, name)
// Register associates the checker with the provided name in the default
// registry.
func Register(name string, check Checker) {
DefaultRegistry.Register(name, check)
}

// UnregisterAll removes all registered checkers.
func UnregisterAll() {
mutex.Lock()
defer mutex.Unlock()
registeredChecks = make(map[string]Checker)
// RegisterFunc allows the convenience of registering a checker directly from
// an arbitrary func() error.
func (registry *Registry) RegisterFunc(name string, check func() error) {
registry.Register(name, CheckFunc(check))
}

// RegisterFunc allows the convenience of registering a checker directly
// from an arbitrary func() error
// RegisterFunc allows the convenience of registering a checker in the default
// registry directly from an arbitrary func() error.
func RegisterFunc(name string, check func() error) {
Register(name, CheckFunc(check))
DefaultRegistry.RegisterFunc(name, check)
}

// RegisterPeriodicFunc allows the convenience of registering a PeriodicChecker
// from an arbitrary func() error.
func (registry *Registry) RegisterPeriodicFunc(name string, period time.Duration, check CheckFunc) {
registry.Register(name, PeriodicChecker(CheckFunc(check), period))
}

// RegisterPeriodicFunc allows the convenience of registering a PeriodicChecker
// from an arbitrary func() error
// in the default registry from an arbitrary func() error.
func RegisterPeriodicFunc(name string, period time.Duration, check CheckFunc) {
Register(name, PeriodicChecker(CheckFunc(check), period))
DefaultRegistry.RegisterPeriodicFunc(name, period, check)
}

// RegisterPeriodicThresholdFunc allows the convenience of registering a
// PeriodicChecker from an arbitrary func() error.
func (registry *Registry) RegisterPeriodicThresholdFunc(name string, period time.Duration, threshold int, check CheckFunc) {
registry.Register(name, PeriodicThresholdChecker(CheckFunc(check), period, threshold))
}

// RegisterPeriodicThresholdFunc allows the convenience of registering a
// PeriodicChecker from an arbitrary func() error
// PeriodicChecker in the default registry from an arbitrary func() error.
func RegisterPeriodicThresholdFunc(name string, period time.Duration, threshold int, check CheckFunc) {
Register(name, PeriodicThresholdChecker(CheckFunc(check), period, threshold))
DefaultRegistry.RegisterPeriodicThresholdFunc(name, period, threshold, check)
}

// StatusHandler returns a JSON blob with all the currently registered Health Checks
Expand Down Expand Up @@ -265,7 +299,8 @@ func statusResponse(w http.ResponseWriter, r *http.Request, status int, checks m
}
}

// Registers global /debug/health api endpoint
// Registers global /debug/health api endpoint, creates default registry
func init() {
DefaultRegistry = NewRegistry()
http.HandleFunc("/debug/health", StatusHandler)
}
2 changes: 1 addition & 1 deletion health/health_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func TestReturns503IfThereAreErrorChecks(t *testing.T) {
// the web application when things aren't so healthy.
func TestHealthHandler(t *testing.T) {
// clear out existing checks.
registeredChecks = make(map[string]Checker)
DefaultRegistry = NewRegistry()

// protect an http server
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down
22 changes: 15 additions & 7 deletions registry/handlers/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,15 @@ func NewApp(ctx context.Context, configuration configuration.Configuration) *App
// process. Because the configuration and app are tightly coupled,
// implementing this properly will require a refactor. This method may panic
// if called twice in the same process.
func (app *App) RegisterHealthChecks() {
func (app *App) RegisterHealthChecks(healthRegistries ...*health.Registry) {
if len(healthRegistries) > 1 {
panic("RegisterHealthChecks called with more than one registry")
}
healthRegistry := health.DefaultRegistry
if len(healthRegistries) == 1 {
healthRegistry = healthRegistries[0]
}

if app.Config.Health.StorageDriver.Enabled {
interval := app.Config.Health.StorageDriver.Interval
if interval == 0 {
Expand All @@ -247,9 +255,9 @@ func (app *App) RegisterHealthChecks() {
}

if app.Config.Health.StorageDriver.Threshold != 0 {
health.RegisterPeriodicThresholdFunc("storagedriver_"+app.Config.Storage.Type(), interval, app.Config.Health.StorageDriver.Threshold, storageDriverCheck)
healthRegistry.RegisterPeriodicThresholdFunc("storagedriver_"+app.Config.Storage.Type(), interval, app.Config.Health.StorageDriver.Threshold, storageDriverCheck)
} else {
health.RegisterPeriodicFunc("storagedriver_"+app.Config.Storage.Type(), interval, storageDriverCheck)
healthRegistry.RegisterPeriodicFunc("storagedriver_"+app.Config.Storage.Type(), interval, storageDriverCheck)
}
}

Expand All @@ -260,10 +268,10 @@ func (app *App) RegisterHealthChecks() {
}
if fileChecker.Threshold != 0 {
ctxu.GetLogger(app).Infof("configuring file health check path=%s, interval=%d, threshold=%d", fileChecker.File, interval/time.Second, fileChecker.Threshold)
health.Register(fileChecker.File, health.PeriodicThresholdChecker(checks.FileChecker(fileChecker.File), interval, fileChecker.Threshold))
healthRegistry.Register(fileChecker.File, health.PeriodicThresholdChecker(checks.FileChecker(fileChecker.File), interval, fileChecker.Threshold))
} else {
ctxu.GetLogger(app).Infof("configuring file health check path=%s, interval=%d", fileChecker.File, interval/time.Second)
health.Register(fileChecker.File, health.PeriodicChecker(checks.FileChecker(fileChecker.File), interval))
healthRegistry.Register(fileChecker.File, health.PeriodicChecker(checks.FileChecker(fileChecker.File), interval))
}
}

Expand All @@ -274,10 +282,10 @@ func (app *App) RegisterHealthChecks() {
}
if httpChecker.Threshold != 0 {
ctxu.GetLogger(app).Infof("configuring HTTP health check uri=%s, interval=%d, threshold=%d", httpChecker.URI, interval/time.Second, httpChecker.Threshold)
health.Register(httpChecker.URI, health.PeriodicThresholdChecker(checks.HTTPChecker(httpChecker.URI), interval, httpChecker.Threshold))
healthRegistry.Register(httpChecker.URI, health.PeriodicThresholdChecker(checks.HTTPChecker(httpChecker.URI), interval, httpChecker.Threshold))
} else {
ctxu.GetLogger(app).Infof("configuring HTTP health check uri=%s, interval=%d", httpChecker.URI, interval/time.Second)
health.Register(httpChecker.URI, health.PeriodicChecker(checks.HTTPChecker(httpChecker.URI), interval))
healthRegistry.Register(httpChecker.URI, health.PeriodicChecker(checks.HTTPChecker(httpChecker.URI), interval))
}
}
}
Expand Down
100 changes: 19 additions & 81 deletions registry/handlers/health_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package handlers

import (
"encoding/json"
"io/ioutil"
"net/http"
"net/http/httptest"
Expand All @@ -15,9 +14,6 @@ import (
)

func TestFileHealthCheck(t *testing.T) {
// In case other tests registered checks before this one
health.UnregisterAll()

interval := time.Second

tmpfile, err := ioutil.TempFile(os.TempDir(), "healthcheck")
Expand All @@ -43,60 +39,29 @@ func TestFileHealthCheck(t *testing.T) {
ctx := context.Background()

app := NewApp(ctx, config)
app.RegisterHealthChecks()

debugServer := httptest.NewServer(nil)
healthRegistry := health.NewRegistry()
app.RegisterHealthChecks(healthRegistry)

// Wait for health check to happen
<-time.After(2 * interval)

resp, err := http.Get(debugServer.URL + "/debug/health")
if err != nil {
t.Fatalf("error performing HTTP GET: %v", err)
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("error reading HTTP body: %v", err)
}
resp.Body.Close()
var decoded map[string]string
err = json.Unmarshal(body, &decoded)
if err != nil {
t.Fatalf("error unmarshaling json: %v", err)
status := healthRegistry.CheckStatus()
if len(status) != 1 {
t.Fatal("expected 1 item in health check results")
}
if len(decoded) != 1 {
t.Fatal("expected 1 item in returned json")
}
if decoded[tmpfile.Name()] != "file exists" {
if status[tmpfile.Name()] != "file exists" {
t.Fatal(`did not get "file exists" result for health check`)
}

os.Remove(tmpfile.Name())

<-time.After(2 * interval)
resp, err = http.Get(debugServer.URL + "/debug/health")
if err != nil {
t.Fatalf("error performing HTTP GET: %v", err)
}
body, err = ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("error reading HTTP body: %v", err)
}
resp.Body.Close()
var decoded2 map[string]string
err = json.Unmarshal(body, &decoded2)
if err != nil {
t.Fatalf("error unmarshaling json: %v", err)
}
if len(decoded2) != 0 {
t.Fatal("expected 0 items in returned json")
if len(healthRegistry.CheckStatus()) != 0 {
t.Fatal("expected 0 items in health check results")
}
}

func TestHTTPHealthCheck(t *testing.T) {
// In case other tests registered checks before this one
health.UnregisterAll()

interval := time.Second
threshold := 3

Expand Down Expand Up @@ -132,32 +97,18 @@ func TestHTTPHealthCheck(t *testing.T) {
ctx := context.Background()

app := NewApp(ctx, config)
app.RegisterHealthChecks()

debugServer := httptest.NewServer(nil)
healthRegistry := health.NewRegistry()
app.RegisterHealthChecks(healthRegistry)

for i := 0; ; i++ {
<-time.After(interval)

resp, err := http.Get(debugServer.URL + "/debug/health")
if err != nil {
t.Fatalf("error performing HTTP GET: %v", err)
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("error reading HTTP body: %v", err)
}
resp.Body.Close()
var decoded map[string]string
err = json.Unmarshal(body, &decoded)
if err != nil {
t.Fatalf("error unmarshaling json: %v", err)
}
status := healthRegistry.CheckStatus()

if i < threshold-1 {
// definitely shouldn't have hit the threshold yet
if len(decoded) != 0 {
t.Fatal("expected 1 items in returned json")
if len(status) != 0 {
t.Fatal("expected 1 item in health check results")
}
continue
}
Expand All @@ -166,10 +117,10 @@ func TestHTTPHealthCheck(t *testing.T) {
continue
}

if len(decoded) != 1 {
t.Fatal("expected 1 item in returned json")
if len(status) != 1 {
t.Fatal("expected 1 item in health check results")
}
if decoded[checkedServer.URL] != "downstream service returned unexpected status: 500" {
if status[checkedServer.URL] != "downstream service returned unexpected status: 500" {
t.Fatal("did not get expected result for health check")
}

Expand All @@ -180,21 +131,8 @@ func TestHTTPHealthCheck(t *testing.T) {
close(stopFailing)

<-time.After(2 * interval)
resp, err := http.Get(debugServer.URL + "/debug/health")
if err != nil {
t.Fatalf("error performing HTTP GET: %v", err)
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("error reading HTTP body: %v", err)
}
resp.Body.Close()
var decoded map[string]string
err = json.Unmarshal(body, &decoded)
if err != nil {
t.Fatalf("error unmarshaling json: %v", err)
}
if len(decoded) != 0 {
t.Fatal("expected 0 items in returned json")

if len(healthRegistry.CheckStatus()) != 0 {
t.Fatal("expected 0 items in health check results")
}
}

0 comments on commit b9b9caf

Please sign in to comment.