From 327579c6b3496d4afbb902a4834e8447ef6f5f4d Mon Sep 17 00:00:00 2001 From: Michael Lavers Date: Thu, 27 May 2021 18:32:49 -0700 Subject: [PATCH] Use Context for Timeout Handling (#72) * Use Context for Timeout Handling * Cancel inflight request during shutdown/timeout * Refactoring timeout handling to use context timeouts * Adding lifecycle unit tests * Fix blocking logs channel issue * Adding more lifecycle tests * Tidy up go modules * Stop sharing HandlerFunc across log servers (breaks tests) * Move log server host override to config * Add tests for client error handling * Adding tests to improve coverage * Add InitError and ExitError tests * Coverage is almost there, fewmore tests * Panic on 500 Internal Server Error, otherwise continue --- checks/agent_version_check.go | 3 +- checks/agent_version_check_test.go | 8 +- checks/handler_check.go | 3 +- checks/handler_check_test.go | 8 +- checks/sanity_check.go | 5 +- checks/sanity_check_test.go | 20 +- checks/startup_check.go | 18 +- checks/startup_check_test.go | 16 +- checks/vendor_check.go | 3 +- checks/vendor_check_test.go | 11 +- config/config.go | 21 +- config/config_test.go | 11 +- coverage.sh | 6 + credentials/credentials.go | 9 +- credentials/credentials_test.go | 21 +- go.sum | 27 +- lambda/extension/api/api.go | 6 +- lambda/extension/api/api_test.go | 3 +- lambda/extension/client/client.go | 74 +++- lambda/extension/client/client_test.go | 340 +++++++++++++++- lambda/logserver/logserver.go | 32 +- lambda/logserver/logserver_test.go | 20 +- main.go | 320 +++++++++------ main_test.go | 537 +++++++++++++++++++++++++ telemetry/batch.go | 4 +- telemetry/batch_test.go | 3 +- telemetry/client.go | 9 +- telemetry/client_test.go | 25 +- telemetry/request.go | 5 +- util/logger.go | 4 + 30 files changed, 1307 insertions(+), 265 deletions(-) create mode 100644 main_test.go diff --git a/checks/agent_version_check.go b/checks/agent_version_check.go index 570ca9e..5dea9b1 100644 --- a/checks/agent_version_check.go +++ b/checks/agent_version_check.go @@ -1,6 +1,7 @@ package checks import ( + "context" "encoding/json" "fmt" "io/ioutil" @@ -18,7 +19,7 @@ type LayerAgentVersion struct { // We are only returning an error message when an out of date agent version is detected. // All other errors will result in a nil return value. -func agentVersionCheck(conf *config.Configuration, reg *api.RegistrationResponse, r runtimeConfig) error { +func agentVersionCheck(ctx context.Context, conf *config.Configuration, reg *api.RegistrationResponse, r runtimeConfig) error { if r.AgentVersion == "" { return nil } diff --git a/checks/agent_version_check_test.go b/checks/agent_version_check_test.go index 77fe715..7ba2379 100644 --- a/checks/agent_version_check_test.go +++ b/checks/agent_version_check_test.go @@ -1,6 +1,7 @@ package checks import ( + "context" "os" "path/filepath" "testing" @@ -14,9 +15,10 @@ func TestAgentVersion(t *testing.T) { conf := config.Configuration{} reg := api.RegistrationResponse{} r := runtimeConfig{} + ctx := context.Background() // No version set - err := agentVersionCheck(&conf, ®, r) + err := agentVersionCheck(ctx, &conf, ®, r) assert.Nil(t, err) // Error @@ -33,11 +35,11 @@ func TestAgentVersion(t *testing.T) { f, _ := os.Create(filepath.Join(testFile, r.agentVersionFile)) f.WriteString("10.1.0") - err = agentVersionCheck(&conf, ®, r) + err = agentVersionCheck(ctx, &conf, ®, r) assert.EqualError(t, err, "Agent version out of date: v10.1.0, in order access up to date features please upgrade to the latest New Relic python layer that includes agent version v10.1.2") // Success r.AgentVersion = "10.1.0" - err = agentVersionCheck(&conf, ®, r) + err = agentVersionCheck(ctx, &conf, ®, r) assert.Nil(t, err) } diff --git a/checks/handler_check.go b/checks/handler_check.go index 0384eaa..73d120a 100644 --- a/checks/handler_check.go +++ b/checks/handler_check.go @@ -1,6 +1,7 @@ package checks import ( + "context" "fmt" "strings" @@ -16,7 +17,7 @@ type handlerConfigs struct { var handlerPath = "/var/task" -func checkHandler(conf *config.Configuration, reg *api.RegistrationResponse, r runtimeConfig) error { +func handlerCheck(ctx context.Context, conf *config.Configuration, reg *api.RegistrationResponse, r runtimeConfig) error { if r.language != "" { h := handlerConfigs{ handlerName: reg.Handler, diff --git a/checks/handler_check_test.go b/checks/handler_check_test.go index b125933..deed06a 100644 --- a/checks/handler_check_test.go +++ b/checks/handler_check_test.go @@ -1,6 +1,7 @@ package checks import ( + "context" "os" "path/filepath" "testing" @@ -57,15 +58,16 @@ func TestHandlerCheck(t *testing.T) { conf := config.Configuration{} reg := api.RegistrationResponse{} r := runtimeConfigs[Node] + ctx := context.Background() // No Runtime - err := checkHandler(&conf, ®, runtimeConfig{}) + err := handlerCheck(ctx, &conf, ®, runtimeConfig{}) assert.Nil(t, err) // Error reg.Handler = testHandler conf.NRHandler = config.EmptyNRWrapper - err = checkHandler(&conf, ®, r) + err = handlerCheck(ctx, &conf, ®, r) assert.EqualError(t, err, "Missing handler file path/to/app.handler (NEW_RELIC_LAMBDA_HANDLER=Undefined)") // Success @@ -79,6 +81,6 @@ func TestHandlerCheck(t *testing.T) { reg.Handler = testHandler conf.NRHandler = config.EmptyNRWrapper - err = checkHandler(&conf, ®, r) + err = handlerCheck(ctx, &conf, ®, r) assert.Nil(t, err) } diff --git a/checks/sanity_check.go b/checks/sanity_check.go index d714b38..64dd066 100644 --- a/checks/sanity_check.go +++ b/checks/sanity_check.go @@ -1,6 +1,7 @@ package checks import ( + "context" "fmt" "github.com/newrelic/newrelic-lambda-extension/config" @@ -21,12 +22,12 @@ var ( ) // sanityCheck checks for configuration that is either misplaced or in conflict -func sanityCheck(conf *config.Configuration, res *api.RegistrationResponse, _ runtimeConfig) error { +func sanityCheck(ctx context.Context, conf *config.Configuration, res *api.RegistrationResponse, _ runtimeConfig) error { if util.AnyEnvVarsExist(awsLogIngestionEnvVars) { return fmt.Errorf("Environment varaible '%s' is used by aws-log-ingestion and has no effect here. Recommend unsetting this environment variable within this function.", util.AnyEnvVarsExistString(awsLogIngestionEnvVars)) } - if credentials.IsSecretConfigured(conf) && util.EnvVarExists("NEW_RELIC_LICENSE_KEY") { + if credentials.IsSecretConfigured(ctx, conf) && util.EnvVarExists("NEW_RELIC_LICENSE_KEY") { return fmt.Errorf("There is both a AWS Secrets Manager secret and a NEW_RELIC_LICENSE_KEY environment variable set. Recommend removing the NEW_RELIC_LICENSE_KEY environment variable and using the AWS Secrets Manager secret.") } diff --git a/checks/sanity_check_test.go b/checks/sanity_check_test.go index 7788d52..0919d9c 100644 --- a/checks/sanity_check_test.go +++ b/checks/sanity_check_test.go @@ -1,11 +1,13 @@ package checks import ( + "context" "fmt" "os" "testing" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/secretsmanager" "github.com/aws/aws-sdk-go/service/secretsmanager/secretsmanageriface" "github.com/newrelic/newrelic-lambda-extension/config" @@ -19,7 +21,7 @@ type mockSecretManager struct { secretsmanageriface.SecretsManagerAPI } -func (mockSecretManager) GetSecretValue(*secretsmanager.GetSecretValueInput) (*secretsmanager.GetSecretValueOutput, error) { +func (mockSecretManager) GetSecretValueWithContext(context.Context, *secretsmanager.GetSecretValueInput, ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { return &secretsmanager.GetSecretValueOutput{ SecretString: aws.String(`{"LicenseKey": "foo"}`), }, nil @@ -29,30 +31,32 @@ type mockSecretManagerErr struct { secretsmanageriface.SecretsManagerAPI } -func (mockSecretManagerErr) GetSecretValue(*secretsmanager.GetSecretValueInput) (*secretsmanager.GetSecretValueOutput, error) { +func (mockSecretManagerErr) GetSecretValueWithContext(context.Context, *secretsmanager.GetSecretValueInput, ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { return nil, fmt.Errorf("Something went wrong") } func TestSanityCheck(t *testing.T) { + ctx := context.Background() + if util.AnyEnvVarsExist(awsLogIngestionEnvVars) { - assert.Error(t, sanityCheck(&config.Configuration{}, &api.RegistrationResponse{}, runtimeConfig{})) + assert.Error(t, sanityCheck(ctx, &config.Configuration{}, &api.RegistrationResponse{}, runtimeConfig{})) } else { - assert.Nil(t, sanityCheck(&config.Configuration{}, &api.RegistrationResponse{}, runtimeConfig{})) + assert.Nil(t, sanityCheck(ctx, &config.Configuration{}, &api.RegistrationResponse{}, runtimeConfig{})) } os.Setenv("DEBUG_LOGGING_ENABLED", "1") - assert.Error(t, sanityCheck(&config.Configuration{}, &api.RegistrationResponse{}, runtimeConfig{})) + assert.Error(t, sanityCheck(ctx, &config.Configuration{}, &api.RegistrationResponse{}, runtimeConfig{})) os.Unsetenv("DEBUG_LOGGING_ENABLED") os.Unsetenv("NEW_RELIC_LICENSE_KEY") credentials.OverrideSecretsManager(&mockSecretManager{}) - assert.Nil(t, sanityCheck(&config.Configuration{}, &api.RegistrationResponse{}, runtimeConfig{})) + assert.Nil(t, sanityCheck(ctx, &config.Configuration{}, &api.RegistrationResponse{}, runtimeConfig{})) os.Setenv("NEW_RELIC_LICENSE_KEY", "foobar") defer os.Unsetenv("NEW_RELIC_LICENSE_KEY") credentials.OverrideSecretsManager(&mockSecretManager{}) - assert.Error(t, sanityCheck(&config.Configuration{}, &api.RegistrationResponse{}, runtimeConfig{})) + assert.Error(t, sanityCheck(ctx, &config.Configuration{}, &api.RegistrationResponse{}, runtimeConfig{})) credentials.OverrideSecretsManager(&mockSecretManagerErr{}) - assert.Nil(t, sanityCheck(&config.Configuration{}, &api.RegistrationResponse{}, runtimeConfig{})) + assert.Nil(t, sanityCheck(ctx, &config.Configuration{}, &api.RegistrationResponse{}, runtimeConfig{})) } diff --git a/checks/startup_check.go b/checks/startup_check.go index ca57115..fa8d667 100644 --- a/checks/startup_check.go +++ b/checks/startup_check.go @@ -1,6 +1,7 @@ package checks import ( + "context" "fmt" "time" @@ -10,21 +11,21 @@ import ( "github.com/newrelic/newrelic-lambda-extension/util" ) -type checkFn func(*config.Configuration, *api.RegistrationResponse, runtimeConfig) error +type checkFn func(context.Context, *config.Configuration, *api.RegistrationResponse, runtimeConfig) error type LogSender interface { - SendFunctionLogs(lines []logserver.LogLine) error + SendFunctionLogs(ctx context.Context, lines []logserver.LogLine) error } /// Register checks here var checks = []checkFn{ agentVersionCheck, - checkHandler, + handlerCheck, sanityCheck, vendorCheck, } -func RunChecks(conf *config.Configuration, reg *api.RegistrationResponse, logSender LogSender) { +func RunChecks(ctx context.Context, conf *config.Configuration, reg *api.RegistrationResponse, logSender LogSender) { runtimeConfig, err := checkAndReturnRuntime() if err != nil { errLog := fmt.Sprintf("There was an issue querying for the latest agent version: %v", err) @@ -32,19 +33,18 @@ func RunChecks(conf *config.Configuration, reg *api.RegistrationResponse, logSen } for _, check := range checks { - runCheck(conf, reg, runtimeConfig, logSender, check) + runCheck(ctx, conf, reg, runtimeConfig, logSender, check) } } -func runCheck(conf *config.Configuration, reg *api.RegistrationResponse, r runtimeConfig, logSender LogSender, check checkFn) error { - err := check(conf, reg, r) - +func runCheck(ctx context.Context, conf *config.Configuration, reg *api.RegistrationResponse, r runtimeConfig, logSender LogSender, check checkFn) error { + err := check(ctx, conf, reg, r) if err != nil { errLog := fmt.Sprintf("Startup check failed: %v", err) util.Logln(errLog) //Send a log line to NR as well - logSender.SendFunctionLogs([]logserver.LogLine{ + logSender.SendFunctionLogs(ctx, []logserver.LogLine{ { Time: time.Now(), RequestID: "0", diff --git a/checks/startup_check_test.go b/checks/startup_check_test.go index aad4c3d..9b14fad 100644 --- a/checks/startup_check_test.go +++ b/checks/startup_check_test.go @@ -1,6 +1,7 @@ package checks import ( + "context" "fmt" "testing" @@ -14,7 +15,7 @@ type TestLogSender struct { sent []logserver.LogLine } -func (c *TestLogSender) SendFunctionLogs(lines []logserver.LogLine) error { +func (c *TestLogSender) SendFunctionLogs(ctx context.Context, lines []logserver.LogLine) error { c.sent = append(c.sent, lines...) return nil } @@ -24,14 +25,15 @@ func TestRunCheck(t *testing.T) { resp := api.RegistrationResponse{} r := runtimeConfig{} client := TestLogSender{} + ctx := context.Background() tested := false - testCheck := func(conf *config.Configuration, resp *api.RegistrationResponse, r runtimeConfig) error { + testCheck := func(ctx context.Context, conf *config.Configuration, resp *api.RegistrationResponse, r runtimeConfig) error { tested = true return nil } - result := runCheck(&conf, &resp, r, &client, testCheck) + result := runCheck(ctx, &conf, &resp, r, &client, testCheck) assert.Equal(t, true, tested) assert.Nil(t, result) @@ -42,14 +44,15 @@ func TestRunCheckErr(t *testing.T) { resp := api.RegistrationResponse{} r := runtimeConfig{} logSender := TestLogSender{} + ctx := context.Background() tested := false - testCheck := func(conf *config.Configuration, resp *api.RegistrationResponse, r runtimeConfig) error { + testCheck := func(ctx context.Context, conf *config.Configuration, resp *api.RegistrationResponse, r runtimeConfig) error { tested = true return fmt.Errorf("Failure Test") } - result := runCheck(&conf, &resp, r, &logSender, testCheck) + result := runCheck(ctx, &conf, &resp, r, &logSender, testCheck) assert.Equal(t, true, tested) assert.NotNil(t, result) @@ -64,5 +67,6 @@ func TestRunChecks(t *testing.T) { client = &mockClientError{} - RunChecks(c, r, l) + ctx := context.Background() + RunChecks(ctx, c, r, l) } diff --git a/checks/vendor_check.go b/checks/vendor_check.go index 7569f32..59b12ac 100644 --- a/checks/vendor_check.go +++ b/checks/vendor_check.go @@ -1,6 +1,7 @@ package checks import ( + "context" "fmt" "github.com/newrelic/newrelic-lambda-extension/config" @@ -10,7 +11,7 @@ import ( // vendorCheck checks to see if the user included a vendored copy of the agent along // with their function while also using a layer that includes the agent -func vendorCheck(_ *config.Configuration, _ *api.RegistrationResponse, r runtimeConfig) error { +func vendorCheck(ctx context.Context, _ *config.Configuration, _ *api.RegistrationResponse, r runtimeConfig) error { if util.PathExists(r.vendorAgentPath) && util.AnyPathsExist(r.layerAgentPaths) { return fmt.Errorf("Vendored agent found at '%s', a layer already includes this agent at '%s'. Recommend using the layer agent to avoid unexpected agent behavior.", r.vendorAgentPath, util.AnyPathsExistString(r.layerAgentPaths)) diff --git a/checks/vendor_check_test.go b/checks/vendor_check_test.go index 39f43b0..9e037cd 100644 --- a/checks/vendor_check_test.go +++ b/checks/vendor_check_test.go @@ -1,6 +1,7 @@ package checks import ( + "context" "testing" "github.com/newrelic/newrelic-lambda-extension/config" @@ -10,24 +11,24 @@ import ( ) func TestVendorCheck(t *testing.T) { - n := runtimeConfigs[Node] + ctx := context.Background() if !util.AnyPathsExist(n.layerAgentPaths) && !util.PathExists(n.vendorAgentPath) { - assert.Nil(t, vendorCheck(&config.Configuration{}, &api.RegistrationResponse{}, n)) + assert.Nil(t, vendorCheck(ctx, &config.Configuration{}, &api.RegistrationResponse{}, n)) } if util.PathExists(n.layerAgentPaths[0]) && util.PathExists(n.vendorAgentPath) { - assert.Error(t, vendorCheck(&config.Configuration{}, &api.RegistrationResponse{}, n)) + assert.Error(t, vendorCheck(ctx, &config.Configuration{}, &api.RegistrationResponse{}, n)) } p := runtimeConfigs[Python] if !util.AnyPathsExist(p.layerAgentPaths) && !util.PathExists(p.vendorAgentPath) { - assert.Nil(t, vendorCheck(&config.Configuration{}, &api.RegistrationResponse{}, n)) + assert.Nil(t, vendorCheck(ctx, &config.Configuration{}, &api.RegistrationResponse{}, n)) } if util.AnyPathsExist(p.layerAgentPaths) && util.PathExists(p.vendorAgentPath) { - assert.Error(t, vendorCheck(&config.Configuration{}, &api.RegistrationResponse{}, n)) + assert.Error(t, vendorCheck(ctx, &config.Configuration{}, &api.RegistrationResponse{}, n)) } } diff --git a/config/config.go b/config/config.go index 6a94b31..e7bd586 100644 --- a/config/config.go +++ b/config/config.go @@ -7,10 +7,11 @@ import ( ) const ( - DefaultRipeMillis = 7_000 - DefaultRotMillis = 12_000 - DefaultLogLevel = "INFO" - DebugLogLevel = "DEBUG" + DefaultRipeMillis = 7_000 + DefaultRotMillis = 12_000 + DefaultLogLevel = "INFO" + DebugLogLevel = "DEBUG" + defaultLogServerHost = "sandbox.localdomain" ) var EmptyNRWrapper = "Undefined" @@ -26,9 +27,10 @@ type Configuration struct { RotMillis uint32 LogLevel string SendFunctionLogs bool + LogServerHost string } -func ConfigurationFromEnvironment() Configuration { +func ConfigurationFromEnvironment() *Configuration { enabledStr, extensionEnabledOverride := os.LookupEnv("NEW_RELIC_LAMBDA_EXTENSION_ENABLED") licenseKey, lkOverride := os.LookupEnv("NEW_RELIC_LICENSE_KEY") licenseKeySecretId, lkSecretOverride := os.LookupEnv("NEW_RELIC_LICENSE_KEY_SECRET") @@ -39,12 +41,13 @@ func ConfigurationFromEnvironment() Configuration { rotMillisStr, rotMillisOverride := os.LookupEnv("NEW_RELIC_HARVEST_ROT_MILLIS") logLevelStr, logLevelOverride := os.LookupEnv("NEW_RELIC_EXTENSION_LOG_LEVEL") sendFunctionLogsStr, sendFunctionLogsOverride := os.LookupEnv("NEW_RELIC_EXTENSION_SEND_FUNCTION_LOGS") + logServerHostStr, logServerHostOverride := os.LookupEnv("NEW_RELIC_LOG_SERVER_HOST") extensionEnabled := true if extensionEnabledOverride && strings.ToLower(enabledStr) == "false" { extensionEnabled = false } - ret := Configuration{ExtensionEnabled: extensionEnabled} + ret := &Configuration{ExtensionEnabled: extensionEnabled} if lkOverride { ret.LicenseKey = licenseKey @@ -94,6 +97,12 @@ func ConfigurationFromEnvironment() Configuration { ret.LogLevel = DefaultLogLevel } + if logServerHostOverride { + ret.LogServerHost = logServerHostStr + } else { + ret.LogServerHost = defaultLogServerHost + } + if sendFunctionLogsOverride && sendFunctionLogsStr == "true" { ret.SendFunctionLogs = true } diff --git a/config/config_test.go b/config/config_test.go index 8616fde..f1a2536 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -9,12 +9,13 @@ import ( func TestConfigurationFromEnvironmentZero(t *testing.T) { conf := ConfigurationFromEnvironment() - expected := Configuration{ + expected := &Configuration{ ExtensionEnabled: true, RipeMillis: DefaultRipeMillis, RotMillis: DefaultRotMillis, LogLevel: DefaultLogLevel, NRHandler: EmptyNRWrapper, + LogServerHost: defaultLogServerHost, } assert.Equal(t, expected, conf) } @@ -71,3 +72,11 @@ func TestConfigurationFromEnvironmentSecretId(t *testing.T) { conf := ConfigurationFromEnvironment() assert.Equal(t, "secretId", conf.LicenseKeySecretId) } + +func TestConfigurationFromEnvironmentLogServerHost(t *testing.T) { + os.Setenv("NEW_RELIC_LOG_SERVER_HOST", "foobar") + defer os.Unsetenv("NEW_RELIC_LOG_SERVER_HOST") + + conf := ConfigurationFromEnvironment() + assert.Equal(t, "foobar", conf.LogServerHost) +} diff --git a/coverage.sh b/coverage.sh index 494b176..930857e 100755 --- a/coverage.sh +++ b/coverage.sh @@ -3,6 +3,12 @@ set -e echo "" > coverage.txt +go test -coverprofile=profile.out -covermode=atomic main.go +if [ -f profile.out ]; then + cat profile.out >> coverage.txt + rm profile.out +fi + for d in $(go list ./... | grep -v vendor); do go test -coverprofile=profile.out -covermode=atomic $d if [ -f profile.out ]; then diff --git a/credentials/credentials.go b/credentials/credentials.go index 84ce39f..bc520ed 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -1,6 +1,7 @@ package credentials import ( + "context" "encoding/json" "fmt" "os" @@ -56,11 +57,11 @@ func decodeLicenseKey(rawJson *string) (string, error) { // IsSecretConfigured returns true if the Secrets Maanger secret is configured, false // otherwise -func IsSecretConfigured(conf *config.Configuration) bool { +func IsSecretConfigured(ctx context.Context, conf *config.Configuration) bool { secretId := getLicenseKeySecretId(conf) secretValueInput := secretsmanager.GetSecretValueInput{SecretId: &secretId} - _, err := secrets.GetSecretValue(&secretValueInput) + _, err := secrets.GetSecretValueWithContext(ctx, &secretValueInput) if err != nil { return false } @@ -70,7 +71,7 @@ func IsSecretConfigured(conf *config.Configuration) bool { // GetNewRelicLicenseKey fetches the license key from AWS Secrets Manager, falling back // to the NEW_RELIC_LICENSE_KEY environment variable if set. -func GetNewRelicLicenseKey(conf *config.Configuration) (string, error) { +func GetNewRelicLicenseKey(ctx context.Context, conf *config.Configuration) (string, error) { if conf.LicenseKey != "" { util.Logln("Using license key from environment variable") return conf.LicenseKey, nil @@ -79,7 +80,7 @@ func GetNewRelicLicenseKey(conf *config.Configuration) (string, error) { secretId := getLicenseKeySecretId(conf) secretValueInput := secretsmanager.GetSecretValueInput{SecretId: &secretId} - secretValueOutput, err := secrets.GetSecretValue(&secretValueInput) + secretValueOutput, err := secrets.GetSecretValueWithContext(ctx, &secretValueInput) if err != nil { envLicenseKey, found := os.LookupEnv(defaultSecretId) if found { diff --git a/credentials/credentials_test.go b/credentials/credentials_test.go index 477defb..baead03 100644 --- a/credentials/credentials_test.go +++ b/credentials/credentials_test.go @@ -1,6 +1,7 @@ package credentials import ( + "context" "fmt" "os" "testing" @@ -8,6 +9,7 @@ import ( "github.com/newrelic/newrelic-lambda-extension/config" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/secretsmanager" "github.com/aws/aws-sdk-go/service/secretsmanager/secretsmanageriface" "github.com/stretchr/testify/assert" @@ -27,7 +29,7 @@ type mockSecretManager struct { secretsmanageriface.SecretsManagerAPI } -func (mockSecretManager) GetSecretValue(*secretsmanager.GetSecretValueInput) (*secretsmanager.GetSecretValueOutput, error) { +func (mockSecretManager) GetSecretValueWithContext(context.Context, *secretsmanager.GetSecretValueInput, ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { return &secretsmanager.GetSecretValueOutput{ SecretString: aws.String(`{"LicenseKey": "foo"}`), }, nil @@ -37,40 +39,43 @@ type mockSecretManagerErr struct { secretsmanageriface.SecretsManagerAPI } -func (mockSecretManagerErr) GetSecretValue(*secretsmanager.GetSecretValueInput) (*secretsmanager.GetSecretValueOutput, error) { +func (mockSecretManagerErr) GetSecretValueWithContext(context.Context, *secretsmanager.GetSecretValueInput, ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { return nil, fmt.Errorf("Something went wrong") } func TestIsSecretConfigured(t *testing.T) { OverrideSecretsManager(mockSecretManager{}) - assert.True(t, IsSecretConfigured(&config.Configuration{})) + ctx := context.Background() + assert.True(t, IsSecretConfigured(ctx, &config.Configuration{})) OverrideSecretsManager(mockSecretManagerErr{}) - assert.False(t, IsSecretConfigured(&config.Configuration{})) + assert.False(t, IsSecretConfigured(ctx, &config.Configuration{})) } func TestGetNewRelicLicenseKey(t *testing.T) { OverrideSecretsManager(mockSecretManager{}) - lk, err := GetNewRelicLicenseKey(&config.Configuration{}) + ctx := context.Background() + lk, err := GetNewRelicLicenseKey(ctx, &config.Configuration{}) assert.Nil(t, err) assert.Equal(t, "foo", lk) os.Unsetenv("NEW_RELIC_LICENSE_KEY") OverrideSecretsManager(mockSecretManagerErr{}) - lk, err = GetNewRelicLicenseKey(&config.Configuration{}) + lk, err = GetNewRelicLicenseKey(ctx, &config.Configuration{}) assert.Error(t, err) assert.Empty(t, lk) os.Setenv("NEW_RELIC_LICENSE_KEY", "foobar") defer os.Unsetenv("NEW_RELIC_LICENSE_KEY") - lk, err = GetNewRelicLicenseKey(&config.Configuration{}) + lk, err = GetNewRelicLicenseKey(ctx, &config.Configuration{}) assert.Nil(t, err) assert.Equal(t, "foobar", lk) } func TestGetNewRelicLicenseKeyConfigValue(t *testing.T) { licenseKey := "test_value" - resultKey, err := GetNewRelicLicenseKey(&config.Configuration{ + ctx := context.Background() + resultKey, err := GetNewRelicLicenseKey(ctx, &config.Configuration{ LicenseKey: licenseKey, }) diff --git a/go.sum b/go.sum index af16189..b65d7a1 100644 --- a/go.sum +++ b/go.sum @@ -1,11 +1,8 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/aws/aws-lambda-go v1.11.0 h1:8nkgvOfMLeKMxglSR+sAkjqLGK0pWFK9e5qyu9rlf0s= github.com/aws/aws-lambda-go v1.11.0/go.mod h1:Rr2SMTLeSMKgD45uep9V/NP8tnbCcySgu04cx0k/6cw= github.com/aws/aws-lambda-go v1.19.1 h1:5iUHbIZ2sG6Yq/J1IN3sWm3+vAB1CWwhI21NffLNuNI= github.com/aws/aws-lambda-go v1.19.1/go.mod h1:jJmlefzPfGnckuHdXX7/80O3BvUUi12XOkbv4w9SGLU= -github.com/aws/aws-sdk-go v1.34.6 h1:2aPXQGkR6xeheN5dns13mSoDWeUlj4wDmfZ+8ZDHauw= -github.com/aws/aws-sdk-go v1.34.6/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0= github.com/aws/aws-sdk-go v1.34.21 h1:M97FXuiJgDHwD4mXhrIZ7RJ4xXV6uZVPvIC2qb+HfYE= github.com/aws/aws-sdk-go v1.34.21/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= @@ -13,7 +10,6 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= -github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -22,12 +18,10 @@ github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.m github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= -github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.3 h1:gyjaxf+svBWX08ZjK86iN9geUJF0H6gp2IRKX6Nf6/I= github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= @@ -37,29 +31,19 @@ github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvq github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0= github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= -github.com/google/go-cmp v0.2.0 h1:+dTQ8DZQJz0Mb/HjFlkptS1FeQ4cWSnN941F8aEG4SQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.0 h1:/QaMHBdZ26BB3SSst0Iwl10Epc+xhTquomWX0oZEB6w= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= -github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/jmespath/go-jmespath v0.3.0 h1:OS12ieG61fsCg5+qLJ+SsW9NicxNkg3b25OyT2yCeUc= github.com/jmespath/go-jmespath v0.3.0/go.mod h1:9QtRXoHjLGCJ5IBSaohpXITPlowMeeYCZ7fLUTSywik= github.com/newrelic/go-agent/v3 v3.4.0/go.mod h1:H28zDNUC0U/b7kLoY4EFOhuth10Xu/9dchozUiOseQQ= -github.com/newrelic/go-agent/v3 v3.8.1 h1:PzM7tOO7ojBxHxEXY/AQJ8bAKrM9vFeFbHPkTaVo8+Q= -github.com/newrelic/go-agent/v3 v3.8.1/go.mod h1:1A1dssWBwzB7UemzRU6ZVaGDsI+cEn5/bNxI0wiYlIc= -github.com/newrelic/go-agent/v3 v3.8.2-0.20200810213557-525b608484c8 h1:NEkGx0g/XmAfJlAqWU2SC2YdEssF6KlgSxbhNd6aMTA= -github.com/newrelic/go-agent/v3 v3.8.2-0.20200810213557-525b608484c8/go.mod h1:1A1dssWBwzB7UemzRU6ZVaGDsI+cEn5/bNxI0wiYlIc= github.com/newrelic/go-agent/v3 v3.9.0 h1:5bcTbdk/Up5cIYIkQjCG92Y+uNoett9wmhuz4kPiFlM= github.com/newrelic/go-agent/v3 v3.9.0/go.mod h1:1A1dssWBwzB7UemzRU6ZVaGDsI+cEn5/bNxI0wiYlIc= -github.com/newrelic/go-agent/v3/integrations/nrlambda v1.1.0 h1:EIklCcrNtE2gQtTELx0fgfuyvmOxpID1GRFsR3MBr9Y= -github.com/newrelic/go-agent/v3/integrations/nrlambda v1.1.0/go.mod h1:IZemD4LiJXNBAV652z2x3Awa1Z9Rlx7hEO4OUyqnr+U= -github.com/newrelic/go-agent/v3/integrations/nrlambda v1.1.1-0.20200810213557-525b608484c8 h1:EvAYDJU3XOb3bWd7L5mPRyxqFNVyky6o5/ERIGL4jy4= -github.com/newrelic/go-agent/v3/integrations/nrlambda v1.1.1-0.20200810213557-525b608484c8/go.mod h1:IZemD4LiJXNBAV652z2x3Awa1Z9Rlx7hEO4OUyqnr+U= github.com/newrelic/go-agent/v3/integrations/nrlambda v1.2.0 h1:pVLK1gx8YsOoI3EpEZ44HOL5GAnOVNkFx50ZJNKxUBk= github.com/newrelic/go-agent/v3/integrations/nrlambda v1.2.0/go.mod h1:IZemD4LiJXNBAV652z2x3Awa1Z9Rlx7hEO4OUyqnr+U= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -90,8 +74,6 @@ golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200625001655-4c5254603344 h1:vGXIOMxbNfDTk/aXCmfdLgkrSV+Z2tcbze+pEc3v5W4= -golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200904194848-62affa334b73 h1:MXfv8rhZWmFeqX3GNZRsd6vOLoaCHjYEX3qkRo3YBUA= golang.org/x/net v0.0.0-20200904194848-62affa334b73/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -101,11 +83,9 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd h1:xhmwyvizuTgC2qz7ZlMluP20uW+C3Rm0FD/WLDX8884= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200909081042-eff7692f9009 h1:W0lCpv29Hv0UaM1LXb9QlBHLNP8UFfcKjblhVCWftOM= golang.org/x/sys v0.0.0-20200909081042-eff7692f9009/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -117,11 +97,11 @@ golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBn golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55 h1:gSJIx1SDwno+2ElGhA4+qG2zF97qiUzTM+rQ0klBOcE= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= google.golang.org/genproto v0.0.0-20200910191746-8ad3c7ee2cd1 h1:Oi/dETbxPPblvoi4hgkzJun62A4dwuBsTM0UcZYpN3U= @@ -129,7 +109,6 @@ google.golang.org/genproto v0.0.0-20200910191746-8ad3c7ee2cd1/go.mod h1:FWY/as6D google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= -google.golang.org/grpc v1.27.0 h1:rRYRFMVgRv6E0D70Skyfsr28tDXIuuPZyWGMPdMcnXg= google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= google.golang.org/grpc v1.32.0 h1:zWTV+LMdc3kaiJMSTOFz2UgSBgx8RNQoTGiZu3fR9S0= google.golang.org/grpc v1.32.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= @@ -141,7 +120,6 @@ google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzi google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.24.0 h1:UhZDfRO8JRQru4/+LlLE0BRKGF8L+PICnvYZmx/fEGA= google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4= google.golang.org/protobuf v1.25.0 h1:Ejskq+SyPohKW+1uil0JJMtmHCgJPJ/qWTxr8qp+R4c= google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= @@ -149,7 +127,6 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/urfave/cli.v1 v1.20.0/go.mod h1:vuBzUtMdQeixQj8LVd+/98pzhxNGQoyuPBlsXHOQNO0= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776 h1:tQIYjPdBoyREyB9XMu+nnTclpTYkz2zFM+lzLJFO4gQ= gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/lambda/extension/api/api.go b/lambda/extension/api/api.go index ae9a9ac..9e1c77a 100644 --- a/lambda/extension/api/api.go +++ b/lambda/extension/api/api.go @@ -68,15 +68,15 @@ type LogSubscription struct { Types []LogEventType `json:"types"` } -func NewLogSubscription(bufferingCfg BufferingCfg, destinationCfg DestinationCfg, types []LogEventType) LogSubscription { - return LogSubscription{ +func NewLogSubscription(bufferingCfg BufferingCfg, destinationCfg DestinationCfg, types []LogEventType) *LogSubscription { + return &LogSubscription{ Buffering: bufferingCfg, Destination: destinationCfg, Types: types, } } -func DefaultLogSubscription(types []LogEventType, port uint16) LogSubscription { +func DefaultLogSubscription(types []LogEventType, port uint16) *LogSubscription { endpoint := formatLogsEndpoint(port) return NewLogSubscription( diff --git a/lambda/extension/api/api_test.go b/lambda/extension/api/api_test.go index 296d2aa..c539b56 100644 --- a/lambda/extension/api/api_test.go +++ b/lambda/extension/api/api_test.go @@ -1,8 +1,9 @@ package api import ( - "github.com/stretchr/testify/assert" "testing" + + "github.com/stretchr/testify/assert" ) func Test_formatLogsEndpoint(t *testing.T) { diff --git a/lambda/extension/client/client.go b/lambda/extension/client/client.go index f29eefc..61d4b6d 100644 --- a/lambda/extension/client/client.go +++ b/lambda/extension/client/client.go @@ -6,6 +6,7 @@ package client import ( "bytes" + "context" "encoding/json" "fmt" "io/ioutil" @@ -57,20 +58,20 @@ func (rc *RegistrationClient) getRegisterURL() string { } // RegisterDefault registers for Invoke and Shutdown events, with no configuration parameters. -func (rc *RegistrationClient) RegisterDefault() (*InvocationClient, *api.RegistrationResponse, error) { +func (rc *RegistrationClient) RegisterDefault(ctx context.Context) (*InvocationClient, *api.RegistrationResponse, error) { defaultEvents := []api.LifecycleEvent{api.Invoke, api.Shutdown} defaultRequest := api.RegistrationRequest{Events: defaultEvents} - return rc.Register(defaultRequest) + return rc.Register(ctx, defaultRequest) } // Register registers, with custom registration parameters. -func (rc *RegistrationClient) Register(registrationRequest api.RegistrationRequest) (*InvocationClient, *api.RegistrationResponse, error) { +func (rc *RegistrationClient) Register(ctx context.Context, registrationRequest api.RegistrationRequest) (*InvocationClient, *api.RegistrationResponse, error) { registrationRequestJson, err := json.Marshal(registrationRequest) if err != nil { return nil, nil, fmt.Errorf("error occurred while marshaling registration request %s", err) } - req, err := http.NewRequest("POST", rc.getRegisterURL(), bytes.NewBuffer(registrationRequestJson)) + req, err := http.NewRequestWithContext(ctx, "POST", rc.getRegisterURL(), bytes.NewBuffer(registrationRequestJson)) if err != nil { return nil, nil, fmt.Errorf("error occurred while creating registration request %s", err) } @@ -84,10 +85,19 @@ func (rc *RegistrationClient) Register(registrationRequest api.RegistrationReque defer util.Close(res.Body) + if res.StatusCode == http.StatusInternalServerError { + util.Panic("error occurred while making registration request: ", res.Status) + } + + if res.StatusCode != http.StatusOK { + return nil, nil, fmt.Errorf("error occurred while making registration request: %s", res.Status) + } + bodyBytes, err := ioutil.ReadAll(res.Body) if err != nil { return nil, nil, err } + util.Debugf("Registration response: %s", bodyBytes) var registrationResponse api.RegistrationResponse @@ -110,22 +120,23 @@ func (ic *InvocationClient) getNextEventURL() string { return fmt.Sprintf("http://%s/%s/extension/event/next", ic.baseUrl, ic.version) } -// getInitErrorURL returns the Lambda Extension next event URL +// getInitErrorURL returns the Lambda Extension initialization error URL func (ic *InvocationClient) getInitErrorURL() string { return fmt.Sprintf("http://%s/%s/extension/init/error", ic.baseUrl, ic.version) } -// getExitErrorURL returns the Lambda Extension next event URL +// getExitErrorURL returns the Lambda exit error URL func (ic *InvocationClient) getExitErrorURL() string { return fmt.Sprintf("http://%s/%s/extension/exit/error", ic.baseUrl, ic.version) } +// getLogRegistrationURL returns the Lambda Log Registration URL func (ic *InvocationClient) getLogRegistrationURL() string { return fmt.Sprintf("http://%s/%s/logs", ic.baseUrl, api.LogsApiVersion) } // LogRegister registers for log events -func (ic *InvocationClient) LogRegister(subscriptionRequest *api.LogSubscription) error { +func (ic *InvocationClient) LogRegister(ctx context.Context, subscriptionRequest *api.LogSubscription) error { subscriptionRequestJson, err := json.Marshal(subscriptionRequest) if err != nil { return fmt.Errorf("error occurred while marshaling subscription request %s", err) @@ -133,7 +144,7 @@ func (ic *InvocationClient) LogRegister(subscriptionRequest *api.LogSubscription util.Debugln("Log registration with request ", string(subscriptionRequestJson)) - req, err := http.NewRequest("PUT", ic.getLogRegistrationURL(), bytes.NewBuffer(subscriptionRequestJson)) + req, err := http.NewRequestWithContext(ctx, "PUT", ic.getLogRegistrationURL(), bytes.NewBuffer(subscriptionRequestJson)) if err != nil { return fmt.Errorf("error occurred while creating subscription request %s", err) } @@ -147,6 +158,14 @@ func (ic *InvocationClient) LogRegister(subscriptionRequest *api.LogSubscription defer util.Close(res.Body) + if res.StatusCode == http.StatusInternalServerError { + util.Panic("error occurred while making log subscription request: ", res.Status) + } + + if res.StatusCode != http.StatusOK && res.StatusCode != http.StatusAccepted { + return fmt.Errorf("error occurred while making log subscription request: %s", res.Status) + } + responseBody, err := ioutil.ReadAll(res.Body) if err != nil { return err @@ -158,8 +177,8 @@ func (ic *InvocationClient) LogRegister(subscriptionRequest *api.LogSubscription } // NextEvent awaits the next event. -func (ic *InvocationClient) NextEvent() (*api.InvocationEvent, error) { - req, err := http.NewRequest("GET", ic.getNextEventURL(), nil) +func (ic *InvocationClient) NextEvent(ctx context.Context) (*api.InvocationEvent, error) { + req, err := http.NewRequestWithContext(ctx, "GET", ic.getNextEventURL(), nil) if err != nil { return nil, fmt.Errorf("error occurred when creating next request %s", err) } @@ -173,6 +192,14 @@ func (ic *InvocationClient) NextEvent() (*api.InvocationEvent, error) { defer util.Close(res.Body) + if res.StatusCode == http.StatusInternalServerError { + util.Panic("error occurred when calling extension/event/next: ", res.Status) + } + + if res.StatusCode != http.StatusOK { + return nil, fmt.Errorf("error occurred when calling extension/event/next: %s", res.Status) + } + body, err := ioutil.ReadAll(res.Body) if err != nil { return nil, fmt.Errorf("error occurred while reading extension/event/next response body %s", err) @@ -187,9 +214,11 @@ func (ic *InvocationClient) NextEvent() (*api.InvocationEvent, error) { return &event, nil } -func (ic *InvocationClient) InitError(errorEnum string, initError error) error { +// InitError sends an initialization error to the lambda platform +func (ic *InvocationClient) InitError(ctx context.Context, errorEnum string, initError error) error { errorBuf := bytes.NewBufferString(initError.Error()) - req, err := http.NewRequest("POST", ic.getInitErrorURL(), errorBuf) + + req, err := http.NewRequestWithContext(ctx, "POST", ic.getInitErrorURL(), errorBuf) if err != nil { return fmt.Errorf("error occurred when creating init error request %s", err) } @@ -204,12 +233,21 @@ func (ic *InvocationClient) InitError(errorEnum string, initError error) error { defer util.Close(res.Body) + if res.StatusCode == http.StatusInternalServerError { + util.Panic("error occurred while making init error request: ", res.Status) + } + + if res.StatusCode != http.StatusAccepted { + return fmt.Errorf("error occurred while making init error request: %s", res.Status) + } + return nil } -func (ic *InvocationClient) ExitError(errorEnum string, exitError error) error { +// ExitError sends an exit error to the lambda platform +func (ic *InvocationClient) ExitError(ctx context.Context, errorEnum string, exitError error) error { errorBuf := bytes.NewBufferString(exitError.Error()) - req, err := http.NewRequest("POST", ic.getExitErrorURL(), errorBuf) + req, err := http.NewRequestWithContext(ctx, "POST", ic.getExitErrorURL(), errorBuf) if err != nil { return fmt.Errorf("error occurred when creating exit error request %s", err) } @@ -224,5 +262,13 @@ func (ic *InvocationClient) ExitError(errorEnum string, exitError error) error { defer util.Close(res.Body) + if res.StatusCode == http.StatusInternalServerError { + util.Panic("error occurred while making exit error request: ", res.Status) + } + + if res.StatusCode != http.StatusAccepted { + return fmt.Errorf("error occurred while making exit error request: %s", res.Status) + } + return nil } diff --git a/lambda/extension/client/client_test.go b/lambda/extension/client/client_test.go index 09c9667..c4e49c2 100644 --- a/lambda/extension/client/client_test.go +++ b/lambda/extension/client/client_test.go @@ -1,7 +1,9 @@ package client import ( + "context" "encoding/json" + "errors" "io/ioutil" "net/http" "net/http/httptest" @@ -38,7 +40,8 @@ func TestRegistrationClient_GetRegisterURL(t *testing.T) { func TestRegistrationClient_RegisterDefault(t *testing.T) { rc := RegistrationClient{} - ic, res, err := rc.RegisterDefault() + ctx := context.Background() + ic, res, err := rc.RegisterDefault(ctx) assert.Nil(t, ic) assert.Nil(t, res) assert.Error(t, err) @@ -70,7 +73,7 @@ func TestRegistrationClient_RegisterDefault(t *testing.T) { defer os.Unsetenv(api.LambdaHostPortEnvVar) client := New(*srv.Client()) - invocationClient, rr, err := client.RegisterDefault() + invocationClient, rr, err := client.RegisterDefault(ctx) assert.NoError(t, err) assert.Equal(t, "test-ext-id", invocationClient.extensionId) @@ -80,6 +83,287 @@ func TestRegistrationClient_RegisterDefault(t *testing.T) { assert.NotEmpty(t, invocationClient.getLogRegistrationURL()) } +func TestRegistrationClient_RegisterError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer util.Close(r.Body) + + w.Header().Add(api.ExtensionIdHeader, "test-ext-id") + w.WriteHeader(400) + _, _ = w.Write(nil) + })) + defer srv.Close() + + url := srv.URL[7:] + + _ = os.Setenv(api.LambdaHostPortEnvVar, url) + defer os.Unsetenv(api.LambdaHostPortEnvVar) + + client := New(*srv.Client()) + ctx := context.Background() + ic, rr, err := client.RegisterDefault(ctx) + + assert.Nil(t, ic) + assert.Nil(t, rr) + assert.Error(t, err) +} + +func TestRegistrationClient_RegisterPanic(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer util.Close(r.Body) + + w.Header().Add(api.ExtensionIdHeader, "test-ext-id") + w.WriteHeader(500) + _, _ = w.Write(nil) + })) + defer srv.Close() + + url := srv.URL[7:] + + _ = os.Setenv(api.LambdaHostPortEnvVar, url) + defer os.Unsetenv(api.LambdaHostPortEnvVar) + + client := New(*srv.Client()) + ctx := context.Background() + + assert.Panics(t, func() { + client.RegisterDefault(ctx) + }) +} + +func TestInvocationClient_LogRegister(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, r.Method, http.MethodPut) + + assert.NotEmpty(t, r.Header.Get(api.ExtensionIdHeader)) + defer util.Close(r.Body) + + w.WriteHeader(200) + _, _ = w.Write(nil) + })) + defer srv.Close() + + url := srv.URL[7:] + + client := InvocationClient{ + version: api.Version, + baseUrl: url, + httpClient: *srv.Client(), + extensionId: "test-ext-id", + } + + eventTypes := []api.LogEventType{api.Platform} + subscriptionRequest := api.DefaultLogSubscription(eventTypes, 12345) + + ctx := context.Background() + err := client.LogRegister(ctx, subscriptionRequest) + + assert.NoError(t, err) +} + +func TestInvocationClient_LogRegisterError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer util.Close(r.Body) + + w.WriteHeader(400) + _, _ = w.Write(nil) + })) + defer srv.Close() + + url := srv.URL[7:] + + client := InvocationClient{ + version: api.Version, + baseUrl: url, + httpClient: *srv.Client(), + extensionId: "test-ext-id", + } + + eventTypes := []api.LogEventType{api.Platform} + subscriptionRequest := api.DefaultLogSubscription(eventTypes, 12345) + + ctx := context.Background() + err := client.LogRegister(ctx, subscriptionRequest) + + assert.Error(t, err) +} + +func TestInvocationClient_LogRegisterEPanic(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer util.Close(r.Body) + + w.WriteHeader(500) + _, _ = w.Write(nil) + })) + defer srv.Close() + + url := srv.URL[7:] + + client := InvocationClient{ + version: api.Version, + baseUrl: url, + httpClient: *srv.Client(), + extensionId: "test-ext-id", + } + + eventTypes := []api.LogEventType{api.Platform} + subscriptionRequest := api.DefaultLogSubscription(eventTypes, 12345) + + ctx := context.Background() + assert.Panics(t, func() { + client.LogRegister(ctx, subscriptionRequest) + }) +} + +func TestInvocationClient_InitError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, r.Method, http.MethodPost) + + assert.NotEmpty(t, r.Header.Get(api.ExtensionIdHeader)) + defer util.Close(r.Body) + + w.WriteHeader(202) + _, _ = w.Write(nil) + })) + defer srv.Close() + + url := srv.URL[7:] + + client := InvocationClient{ + version: api.Version, + baseUrl: url, + httpClient: *srv.Client(), + extensionId: "test-ext-id", + } + + ctx := context.Background() + err := client.InitError(ctx, "foo.bar", errors.New("something went wrong")) + + assert.NoError(t, err) +} + +func TestInvocationClient_InitErrorError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer util.Close(r.Body) + + w.WriteHeader(400) + _, _ = w.Write(nil) + })) + defer srv.Close() + + url := srv.URL[7:] + + client := InvocationClient{ + version: api.Version, + baseUrl: url, + httpClient: *srv.Client(), + extensionId: "test-ext-id", + } + + ctx := context.Background() + err := client.InitError(ctx, "foo.bar", errors.New("something went wrong")) + + assert.Error(t, err) +} + +func TestInvocationClient_InitErrorPanic(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer util.Close(r.Body) + + w.WriteHeader(500) + _, _ = w.Write(nil) + })) + defer srv.Close() + + url := srv.URL[7:] + + client := InvocationClient{ + version: api.Version, + baseUrl: url, + httpClient: *srv.Client(), + extensionId: "test-ext-id", + } + + ctx := context.Background() + assert.Panics(t, func() { + client.InitError(ctx, "foo.bar", errors.New("something went wrong")) + }) +} + +func TestInvocationClient_ExitError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, r.Method, http.MethodPost) + + assert.NotEmpty(t, r.Header.Get(api.ExtensionIdHeader)) + defer util.Close(r.Body) + + w.WriteHeader(202) + _, _ = w.Write(nil) + })) + defer srv.Close() + + url := srv.URL[7:] + + client := InvocationClient{ + version: api.Version, + baseUrl: url, + httpClient: *srv.Client(), + extensionId: "test-ext-id", + } + + ctx := context.Background() + err := client.ExitError(ctx, "foo.bar", errors.New("something went wrong")) + + assert.NoError(t, err) +} + +func TestInvocationClient_ExitErrorError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer util.Close(r.Body) + + w.WriteHeader(400) + _, _ = w.Write(nil) + })) + defer srv.Close() + + url := srv.URL[7:] + + client := InvocationClient{ + version: api.Version, + baseUrl: url, + httpClient: *srv.Client(), + extensionId: "test-ext-id", + } + + ctx := context.Background() + err := client.ExitError(ctx, "foo.bar", errors.New("something went wrong")) + + assert.Error(t, err) +} + +func TestInvocationClient_ExitErrorPanic(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer util.Close(r.Body) + + w.WriteHeader(500) + _, _ = w.Write(nil) + })) + defer srv.Close() + + url := srv.URL[7:] + + client := InvocationClient{ + version: api.Version, + baseUrl: url, + httpClient: *srv.Client(), + extensionId: "test-ext-id", + } + + ctx := context.Background() + assert.Panics(t, func() { + client.ExitError(ctx, "foo.bar", errors.New("something went wrong")) + }) +} + func TestInvocationClient_NextEvent(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, r.Method, http.MethodGet) @@ -107,8 +391,58 @@ func TestInvocationClient_NextEvent(t *testing.T) { httpClient: *srv.Client(), extensionId: "test-ext-id", } - invocationEvent, err := client.NextEvent() + ctx := context.Background() + invocationEvent, err := client.NextEvent(ctx) assert.NoError(t, err) assert.NotNil(t, invocationEvent) } + +func TestInvocationClient_NextEventError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer util.Close(r.Body) + + w.WriteHeader(400) + _, _ = w.Write(nil) + })) + defer srv.Close() + + url := srv.URL[7:] + + client := InvocationClient{ + version: api.Version, + baseUrl: url, + httpClient: *srv.Client(), + extensionId: "test-ext-id", + } + + ctx := context.Background() + event, err := client.NextEvent(ctx) + + assert.Error(t, err) + assert.Nil(t, event) +} + +func TestInvocationClient_NextEventPanic(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer util.Close(r.Body) + + w.WriteHeader(500) + _, _ = w.Write(nil) + })) + defer srv.Close() + + url := srv.URL[7:] + + client := InvocationClient{ + version: api.Version, + baseUrl: url, + httpClient: *srv.Client(), + extensionId: "test-ext-id", + } + + ctx := context.Background() + assert.Panics(t, func() { + client.NextEvent(ctx) + }) +} diff --git a/lambda/logserver/logserver.go b/lambda/logserver/logserver.go index 63b95af..3d79985 100644 --- a/lambda/logserver/logserver.go +++ b/lambda/logserver/logserver.go @@ -9,13 +9,13 @@ import ( "strconv" "time" + "github.com/newrelic/newrelic-lambda-extension/config" "github.com/newrelic/newrelic-lambda-extension/lambda/extension/api" "github.com/newrelic/newrelic-lambda-extension/util" ) const ( platformLogBufferSize = 100 - defaultHost = "sandbox.localdomain" ) type LogLine struct { @@ -49,6 +49,7 @@ func (ls *LogServer) Close() error { func (ls *LogServer) PollPlatformChannel() []LogLine { var ret []LogLine + for { select { case report, more := <-ls.platformLogChan: @@ -95,6 +96,8 @@ func formatReport(metrics map[string]interface{}) string { } func (ls *LogServer) handler(res http.ResponseWriter, req *http.Request) { + defer util.Close(req.Body) + bodyBytes, err := ioutil.ReadAll(req.Body) if err != nil { util.Logf("Error processing log request: %v", err) @@ -106,8 +109,10 @@ func (ls *LogServer) handler(res http.ResponseWriter, req *http.Request) { util.Logf("Error parsing log payload: %v", err) } - var functionLogs []LogLine - var lastRequestId string + var ( + functionLogs []LogLine + lastRequestId string + ) for _, event := range logEvents { switch event.Type { @@ -138,9 +143,10 @@ func (ls *LogServer) handler(res http.ResponseWriter, req *http.Request) { Content: []byte(record), }) default: - //util.Logln("Ignored log event of type ", event.Type, string(bodyBytes)) + //util.Debugln("Ignored log event of type ", event.Type, string(bodyBytes)) } } + if len(functionLogs) > 0 { ls.functionLogChan <- functionLogs } @@ -148,8 +154,8 @@ func (ls *LogServer) handler(res http.ResponseWriter, req *http.Request) { _, _ = res.Write(nil) } -func Start() (*LogServer, error) { - return startInternal(defaultHost) +func Start(conf *config.Configuration) (*LogServer, error) { + return startInternal(conf.LogServerHost) } func startInternal(host string) (*LogServer, error) { @@ -158,23 +164,23 @@ func startInternal(host string) (*LogServer, error) { return nil, err } - server := http.Server{} + server := &http.Server{} - logServer := LogServer{ + logServer := &LogServer{ listenString: listener.Addr().String(), - server: &server, + server: server, platformLogChan: make(chan LogLine, platformLogBufferSize), functionLogChan: make(chan []LogLine), } - http.HandleFunc("/", func(res http.ResponseWriter, req *http.Request) { - logServer.handler(res, req) - }) + mux := http.NewServeMux() + mux.HandleFunc("/", logServer.handler) + server.Handler = mux go func() { util.Logln("Starting log server.") util.Logf("Log server terminating: %v\n", server.Serve(listener)) }() - return &logServer, nil + return logServer, nil } diff --git a/lambda/logserver/logserver_test.go b/lambda/logserver/logserver_test.go index fec56f2..f906dd2 100644 --- a/lambda/logserver/logserver_test.go +++ b/lambda/logserver/logserver_test.go @@ -4,24 +4,18 @@ import ( "bytes" "encoding/json" "fmt" - "log" "net/http" "testing" "time" + "github.com/newrelic/newrelic-lambda-extension/config" "github.com/newrelic/newrelic-lambda-extension/lambda/extension/api" "github.com/stretchr/testify/assert" ) -func Test_Logserver(t *testing.T) { +func TestLogServer(t *testing.T) { logs, err := startInternal("localhost") - if err != nil { - log.Println("Failed to start logs HTTP server", err) - if err != nil { - log.Fatal(err) - } - return - } + assert.NoError(t, err) testEvents := []api.LogEvent{ { @@ -49,8 +43,8 @@ func Test_Logserver(t *testing.T) { client := http.Client{} res, err := client.Do(req) - assert.NoError(t, err) + assert.NoError(t, err) assert.Equal(t, 200, res.StatusCode) assert.Equal(t, http.NoBody, res.Body) @@ -61,3 +55,9 @@ func Test_Logserver(t *testing.T) { assert.Nil(t, logs.Close()) } + +func TestLogServerStart(t *testing.T) { + logs, err := Start(&config.Configuration{LogServerHost: "localhost"}) + assert.NoError(t, err) + assert.Nil(t, logs.Close()) +} diff --git a/main.go b/main.go index b329cea..27eed19 100644 --- a/main.go +++ b/main.go @@ -1,10 +1,14 @@ package main import ( + "context" "encoding/base64" "fmt" "net/http" + "os" + "os/signal" "sync" + "syscall" "time" "github.com/newrelic/newrelic-lambda-extension/checks" @@ -18,37 +22,68 @@ import ( "github.com/newrelic/newrelic-lambda-extension/telemetry" ) +var rootCtx context.Context + +func init() { + rootCtx = context.Background() +} + func main() { extensionStartup := time.Now() + ctx, cancel := context.WithCancel(rootCtx) + defer cancel() + + // exit cleanly on SIGTERM or SIGINT + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGTERM, syscall.SIGINT) + go func() { + s := <-sigs + cancel() + util.Logf("Received %v Exiting", s) + }() + + // Allow extension to be interrupted with CTRL-C + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt) + go func() { + for _ = range c { + cancel() + util.Fatal("Exiting...") + } + }() + // Parse various env vars for our config conf := config.ConfigurationFromEnvironment() + // Optionally enable debug logging, disabled by default util.ConfigLogger(conf.LogLevel == config.DebugLogLevel) // Extensions must register registrationClient := client.New(http.Client{}) + regReq := api.RegistrationRequest{ Events: []api.LifecycleEvent{api.Invoke, api.Shutdown}, } - invocationClient, registrationResponse, err := registrationClient.Register(regReq) + invocationClient, registrationResponse, err := registrationClient.Register(ctx, regReq) if err != nil { - util.Fatal(err) + util.Panic(err) } + // If extension disabled, go into no op mode if !conf.ExtensionEnabled { util.Logln("Extension telemetry processing disabled") - noopLoop(invocationClient) + noopLoop(ctx, invocationClient) return } // Attempt to find the license key for telemetry sending - licenseKey, err := credentials.GetNewRelicLicenseKey(&conf) + licenseKey, err := credentials.GetNewRelicLicenseKey(ctx, conf) if err != nil { - util.Logln("Failed to retrieve license key", err) + util.Logln("Failed to retrieve New Relic license key", err) // We fail open; telemetry will go to CloudWatch instead - noopLoop(invocationClient) + noopLoop(ctx, invocationClient) return } @@ -56,52 +91,67 @@ func main() { batch := telemetry.NewBatch(int64(conf.RipeMillis), int64(conf.RotMillis)) // Start the Logs API server, and register it - logServer, err := logserver.Start() + logServer, err := logserver.Start(conf) if err != nil { - util.Logln("Failed to start logs HTTP server", err) - err = invocationClient.InitError("logServer.start", err) - if err != nil { - util.Fatal(err) + err2 := invocationClient.InitError(ctx, "logServer.start", err) + if err2 != nil { + util.Logln(err2) } - return + util.Panic("Failed to start logs HTTP server", err) } + eventTypes := []api.LogEventType{api.Platform} if conf.SendFunctionLogs { eventTypes = append(eventTypes, api.Function) } subscriptionRequest := api.DefaultLogSubscription(eventTypes, logServer.Port()) - err = invocationClient.LogRegister(&subscriptionRequest) + err = invocationClient.LogRegister(ctx, subscriptionRequest) if err != nil { - util.Logln("Failed to register with Logs API", err) - err = invocationClient.InitError("logServer.register", err) - if err != nil { - util.Fatal(err) + err2 := invocationClient.InitError(ctx, "logServer.register", err) + if err2 != nil { + util.Logln(err2) } - return + util.Panic("Failed to register with Logs API", err) } // Init the telemetry sending client telemetryClient := telemetry.New(registrationResponse.FunctionName, licenseKey, conf.TelemetryEndpoint, conf.LogEndpoint) - telemetryChan, err := telemetry.InitTelemetryChannel() if err != nil { - util.Fatal("telemetry pipe init failed: ", err) + err2 := invocationClient.InitError(ctx, "telemetryClient.init", err) + if err2 != nil { + util.Logln(err2) + } + util.Panic("telemetry pipe init failed: ", err) } + // Run startup checks go func() { - checks.RunChecks(&conf, registrationResponse, telemetryClient) + checks.RunChecks(ctx, conf, registrationResponse, telemetryClient) }() // Send function logs as they arrive. When disabled, function logs aren't delivered to the extension. - var backgroundTasks sync.WaitGroup + backgroundTasks := &sync.WaitGroup{} + backgroundTasks.Add(1) + go func() { - backgroundTasks.Add(1) defer backgroundTasks.Done() - functionLogShipLoop(logServer, telemetryClient) + logShipLoop(ctx, logServer, telemetryClient) }() // Call next, and process telemetry, until we're shut down - mainLoop(invocationClient, &batch, telemetryChan, logServer, telemetryClient) + eventCounter, invokedFunctionARN := mainLoop(ctx, invocationClient, batch, telemetryChan, logServer, telemetryClient) + + util.Logf("New Relic Extension shutting down after %v events\n", eventCounter) + + err = logServer.Close() + if err != nil { + util.Logln("Error shutting down Log API server", err) + } + + pollLogServer(logServer, batch) + finalHarvest := batch.Close() + shipHarvest(ctx, finalHarvest, telemetryClient, invokedFunctionARN) util.Debugln("Waiting for background tasks to complete") backgroundTasks.Wait() @@ -111,14 +161,15 @@ func main() { util.Logf("Extension shutdown after %vms", ranFor.Milliseconds()) } -// functionLogShipLoop ships function logs to New Relic as they arrive. -func functionLogShipLoop(logServer *logserver.LogServer, telemetryClient *telemetry.Client) { +// logShipLoop ships function logs to New Relic as they arrive. +func logShipLoop(ctx context.Context, logServer *logserver.LogServer, telemetryClient *telemetry.Client) { for { functionLogs, more := logServer.AwaitFunctionLogs() if !more { return } - err := telemetryClient.SendFunctionLogs(functionLogs) + + err := telemetryClient.SendFunctionLogs(ctx, functionLogs) if err != nil { util.Logf("Failed to send %d function logs", len(functionLogs)) } @@ -126,108 +177,117 @@ func functionLogShipLoop(logServer *logserver.LogServer, telemetryClient *teleme } // mainLoop repeatedly calls the /next api, and processes telemetry and platform logs. The timing is rather complicated. -func mainLoop(invocationClient *client.InvocationClient, batch *telemetry.Batch, telemetryChan chan []byte, logServer *logserver.LogServer, telemetryClient *telemetry.Client) { - counter := 0 - var invokedFunctionARN string - var lastRequestId string - var lastEventStart time.Time +func mainLoop(ctx context.Context, invocationClient *client.InvocationClient, batch *telemetry.Batch, telemetryChan chan []byte, logServer *logserver.LogServer, telemetryClient *telemetry.Client) (int, string) { + var ( + invokedFunctionARN string + lastEventStart time.Time + lastRequestId string + ) + + eventCounter := 0 probablyTimeout := false + for { - // Our call to next blocks. It is likely that the container is frozen immediately after we call NextEvent. - event, err := invocationClient.NextEvent() - // We've thawed. - eventStart := time.Now() - if err != nil { - errErr := invocationClient.ExitError("NextEventError.Main", err) - if errErr != nil { - util.Logln(errErr) + select { + case <-ctx.Done(): + // We're already done + util.Logln(ctx.Err()) + return eventCounter, "" + default: + // Our call to next blocks. It is likely that the container is frozen immediately after we call NextEvent. + event, err := invocationClient.NextEvent(ctx) + + // We've thawed. + eventStart := time.Now() + + if err != nil { + + util.Logln(err) + err = invocationClient.ExitError(ctx, "NextEventError.Main", err) + if err != nil { + util.Logln(err) + } + continue } - util.Fatal(err) - } - counter++ + eventCounter++ + + if probablyTimeout { + // We suspect a timeout. Either way, we've gotten to the next event, so telemetry will + // have arrived for the last request if it's going to. Non-blocking poll for telemetry. + // If we have indeed timed out, there's a chance we got telemetry out anyway. If we haven't + // timed out, this will catch us up to the current state of telemetry, allowing us to resume. + select { + case telemetryBytes := <-telemetryChan: + // We received telemetry + batch.AddTelemetry(lastRequestId, telemetryBytes) + util.Logf("We suspected a timeout for request %s but got telemetry anyway", lastRequestId) + default: + } - if probablyTimeout { - // We suspect a timeout. Either way, we've gotten to the next event, so telemetry will - // have arrived for the last request if it's going to. Non-blocking poll for telemetry. - // If we have indeed timed out, there's a chance we got telemetry out anyway. If we haven't - // timed out, this will catch us up to the current state of telemetry, allowing us to resume. - select { - case telemetryBytes := <-telemetryChan: - // We received telemetry - batch.AddTelemetry(lastRequestId, telemetryBytes) - util.Logf("We suspected a timeout for request %s but got telemetry anyway", lastRequestId) - default: } - } - if event.EventType == api.Shutdown { - if event.ShutdownReason == api.Timeout && lastRequestId != "" { - // Synthesize the timeout error message that the platform produces, and LLC parses - timestamp := eventStart.UTC() - timeoutSecs := eventStart.Sub(lastEventStart).Seconds() - timeoutMessage := fmt.Sprintf( - "%s %s Task timed out after %.2f seconds", - timestamp.Format(time.RFC3339), - lastRequestId, - timeoutSecs, - ) - batch.AddTelemetry(lastRequestId, []byte(timeoutMessage)) - } else if event.ShutdownReason == api.Failure && lastRequestId != "" { - // Synthesize a generic platform error. Probably an OOM, though it could be any runtime crash. - errorMessage := fmt.Sprintf("RequestId: %s A platform error caused a shutdown", lastRequestId) - batch.AddTelemetry(lastRequestId, []byte(errorMessage)) + invokedFunctionARN = event.InvokedFunctionARN + lastRequestId = event.RequestID + + if event.EventType == api.Shutdown { + if event.ShutdownReason == api.Timeout && lastRequestId != "" { + // Synthesize the timeout error message that the platform produces, and LLC parses + timestamp := eventStart.UTC() + timeoutSecs := eventStart.Sub(lastEventStart).Seconds() + timeoutMessage := fmt.Sprintf( + "%s %s Task timed out after %.2f seconds", + timestamp.Format(time.RFC3339), + lastRequestId, + timeoutSecs, + ) + batch.AddTelemetry(lastRequestId, []byte(timeoutMessage)) + } else if event.ShutdownReason == api.Failure && lastRequestId != "" { + // Synthesize a generic platform error. Probably an OOM, though it could be any runtime crash. + errorMessage := fmt.Sprintf("RequestId: %s A platform error caused a shutdown", lastRequestId) + batch.AddTelemetry(lastRequestId, []byte(errorMessage)) + } + + return eventCounter, invokedFunctionARN } - break - } + // Create an invocation record to hold telemetry + batch.AddInvocation(lastRequestId, eventStart) - invokedFunctionARN = event.InvokedFunctionARN - lastRequestId = event.RequestID - // Create an invocation record to hold telemetry - batch.AddInvocation(lastRequestId, eventStart) + // Await agent telemetry. This may time out + // timeoutInstant is when the invocation will time out + timeoutInstant := time.Unix(0, event.DeadlineMs*int64(time.Millisecond)) - // Await agent telemetry. This may time out, so we race the timeout against the telemetry channel - // timeoutInstant is when the invocation will time out - timeoutInstant := time.Unix(0, event.DeadlineMs*int64(time.Millisecond)) + // Set the timeout timer for a smidge before the actual timeout; + // we can recover from early. + timeoutWatchBegins := time.Millisecond * 100 + timeout := timeoutInstant.Sub(time.Now()) - timeoutWatchBegins - // Set the timeout timer for a smidge before the actual timeout; we can recover from early. - timeoutWatchBegins := time.Millisecond * 100 - timeout := time.NewTimer(timeoutInstant.Sub(time.Now()) - timeoutWatchBegins) - select { - case telemetryBytes := <-telemetryChan: - // We received telemetry - util.Debugf("Agent telemetry bytes: %s", base64.URLEncoding.EncodeToString(telemetryBytes)) - inv := batch.AddTelemetry(lastRequestId, telemetryBytes) - if inv == nil { - util.Logf("Failed to add telemetry for request %v", lastRequestId) - } - // Tear down the timer - if !timeout.Stop() { - <-timeout.C - } + invCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() - pollLogServer(logServer, batch) + select { + case <-invCtx.Done(): + // We are about to timeout + util.Debugln("Timeout suspected: ", invCtx.Err()) + probablyTimeout = true + continue + case telemetryBytes := <-telemetryChan: + // We received telemetry + util.Debugf("Agent telemetry bytes: %s", base64.URLEncoding.EncodeToString(telemetryBytes)) + inv := batch.AddTelemetry(lastRequestId, telemetryBytes) + if inv == nil { + util.Logf("Failed to add telemetry for request %v", lastRequestId) + } + + pollLogServer(logServer, batch) + harvested := batch.Harvest(time.Now()) + shipHarvest(ctx, harvested, telemetryClient, invokedFunctionARN) + } - harvested := batch.Harvest(time.Now()) - shipHarvest(harvested, telemetryClient, invokedFunctionARN) - case <-timeout.C: - // Function is timing out - util.Debugln("Timeout suspected") - probablyTimeout = true + lastEventStart = eventStart } - lastEventStart = eventStart } - util.Logf("New Relic Extension shutting down after %v events\n", counter) - - err := logServer.Close() - if err != nil { - util.Logln("Error shutting down Log API server", err) - } - - pollLogServer(logServer, batch) - finalHarvest := batch.Close() - shipHarvest(finalHarvest, telemetryClient, invokedFunctionARN) } // pollLogServer polls for platform logs, and annotates telemetry @@ -240,33 +300,41 @@ func pollLogServer(logServer *logserver.LogServer, batch *telemetry.Batch) { } } -func shipHarvest(harvested []*telemetry.Invocation, telemetryClient *telemetry.Client, invokedFunctionARN string) { +func shipHarvest(ctx context.Context, harvested []*telemetry.Invocation, telemetryClient *telemetry.Client, invokedFunctionARN string) { if len(harvested) > 0 { telemetrySlice := make([][]byte, 0, 2*len(harvested)) for _, inv := range harvested { telemetrySlice = append(telemetrySlice, inv.Telemetry...) } - err, _ := telemetryClient.SendTelemetry(invokedFunctionARN, telemetrySlice) + err, _ := telemetryClient.SendTelemetry(ctx, invokedFunctionARN, telemetrySlice) if err != nil { util.Logf("Failed to send harvested telemetry for %d invocations %s", len(harvested), err) } } } -func noopLoop(invocationClient *client.InvocationClient) { +func noopLoop(ctx context.Context, invocationClient *client.InvocationClient) { + util.Logln("Starting no-op mode, no telemetry will be sent") + for { - event, err := invocationClient.NextEvent() - if err != nil { - errErr := invocationClient.ExitError("NextEventError.Noop", err) - if errErr != nil { - util.Logln(errErr) + select { + case <-ctx.Done(): + return + default: + event, err := invocationClient.NextEvent(ctx) + if err != nil { + util.Logln(err) + errErr := invocationClient.ExitError(ctx, "NextEventError.Noop", err) + if errErr != nil { + util.Logln(errErr) + } + continue } - util.Fatal(err) - } - if event.EventType == api.Shutdown { - return + if event.EventType == api.Shutdown { + return + } } } } diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..74b0900 --- /dev/null +++ b/main_test.go @@ -0,0 +1,537 @@ +package main + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "testing" + + "github.com/newrelic/newrelic-lambda-extension/lambda/extension/api" + "github.com/newrelic/newrelic-lambda-extension/util" + + "github.com/stretchr/testify/assert" +) + +// TODO: These tests are very repetitive. Helpers would be useful here. + +func TestMainRegisterFail(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer util.Close(r.Body) + + if r.URL.Path == "/2020-01-01/extension/register" { + w.Header().Add(api.ExtensionIdHeader, "test-ext-id") + w.WriteHeader(400) + _, _ = w.Write(nil) + } + })) + defer srv.Close() + + url := srv.URL[7:] + + _ = os.Setenv(api.LambdaHostPortEnvVar, url) + defer os.Unsetenv(api.LambdaHostPortEnvVar) + + assert.Panics(t, main) +} + +func TestMainLogServerInitFail(t *testing.T) { + var ( + registerRequestCount int + initErrorRequestCount int + exitErrorRequestCount int + logRegisterRequestCount int + ) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer util.Close(r.Body) + + if r.URL.Path == "/2020-01-01/extension/register" { + registerRequestCount++ + + w.Header().Add(api.ExtensionIdHeader, "test-ext-id") + w.WriteHeader(200) + res, err := json.Marshal(api.RegistrationResponse{ + FunctionName: "foobar", + FunctionVersion: "latest", + Handler: "lambda.handler", + }) + assert.Nil(t, err) + _, _ = w.Write(res) + } + + if r.URL.Path == "/2020-01-01/extension/init/error" { + initErrorRequestCount++ + + w.WriteHeader(200) + _, _ = w.Write([]byte("")) + } + + if r.URL.Path == "/2020-01-01/extension/exit/error" { + exitErrorRequestCount++ + + w.WriteHeader(200) + _, _ = w.Write([]byte("")) + } + + if r.URL.Path == "/2020-08-15/logs" { + logRegisterRequestCount++ + + w.WriteHeader(200) + _, _ = w.Write([]byte("")) + } + })) + defer srv.Close() + + url := srv.URL[7:] + + _ = os.Setenv(api.LambdaHostPortEnvVar, url) + defer os.Unsetenv(api.LambdaHostPortEnvVar) + + _ = os.Setenv("NEW_RELIC_LICENSE_KEY", "foobar") + defer os.Unsetenv("NEW_RELIC_LICENSE_KEY") + + // Shouldn't be able to bind to this locally + _ = os.Setenv("NEW_RELIC_LOG_SERVER_HOST", "sandbox.localdomain") + defer os.Unsetenv("NEW_RELIC_LOG_SERVER_HOST") + + _ = os.Setenv("NEW_RELIC_EXTENSION_LOG_LEVEL", "DEBUG") + defer os.Unsetenv("NEW_RELIC_EXTENSION_LOG_LEVEL") + + assert.Panics(t, main) + + assert.Equal(t, 1, registerRequestCount) + assert.Equal(t, 1, initErrorRequestCount) + assert.Equal(t, 0, exitErrorRequestCount) + assert.Equal(t, 0, logRegisterRequestCount) +} + +func TestMainLogServerRegisterFail(t *testing.T) { + var ( + registerRequestCount int + initErrorRequestCount int + exitErrorRequestCount int + logRegisterRequestCount int + ) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer util.Close(r.Body) + + if r.URL.Path == "/2020-01-01/extension/register" { + registerRequestCount++ + + w.Header().Add(api.ExtensionIdHeader, "test-ext-id") + w.WriteHeader(200) + res, err := json.Marshal(api.RegistrationResponse{ + FunctionName: "foobar", + FunctionVersion: "latest", + Handler: "lambda.handler", + }) + assert.Nil(t, err) + _, _ = w.Write(res) + } + + if r.URL.Path == "/2020-01-01/extension/init/error" { + initErrorRequestCount++ + + w.WriteHeader(200) + _, _ = w.Write([]byte("")) + } + + if r.URL.Path == "/2020-01-01/extension/exit/error" { + exitErrorRequestCount++ + + w.WriteHeader(200) + _, _ = w.Write([]byte("")) + } + + if r.URL.Path == "/2020-08-15/logs" { + logRegisterRequestCount++ + + w.WriteHeader(400) + _, _ = w.Write(nil) + } + })) + defer srv.Close() + + url := srv.URL[7:] + + _ = os.Setenv(api.LambdaHostPortEnvVar, url) + defer os.Unsetenv(api.LambdaHostPortEnvVar) + + _ = os.Setenv("NEW_RELIC_LICENSE_KEY", "foobar") + defer os.Unsetenv("NEW_RELIC_LICENSE_KEY") + + _ = os.Setenv("NEW_RELIC_LOG_SERVER_HOST", "localhost") + defer os.Unsetenv("NEW_RELIC_LOG_SERVER_HOST") + + _ = os.Setenv("NEW_RELIC_EXTENSION_LOG_LEVEL", "DEBUG") + defer os.Unsetenv("NEW_RELIC_EXTENSION_LOG_LEVEL") + + assert.Panics(t, main) + + assert.Equal(t, 1, registerRequestCount) + assert.Equal(t, 1, initErrorRequestCount) + assert.Equal(t, 0, exitErrorRequestCount) + assert.Equal(t, 1, logRegisterRequestCount) +} + +func TestMainShutdown(t *testing.T) { + var ( + registerRequestCount int + initErrorRequestCount int + exitErrorRequestCount int + logRegisterRequestCount int + nextEventRequestCount int + ) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + util.Logln("Path: ", r.URL.Path) + defer util.Close(r.Body) + + if r.URL.Path == "/2020-01-01/extension/register" { + registerRequestCount++ + + w.Header().Add(api.ExtensionIdHeader, "test-ext-id") + w.WriteHeader(200) + res, err := json.Marshal(api.RegistrationResponse{ + FunctionName: "foobar", + FunctionVersion: "latest", + Handler: "lambda.handler", + }) + assert.Nil(t, err) + _, _ = w.Write(res) + } + + if r.URL.Path == "/2020-01-01/extension/init/error" { + initErrorRequestCount++ + + w.WriteHeader(200) + _, _ = w.Write([]byte("")) + } + + if r.URL.Path == "/2020-01-01/extension/exit/error" { + exitErrorRequestCount++ + + w.WriteHeader(200) + _, _ = w.Write([]byte("")) + } + + if r.URL.Path == "/2020-08-15/logs" { + logRegisterRequestCount++ + + w.WriteHeader(200) + _, _ = w.Write([]byte("")) + } + + if r.URL.Path == "/2020-01-01/extension/event/next" { + nextEventRequestCount++ + + w.WriteHeader(200) + res, err := json.Marshal(api.InvocationEvent{ + EventType: api.Shutdown, + DeadlineMs: 1, + RequestID: "12345", + InvokedFunctionARN: "arn:aws:lambda:us-east-1:12345:foobar", + ShutdownReason: api.Timeout, + Tracing: nil, + }) + assert.Nil(t, err) + _, _ = w.Write(res) + } + })) + defer srv.Close() + + url := srv.URL[7:] + + _ = os.Setenv(api.LambdaHostPortEnvVar, url) + defer os.Unsetenv(api.LambdaHostPortEnvVar) + + _ = os.Setenv("NEW_RELIC_LICENSE_KEY", "foobar") + defer os.Unsetenv("NEW_RELIC_LICENSE_KEY") + + _ = os.Setenv("NEW_RELIC_LOG_SERVER_HOST", "localhost") + defer os.Unsetenv("NEW_RELIC_LOG_SERVER_HOST") + + _ = os.Setenv("NEW_RELIC_EXTENSION_LOG_LEVEL", "DEBUG") + defer os.Unsetenv("NEW_RELIC_EXTENSION_LOG_LEVEL") + + assert.NotPanics(t, main) + + assert.Equal(t, 1, registerRequestCount) + assert.Equal(t, 0, initErrorRequestCount) + assert.Equal(t, 0, exitErrorRequestCount) + assert.Equal(t, 1, logRegisterRequestCount) + assert.Equal(t, 1, nextEventRequestCount) +} + +func TestMainNoLicenseKey(t *testing.T) { + var ( + registerRequestCount int + initErrorRequestCount int + exitErrorRequestCount int + logRegisterRequestCount int + nextEventRequestCount int + ) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + util.Logln("Path: ", r.URL.Path) + defer util.Close(r.Body) + + if r.URL.Path == "/2020-01-01/extension/register" { + registerRequestCount++ + + w.Header().Add(api.ExtensionIdHeader, "test-ext-id") + w.WriteHeader(200) + res, err := json.Marshal(api.RegistrationResponse{ + FunctionName: "foobar", + FunctionVersion: "latest", + Handler: "lambda.handler", + }) + assert.Nil(t, err) + _, _ = w.Write(res) + } + + if r.URL.Path == "/2020-01-01/extension/init/error" { + initErrorRequestCount++ + + w.WriteHeader(200) + _, _ = w.Write([]byte("")) + } + + if r.URL.Path == "/2020-01-01/extension/exit/error" { + exitErrorRequestCount++ + + w.WriteHeader(200) + _, _ = w.Write([]byte("")) + } + + if r.URL.Path == "/2020-08-15/logs" { + logRegisterRequestCount++ + + w.WriteHeader(200) + _, _ = w.Write([]byte("")) + } + + if r.URL.Path == "/2020-01-01/extension/event/next" { + nextEventRequestCount++ + + w.WriteHeader(200) + res, err := json.Marshal(api.InvocationEvent{ + EventType: api.Shutdown, + DeadlineMs: 1, + RequestID: "12345", + InvokedFunctionARN: "arn:aws:lambda:us-east-1:12345:foobar", + ShutdownReason: api.Timeout, + Tracing: nil, + }) + assert.Nil(t, err) + _, _ = w.Write(res) + } + })) + defer srv.Close() + + url := srv.URL[7:] + + _ = os.Setenv(api.LambdaHostPortEnvVar, url) + defer os.Unsetenv(api.LambdaHostPortEnvVar) + + _ = os.Setenv("NEW_RELIC_EXTENSION_LOG_LEVEL", "DEBUG") + defer os.Unsetenv("NEW_RELIC_EXTENSION_LOG_LEVEL") + + assert.NotPanics(t, main) + + assert.Equal(t, 1, registerRequestCount) + assert.Equal(t, 0, initErrorRequestCount) + assert.Equal(t, 0, exitErrorRequestCount) + assert.Equal(t, 0, logRegisterRequestCount) + assert.Equal(t, 1, nextEventRequestCount) +} + +func TestMainExtensionDisabled(t *testing.T) { + var ( + registerRequestCount int + initErrorRequestCount int + exitErrorRequestCount int + logRegisterRequestCount int + nextEventRequestCount int + ) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + util.Logln("Path: ", r.URL.Path) + defer util.Close(r.Body) + + if r.URL.Path == "/2020-01-01/extension/register" { + registerRequestCount++ + + w.Header().Add(api.ExtensionIdHeader, "test-ext-id") + w.WriteHeader(200) + res, err := json.Marshal(api.RegistrationResponse{ + FunctionName: "foobar", + FunctionVersion: "latest", + Handler: "lambda.handler", + }) + assert.Nil(t, err) + _, _ = w.Write(res) + } + + if r.URL.Path == "/2020-01-01/extension/init/error" { + initErrorRequestCount++ + + w.WriteHeader(200) + _, _ = w.Write([]byte("")) + } + + if r.URL.Path == "/2020-01-01/extension/exit/error" { + exitErrorRequestCount++ + + w.WriteHeader(200) + _, _ = w.Write([]byte("")) + } + + if r.URL.Path == "/2020-08-15/logs" { + logRegisterRequestCount++ + + w.WriteHeader(200) + _, _ = w.Write([]byte("")) + } + + if r.URL.Path == "/2020-01-01/extension/event/next" { + nextEventRequestCount++ + + w.WriteHeader(200) + res, err := json.Marshal(api.InvocationEvent{ + EventType: api.Shutdown, + DeadlineMs: 1, + RequestID: "12345", + InvokedFunctionARN: "arn:aws:lambda:us-east-1:12345:foobar", + ShutdownReason: api.Timeout, + Tracing: nil, + }) + assert.Nil(t, err) + _, _ = w.Write(res) + } + })) + defer srv.Close() + + url := srv.URL[7:] + + _ = os.Setenv(api.LambdaHostPortEnvVar, url) + defer os.Unsetenv(api.LambdaHostPortEnvVar) + + _ = os.Setenv("NEW_RELIC_LICENSE_KEY", "foobar") + defer os.Unsetenv("NEW_RELIC_LICENSE_KEY") + + _ = os.Setenv("NEW_RELIC_LAMBDA_EXTENSION_ENABLED", "false") + defer os.Unsetenv("NEW_RELIC_LAMBDA_EXTENSION_ENABLED") + + _ = os.Setenv("NEW_RELIC_EXTENSION_LOG_LEVEL", "DEBUG") + defer os.Unsetenv("NEW_RELIC_EXTENSION_LOG_LEVEL") + + assert.NotPanics(t, main) + + assert.Equal(t, 1, registerRequestCount) + assert.Equal(t, 0, initErrorRequestCount) + assert.Equal(t, 0, exitErrorRequestCount) + assert.Equal(t, 0, logRegisterRequestCount) + assert.Equal(t, 1, nextEventRequestCount) +} + +func TestMainTimeout(t *testing.T) { + var ( + registerRequestCount int + initErrorRequestCount int + exitErrorRequestCount int + logRegisterRequestCount int + nextEventRequestCount int + ) + + ctx, cancel := context.WithCancel(context.Background()) + overrideContext(ctx) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer util.Close(r.Body) + + if r.URL.Path == "/2020-01-01/extension/register" { + registerRequestCount++ + + w.Header().Add(api.ExtensionIdHeader, "test-ext-id") + w.WriteHeader(200) + res, err := json.Marshal(api.RegistrationResponse{ + FunctionName: "foobar", + FunctionVersion: "latest", + Handler: "lambda.handler", + }) + assert.Nil(t, err) + _, _ = w.Write(res) + } + + if r.URL.Path == "/2020-01-01/extension/init/error" { + initErrorRequestCount++ + + w.WriteHeader(200) + _, _ = w.Write([]byte("")) + } + + if r.URL.Path == "/2020-01-01/extension/exit/error" { + exitErrorRequestCount++ + + w.WriteHeader(200) + _, _ = w.Write([]byte("")) + } + + if r.URL.Path == "/2020-08-15/logs" { + logRegisterRequestCount++ + + w.WriteHeader(200) + _, _ = w.Write(nil) + + } + + if r.URL.Path == "/2020-01-01/extension/event/next" { + nextEventRequestCount++ + + w.WriteHeader(200) + res, err := json.Marshal(api.InvocationEvent{ + EventType: api.Invoke, + DeadlineMs: 1000, + RequestID: "12345", + InvokedFunctionARN: "arn:aws:lambda:us-east-1:12345:foobar", + ShutdownReason: "", + Tracing: nil, + }) + assert.Nil(t, err) + _, _ = w.Write(res) + + cancel() + } + })) + defer srv.Close() + + url := srv.URL[7:] + + _ = os.Setenv(api.LambdaHostPortEnvVar, url) + defer os.Unsetenv(api.LambdaHostPortEnvVar) + + _ = os.Setenv("NEW_RELIC_LICENSE_KEY", "foobar") + defer os.Unsetenv("NEW_RELIC_LICENSE_KEY") + + _ = os.Setenv("NEW_RELIC_LOG_SERVER_HOST", "localhost") + defer os.Unsetenv("NEW_RELIC_LOG_SERVER_HOST") + + _ = os.Setenv("NEW_RELIC_EXTENSION_LOG_LEVEL", "DEBUG") + defer os.Unsetenv("NEW_RELIC_EXTENSION_LOG_LEVEL") + + assert.NotPanics(t, main) + + assert.Equal(t, 1, registerRequestCount) + assert.Equal(t, 0, initErrorRequestCount) + assert.Equal(t, 0, exitErrorRequestCount) + assert.Equal(t, 1, logRegisterRequestCount) + assert.Equal(t, 1, nextEventRequestCount) +} + +func overrideContext(ctx context.Context) { + rootCtx = ctx +} diff --git a/telemetry/batch.go b/telemetry/batch.go index ee8c99a..5de603f 100644 --- a/telemetry/batch.go +++ b/telemetry/batch.go @@ -20,9 +20,9 @@ type Batch struct { } // NewBatch constructs a new batch. -func NewBatch(ripeMillis int64, rotMillis int64) Batch { +func NewBatch(ripeMillis int64, rotMillis int64) *Batch { initialSize := uint32(math.Min(float64(ripeMillis)/100, 100)) - return Batch{ + return &Batch{ lastHarvest: epochStart, eldest: epochStart, invocations: make(map[string]*Invocation, initialSize), diff --git a/telemetry/batch_test.go b/telemetry/batch_test.go index 55aefa9..bfef2b0 100644 --- a/telemetry/batch_test.go +++ b/telemetry/batch_test.go @@ -2,9 +2,10 @@ package telemetry import ( "bytes" - "github.com/stretchr/testify/assert" "testing" "time" + + "github.com/stretchr/testify/assert" ) const ( diff --git a/telemetry/client.go b/telemetry/client.go index ec2af8d..5070834 100644 --- a/telemetry/client.go +++ b/telemetry/client.go @@ -2,6 +2,7 @@ package telemetry import ( "bytes" + "context" "io/ioutil" "net/http" "net/url" @@ -78,7 +79,7 @@ func getLogEndpointURL(licenseKey string, logEndpointOverride string) string { return LogEndpointUS } -func (c *Client) SendTelemetry(invokedFunctionARN string, telemetry [][]byte) (error, int) { +func (c *Client) SendTelemetry(ctx context.Context, invokedFunctionARN string, telemetry [][]byte) (error, int) { start := time.Now() logEvents := make([]LogsEvent, 0, len(telemetry)) for _, payload := range telemetry { @@ -92,7 +93,7 @@ func (c *Client) SendTelemetry(invokedFunctionARN string, telemetry [][]byte) (e } var builder requestBuilder = func(buffer *bytes.Buffer) (*http.Request, error) { - return BuildVortexRequest(c.telemetryEndpoint, buffer, util.Name, c.licenseKey) + return BuildVortexRequest(ctx, c.telemetryEndpoint, buffer, util.Name, c.licenseKey) } transmitStart := time.Now() @@ -179,7 +180,7 @@ func (c *Client) sendPayloads(compressedPayloads []*bytes.Buffer, builder reques return successCount, sentBytes, nil } -func (c *Client) SendFunctionLogs(lines []logserver.LogLine) error { +func (c *Client) SendFunctionLogs(ctx context.Context, lines []logserver.LogLine) error { start := time.Now() common := map[string]interface{}{ @@ -204,7 +205,7 @@ func (c *Client) SendFunctionLogs(lines []logserver.LogLine) error { compressedPayloads := []*bytes.Buffer{compressedPayload} var builder requestBuilder = func(buffer *bytes.Buffer) (*http.Request, error) { - req, err := BuildVortexRequest(c.logEndpoint, buffer, util.Name, c.licenseKey) + req, err := BuildVortexRequest(ctx, c.logEndpoint, buffer, util.Name, c.licenseKey) if err != nil { return nil, err } diff --git a/telemetry/client_test.go b/telemetry/client_test.go index 91ff416..6611361 100644 --- a/telemetry/client_test.go +++ b/telemetry/client_test.go @@ -1,6 +1,7 @@ package telemetry import ( + "context" "encoding/json" "io/ioutil" "net/http" @@ -43,8 +44,9 @@ func TestClientSend(t *testing.T) { client := NewWithHTTPClient(srv.Client(), "", "a mock license key", srv.URL, srv.URL) + ctx := context.Background() bytes := []byte("foobar") - err, successCount := client.SendTelemetry("arn:aws:lambda:us-east-1:1234:function:newrelic-example-go", [][]byte{bytes}) + err, successCount := client.SendTelemetry(ctx, "arn:aws:lambda:us-east-1:1234:function:newrelic-example-go", [][]byte{bytes}) assert.NoError(t, err) assert.Equal(t, 1, successCount) @@ -93,8 +95,9 @@ func TestClientSendRetry(t *testing.T) { httpClient.Timeout = 200 * time.Millisecond client := NewWithHTTPClient(httpClient, "", "a mock license key", srv.URL, srv.URL) + ctx := context.Background() bytes := []byte("foobar") - err, successCount := client.SendTelemetry("arn:aws:lambda:us-east-1:1234:function:newrelic-example-go", [][]byte{bytes}) + err, successCount := client.SendTelemetry(ctx, "arn:aws:lambda:us-east-1:1234:function:newrelic-example-go", [][]byte{bytes}) assert.NoError(t, err) assert.Equal(t, 1, successCount) @@ -114,14 +117,30 @@ func TestClientSendOutOfRetries(t *testing.T) { httpClient.Timeout = 200 * time.Millisecond client := NewWithHTTPClient(httpClient, "", "a mock license key", srv.URL, srv.URL) + ctx := context.Background() bytes := []byte("foobar") - err, successCount := client.SendTelemetry("arn:aws:lambda:us-east-1:1234:function:newrelic-example-go", [][]byte{bytes}) + err, successCount := client.SendTelemetry(ctx, "arn:aws:lambda:us-east-1:1234:function:newrelic-example-go", [][]byte{bytes}) assert.NoError(t, err) assert.Equal(t, 0, successCount) assert.Equal(t, int32(retries), atomic.LoadInt32(&count)) } +func TestClientUnreachableEndpoint(t *testing.T) { + httpClient := &http.Client{ + Timeout: time.Millisecond * 1, + } + + client := NewWithHTTPClient(httpClient, "", "a mock license key", "http://10.123.123.123:12345", "http://10.123.123.123:12345") + + ctx := context.Background() + bytes := []byte("foobar") + err, successCount := client.SendTelemetry(ctx, "arn:aws:lambda:us-east-1:1234:function:newrelic-example-go", [][]byte{bytes}) + + assert.Nil(t, err) + assert.Equal(t, 0, successCount) +} + func TestGetInfraEndpointURL(t *testing.T) { assert.Equal(t, "barbaz", getInfraEndpointURL("foobar", "barbaz")) assert.Equal(t, InfraEndpointUS, getInfraEndpointURL("us license key", "")) diff --git a/telemetry/request.go b/telemetry/request.go index 8a39272..40f62c0 100644 --- a/telemetry/request.go +++ b/telemetry/request.go @@ -2,6 +2,7 @@ package telemetry import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -133,8 +134,8 @@ func CompressedPayloadsForLogEvents(logsEvents []LogsEvent, functionName string, } // BuildVortexRequest builds a Vortex HTTP request -func BuildVortexRequest(url string, compressed *bytes.Buffer, userAgent string, licenseKey string) (*http.Request, error) { - req, err := http.NewRequest("POST", url, compressed) +func BuildVortexRequest(ctx context.Context, url string, compressed *bytes.Buffer, userAgent string, licenseKey string) (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, "POST", url, compressed) if err != nil { return nil, fmt.Errorf("error creating request: %v", err) } diff --git a/util/logger.go b/util/logger.go index e8f50ad..e60a616 100644 --- a/util/logger.go +++ b/util/logger.go @@ -63,3 +63,7 @@ func Logln(v ...interface{}) { func Fatal(v ...interface{}) { log.Fatal(v...) } + +func Panic(v ...interface{}) { + log.Panic(v...) +}