diff --git a/client.go b/client.go index 8b0d6fc..d533b48 100644 --- a/client.go +++ b/client.go @@ -148,6 +148,14 @@ type Marathon interface { Leader() (string, error) // cause the current leader to abdicate AbdicateLeader() (string, error) + + // Extra APIs not mapping to any Marathon REST endpoint. + + // Stop terminates any library-local processes. For now, this covers + // notifying all running health check routines to terminate. + // This method is thread-safe and returns once all processes have + // stopped. + Stop() } var ( @@ -270,6 +278,10 @@ func (r *marathonClient) Ping() (bool, error) { return true, nil } +func (r *marathonClient) Stop() { + r.hosts.Stop() +} + func (r *marathonClient) apiGet(path string, post, result interface{}) error { return r.apiCall("GET", path, post, result) } diff --git a/client_test.go b/client_test.go index 51de622..3608a6f 100644 --- a/client_test.go +++ b/client_test.go @@ -17,12 +17,15 @@ limitations under the License. package marathon import ( + "sync/atomic" "testing" "net/http" + "net/http/httptest" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewClient(t *testing.T) { @@ -272,3 +275,41 @@ func TestAPIRequestDCOS(t *testing.T) { endpoint.Close() } } + +func TestStop(t *testing.T) { + var reqCount uint32 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + atomic.AddUint32(&reqCount, 1) + http.Error(w, "I'm down", 503) + })) + defer ts.Close() + + client, err := NewClient(Config{URL: ts.URL}) + require.NoError(t, err) + client.(*marathonClient).hosts.healthCheckInterval = 50 * time.Millisecond + + _, err = client.Ping() + require.Equal(t, ErrMarathonDown, err) + + // Expect some health checks to fail. + time.Sleep(150 * time.Millisecond) + count := int(atomic.LoadUint32(&reqCount)) + require.True(t, count > 0, "expected non-zero request count") + + // Stop all health check goroutines. + // Should be okay to call the method multiple times. + client.Stop() + client.Stop() + + // Wait for all health checks to terminate. + time.Sleep(100 * time.Millisecond) + + // Reset request counter. + atomic.StoreUint32(&reqCount, 0) + + // Wait another small period, not expecting any further health checks to + // fire. + time.Sleep(100 * time.Millisecond) + count = int(atomic.LoadUint32(&reqCount)) + assert.Equal(t, 0, count, "expected zero request count") +} diff --git a/cluster.go b/cluster.go index 992b4e0..5839b0c 100644 --- a/cluster.go +++ b/cluster.go @@ -42,6 +42,14 @@ type cluster struct { // healthCheckInterval is the interval by which we probe down nodes for // availability again. healthCheckInterval time.Duration + // done is a channel signaling to all pending health-checking routines + // that it's time to shut down. + done chan struct{} + // isDone is used to guarantee thread-safety when calling Stop(). + isDone bool + // healthCheckWg is a sync.Workgroup sychronizing the successful + // termination of all pending health-check routines. + healthCheckWg sync.WaitGroup } // member represents an individual endpoint @@ -100,9 +108,23 @@ func newCluster(client *httpClient, marathonURL string, isDCOS bool) (*cluster, client: client, members: members, healthCheckInterval: 5 * time.Second, + done: make(chan struct{}), }, nil } +// Stop gracefully terminates the cluster. It returns once all health-checking +// goroutines have finished. +func (c *cluster) Stop() { + c.Lock() + defer c.Unlock() + if c.isDone { + return + } + c.isDone = true + close(c.done) + c.healthCheckWg.Wait() +} + // retrieve the current member, i.e. the current endpoint in use func (c *cluster) getMember() (string, error) { c.RLock() @@ -125,7 +147,11 @@ func (c *cluster) markDown(endpoint string) { // nodes status ensures the multiple calls don't create multiple checks if n.status == memberStatusUp && n.endpoint == endpoint { n.status = memberStatusDown - go c.healthCheckNode(n) + c.healthCheckWg.Add(1) + go func() { + defer c.healthCheckWg.Done() + c.healthCheckNode(n) + }() break } } @@ -136,16 +162,21 @@ func (c *cluster) healthCheckNode(node *member) { // step: wait for the node to become active ... we are assuming a /ping is enough here ticker := time.NewTicker(c.healthCheckInterval) defer ticker.Stop() - for range ticker.C { - req, err := c.client.buildMarathonRequest("GET", node.endpoint, "ping", nil) - if err == nil { - res, err := c.client.Do(req) - if err == nil && res.StatusCode == 200 { - // step: mark the node as active again - c.Lock() - node.status = memberStatusUp - c.Unlock() - break + for { + select { + case <-c.done: + return + case <-ticker.C: + req, err := c.client.buildMarathonRequest("GET", node.endpoint, "ping", nil) + if err == nil { + res, err := c.client.Do(req) + if err == nil && res.StatusCode == 200 { + // step: mark the node as active again + c.Lock() + node.status = memberStatusUp + c.Unlock() + break + } } } } diff --git a/subscription_test.go b/subscription_test.go index c940285..509904d 100644 --- a/subscription_test.go +++ b/subscription_test.go @@ -385,7 +385,7 @@ func TestConnectToSSEFailure(t *testing.T) { config := configContainer{client: &clientCfg} endpoint := newFakeMarathonEndpoint(t, &config) - endpoint.Close() + endpoint.CloseServer() client := endpoint.Client.(*marathonClient) @@ -425,7 +425,7 @@ func TestRegisterSEESubscriptionReconnectsStreamOnError(t *testing.T) { time.Sleep(SSEConnectWaitTime) // This should make the SSE subscription fail and reconnect to another cluster member - endpoint1.Close() + endpoint1.CloseServer() // Give it a bit of time so that the subscription can reconnect time.Sleep(SSEConnectWaitTime) diff --git a/testing_utils_test.go b/testing_utils_test.go index a427f68..0822f41 100644 --- a/testing_utils_test.go +++ b/testing_utils_test.go @@ -319,5 +319,10 @@ func (s *fakeServer) Close() { } func (e *endpoint) Close() { + e.Client.Stop() + e.CloseServer() +} + +func (e *endpoint) CloseServer() { e.Server.Close() }