diff --git a/client.go b/client.go index 0879c4c..a2f9779 100644 --- a/client.go +++ b/client.go @@ -3,13 +3,14 @@ package metadata import ( "context" "fmt" - "github.com/go-resty/resty/v2" "net/http" "net/url" "os" "path" "strconv" "time" + + "github.com/go-resty/resty/v2" ) const APIHost = "169.254.169.254" @@ -17,16 +18,14 @@ const APIProto = "http" const APIVersion = "v1" type Client struct { - resty *resty.Client - - apiBaseURL string - apiProtocol string - apiVersion string - userAgent string - - managedToken bool - managedTokenOpts []TokenOption managedTokenExpiry time.Time + resty *resty.Client + apiBaseURL string + apiProtocol string + apiVersion string + userAgent string + managedTokenOpts []TokenOption + managedToken bool } // NewClient creates a new Metadata API client configured diff --git a/test/integration/go.mod b/test/integration/go.mod index 6285dce..b20074d 100644 --- a/test/integration/go.mod +++ b/test/integration/go.mod @@ -15,6 +15,7 @@ require github.com/stretchr/testify v1.8.4 // indirect require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/go-resty/resty/v2 v2.7.0 // indirect + github.com/jarcoal/httpmock v1.3.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect golang.org/x/net v0.15.0 // indirect golang.org/x/text v0.13.0 // indirect diff --git a/test/integration/go.sum b/test/integration/go.sum index 72335d0..242c71c 100644 --- a/test/integration/go.sum +++ b/test/integration/go.sum @@ -4,6 +4,8 @@ github.com/go-resty/resty/v2 v2.7.0 h1:me+K9p3uhSmXtrBZ4k9jcEAfJmuC8IivWHwaLZwPr github.com/go-resty/resty/v2 v2.7.0/go.mod h1:9PWDzw47qPphMRFfhsyk0NnSgvluHcljSMVIq3w7q0I= github.com/google/go-cmp v0.5.7 h1:81/ik6ipDQS2aGcBfIN5dHDB36BwrStyeAQquSYCV4o= github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= +github.com/jarcoal/httpmock v1.3.1 h1:iUx3whfZWVf3jT01hQTO/Eo5sAYtB2/rqaUuOtpInww= +github.com/jarcoal/httpmock v1.3.1/go.mod h1:3yb8rc4BI7TCBhFY8ng0gjuLKJNquuDNiPaZjnENuYg= github.com/linode/linodego v1.23.0 h1:s0ReCZtuN9Z1IoUN9w1RLeYO1dMZUGPwOQ/IBFsBHtU= github.com/linode/linodego v1.23.0/go.mod h1:0U7wj/UQOqBNbKv1FYTXiBUXueR8DY4HvIotwE0ENgg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/test/integration/helper_test.go b/test/integration/helper_test.go index 6c16260..7828c3b 100644 --- a/test/integration/helper_test.go +++ b/test/integration/helper_test.go @@ -2,10 +2,11 @@ package integration import ( "context" - "github.com/linode/go-metadata" - "github.com/linode/linodego" "log" "os" + + "github.com/linode/go-metadata" + "github.com/linode/linodego" ) var testToken = os.Getenv("LINODE_TOKEN") diff --git a/test/integration/watcher_test.go b/test/integration/watcher_test.go new file mode 100644 index 0000000..755879a --- /dev/null +++ b/test/integration/watcher_test.go @@ -0,0 +1,131 @@ +package integration + +import ( + "context" + "math/rand" + "net/http" + "testing" + "time" + + "github.com/jarcoal/httpmock" + "github.com/linode/go-metadata" + "github.com/stretchr/testify/assert" +) + +func TestNetworkWatcher(t *testing.T) { + t.Parallel() + ctx := context.Background() + httpClient := &http.Client{} + httpmock.ActivateNonDefault(httpClient) + defer httpmock.DeactivateAndReset() + // since we use a hacked httpClient, we need to mock all calls we make + httpmock.RegisterResponder("PUT", "http://169.254.169.254/v1/token", func(req *http.Request) (*http.Response, error) { + return httpmock.NewJsonResponse(200, []string{ + "4fa1a6d669087162e7d65b36f8750c994ce4395b3e9cccea8924466819811004", + }) + }) + + httpmock.RegisterResponder("GET", "http://169.254.169.254/v1/network", + func(req *http.Request) (*http.Response, error) { + randomNumber := rand.Int() + response := map[string]any{ + "interfaces": []string{}, + "ipv4": map[string]any{ + "public": []string{"172.233.211.141/32"}, + "private": []string{}, + "shared": []string{}, + }, + "ipv6": map[string]any{ + "slaac": "2600:3c06::f03c:93ff:fe98:0e4c/128", + "ranges": []string{}, + "link_local": "fe80::f03c:93ff:fe98:0e4c/128", + "shared_ranges": []string{}, + }, + } + if randomNumber%2 == 0 { + response["ipv4"].(map[string]any)["public"] = []string{"172.233.211.142/32"} + return httpmock.NewJsonResponse(200, response) + } else { + response["ipv4"].(map[string]any)["public"] = []string{"172.233.211.141/32"} + return httpmock.NewJsonResponse(200, response) + } + return httpmock.NewJsonResponse(200, response) + }) + + metadataClient, err := metadata.NewClient(ctx, metadata.ClientWithHTTPClient(httpClient)) + assert.NoError(t, err) + + watcher := metadataClient.NewNetworkWatcher(metadata.WatcherWithInterval(1 * time.Second)) + watcher.Start(ctx) + numUpdates := 0 + for i := 1; i <= 5; i++ { + updateData := <-watcher.Updates + if updateData != nil { + t.Logf("Changed IPv4: %s", updateData.IPv4.Public[0].String()) + numUpdates += 1 + } + time.Sleep(1 * time.Second) + } + assert.GreaterOrEqual(t, numUpdates, 3) // interval is 1 sec + watcher.Close() +} + +func TestInstanceWatcher(t *testing.T) { + t.Parallel() + ctx := context.Background() + httpClient := &http.Client{} + httpmock.ActivateNonDefault(httpClient) + defer httpmock.DeactivateAndReset() + // since we use a hacked httpClient, we need to mock all calls we make + httpmock.RegisterResponder("PUT", "http://169.254.169.254/v1/token", func(req *http.Request) (*http.Response, error) { + return httpmock.NewJsonResponse(200, []string{ + "4fa1a6d669087162e7d65b36f8750c994ce4395b3e9cccea8924466819811004", + }) + }) + httpmock.RegisterResponder("GET", "http://169.254.169.254/v1/instance", + func(req *http.Request) (*http.Response, error) { + randomNumber := rand.Int() + response := map[string]any{ + "backups": map[string]any{ + "enabled": true, + "status": "completed", + }, + "host_uuid": "isthisauuid", + "id": 51438702, + "label": "dev-us-ord", + "region": "us-ord", + "specs": map[string]int{ + "disk": 327680, + "gpus": 0, + "memory": 16384, + "transfer": 6000, + "vcpus": 8, + }, + "type": "g6-dedicated-8", + } + if randomNumber%2 == 0 { + response["label"] = "even" + return httpmock.NewJsonResponse(200, response) + } else { + response["label"] = "odd" + return httpmock.NewJsonResponse(200, response) + } + }) + + metadataClient, err := metadata.NewClient(ctx, metadata.ClientWithHTTPClient(httpClient)) + assert.NoError(t, err) + + watcher := metadataClient.NewInstanceWatcher(metadata.WatcherWithInterval(1 * time.Second)) + watcher.Start(ctx) + numUpdates := 0 + for i := 1; i <= 5; i++ { + updateData := <-watcher.Updates + if updateData != nil { + t.Logf("Changed Label: %s", updateData.Label) + numUpdates += 1 + } + time.Sleep(1 * time.Second) + } + assert.GreaterOrEqual(t, numUpdates, 4) // interval is 1 sec + watcher.Close() +} diff --git a/watcher.go b/watcher.go new file mode 100644 index 0000000..2c7d18a --- /dev/null +++ b/watcher.go @@ -0,0 +1,143 @@ +package metadata + +import ( + "context" + "reflect" + "time" +) + +const DefaultWatcherInterval = 5 * time.Minute + +type NetworkWatcher struct { + Updates chan *NetworkData + Errors chan error + cancel chan struct{} + client *Client + interval time.Duration + ticker *time.Ticker +} + +func (watcher *NetworkWatcher) Start(ctx context.Context) { + go func() { + var oldNetworkData *NetworkData + watcher.ticker = time.NewTicker(watcher.interval) + defer watcher.ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-watcher.cancel: + return + case <-watcher.ticker.C: + networkData, err := watcher.client.GetNetwork(ctx) + if err != nil { + watcher.Errors <- err + } + if !reflect.DeepEqual(networkData, oldNetworkData) { + watcher.Updates <- networkData + oldNetworkData = networkData + } + } + } + }() +} + +func (watcher *NetworkWatcher) Close() error { + close(watcher.cancel) + close(watcher.Errors) + close(watcher.Updates) + watcher.ticker.Stop() + return nil +} + +type InstanceWatcher struct { + Updates chan *InstanceData + Errors chan error + cancel chan struct{} + client *Client + interval time.Duration + ticker *time.Ticker +} + +func (watcher *InstanceWatcher) Start(ctx context.Context) { + go func() { + var oldInstanceData *InstanceData + watcher.ticker = time.NewTicker(watcher.interval) + defer watcher.ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-watcher.cancel: + return + case <-watcher.ticker.C: + instanceData, err := watcher.client.GetInstance(ctx) + if err != nil { + watcher.Errors <- err + } + if !reflect.DeepEqual(instanceData, oldInstanceData) { // Todo Testing + watcher.Updates <- instanceData + oldInstanceData = instanceData + } + } + } + }() +} + +func (watcher *InstanceWatcher) Close() error { + close(watcher.cancel) + close(watcher.Errors) + close(watcher.Updates) + watcher.ticker.Stop() + return nil +} + +type WatcherOption func(options *watcherConfig) + +type watcherConfig struct { + Interval time.Duration +} + +func (c *Client) NewInstanceWatcher(opts ...WatcherOption) *InstanceWatcher { + watcherOpts := watcherConfig{ + Interval: DefaultWatcherInterval, + } + + for _, opt := range opts { + opt(&watcherOpts) + } + + return &InstanceWatcher{ + Updates: make(chan *InstanceData), + Errors: make(chan error), + cancel: make(chan struct{}), + interval: watcherOpts.Interval, + client: c, + } +} + +func (c *Client) NewNetworkWatcher(opts ...WatcherOption) *NetworkWatcher { + watcherOpts := watcherConfig{ + Interval: DefaultWatcherInterval, + } + + for _, opt := range opts { + opt(&watcherOpts) + } + + return &NetworkWatcher{ + Updates: make(chan *NetworkData), + Errors: make(chan error), + cancel: make(chan struct{}), + interval: watcherOpts.Interval, + client: c, + } +} + +func WatcherWithInterval(duration time.Duration) WatcherOption { + return func(options *watcherConfig) { + options.Interval = duration + } +}