diff --git a/main.go b/main.go index 27eed19..1de492e 100644 --- a/main.go +++ b/main.go @@ -191,7 +191,6 @@ func mainLoop(ctx context.Context, invocationClient *client.InvocationClient, ba select { case <-ctx.Done(): // We're already done - util.Logln(ctx.Err()) return eventCounter, "" default: // Our call to next blocks. It is likely that the container is frozen immediately after we call NextEvent. @@ -201,7 +200,6 @@ func mainLoop(ctx context.Context, invocationClient *client.InvocationClient, ba eventStart := time.Now() if err != nil { - util.Logln(err) err = invocationClient.ExitError(ctx, "NextEventError.Main", err) if err != nil { @@ -224,7 +222,6 @@ func mainLoop(ctx context.Context, invocationClient *client.InvocationClient, ba util.Logf("We suspected a timeout for request %s but got telemetry anyway", lastRequestId) default: } - } invokedFunctionARN = event.InvokedFunctionARN @@ -260,16 +257,19 @@ func mainLoop(ctx context.Context, invocationClient *client.InvocationClient, ba // Set the timeout timer for a smidge before the actual timeout; // we can recover from early. - timeoutWatchBegins := time.Millisecond * 100 - timeout := timeoutInstant.Sub(time.Now()) - timeoutWatchBegins + timeoutWatchBegins := 100 * time.Millisecond + hardTimeout := timeoutInstant.Sub(time.Now()) + softTimeout := hardTimeout - timeoutWatchBegins + + hardCtx, hardCancel := context.WithTimeout(ctx, hardTimeout) + defer hardCancel() - invCtx, cancel := context.WithTimeout(ctx, timeout) - defer cancel() + softCtx, softCancel := context.WithTimeout(hardCtx, softTimeout) + defer softCancel() select { - case <-invCtx.Done(): + case <-softCtx.Done(): // We are about to timeout - util.Debugln("Timeout suspected: ", invCtx.Err()) probablyTimeout = true continue case telemetryBytes := <-telemetryChan: @@ -282,7 +282,7 @@ func mainLoop(ctx context.Context, invocationClient *client.InvocationClient, ba pollLogServer(logServer, batch) harvested := batch.Harvest(time.Now()) - shipHarvest(ctx, harvested, telemetryClient, invokedFunctionARN) + shipHarvest(hardCtx, harvested, telemetryClient, invokedFunctionARN) } lastEventStart = eventStart diff --git a/main_test.go b/main_test.go index 74b0900..7115f96 100644 --- a/main_test.go +++ b/main_test.go @@ -3,10 +3,12 @@ package main import ( "context" "encoding/json" + "fmt" "net/http" "net/http/httptest" "os" "testing" + "time" "github.com/newrelic/newrelic-lambda-extension/lambda/extension/api" "github.com/newrelic/newrelic-lambda-extension/util" @@ -532,6 +534,117 @@ func TestMainTimeout(t *testing.T) { assert.Equal(t, 1, nextEventRequestCount) } +func TestMainTimeoutUnreachable(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(200*time.Millisecond)) + defer cancel() + overrideContext(ctx) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer util.Close(r.Body) + + if r.URL.Path == "/2020-01-01/extension/register" { + w.Header().Add(api.ExtensionIdHeader, "test-ext-id") + w.WriteHeader(200) + res, err := json.Marshal(api.RegistrationResponse{ + FunctionName: "foobar", + FunctionVersion: "$latest", + Handler: "lambda.handler", + }) + assert.Nil(t, err) + _, _ = w.Write(res) + } + + if r.URL.Path == "/2020-01-01/extension/init/error" { + w.WriteHeader(200) + _, _ = w.Write([]byte("")) + } + + if r.URL.Path == "/2020-01-01/extension/exit/error" { + w.WriteHeader(200) + _, _ = w.Write(nil) + } + + if r.URL.Path == "/2020-08-15/logs" { + w.WriteHeader(200) + _, _ = w.Write(nil) + } + + if r.URL.Path == "/2020-01-01/extension/event/next" { + time.Sleep(25 * time.Millisecond) + + w.WriteHeader(200) + res, err := json.Marshal(api.InvocationEvent{ + EventType: api.Invoke, + DeadlineMs: 100, + RequestID: "12345", + InvokedFunctionARN: "arn:aws:lambda:us-east-1:12345:foobar", + ShutdownReason: "", + Tracing: nil, + }) + assert.Nil(t, err) + _, _ = w.Write(res) + } + + if r.URL.Path == "/aws/lambda/v1" { + time.Sleep(5 * time.Second) + + w.WriteHeader(200) + _, _ = w.Write(nil) + } + })) + defer srv.Close() + + url := srv.URL[7:] + + _ = os.Setenv(api.LambdaHostPortEnvVar, url) + defer os.Unsetenv(api.LambdaHostPortEnvVar) + + _ = os.Setenv("NEW_RELIC_LICENSE_KEY", "foobar") + defer os.Unsetenv("NEW_RELIC_LICENSE_KEY") + + _ = os.Setenv("NEW_RELIC_LOG_SERVER_HOST", "localhost") + defer os.Unsetenv("NEW_RELIC_LOG_SERVER_HOST") + + _ = os.Setenv("NEW_RELIC_EXTENSION_LOG_LEVEL", "DEBUG") + defer os.Unsetenv("NEW_RELIC_EXTENSION_LOG_LEVEL") + + _ = os.Setenv("NEW_RELIC_TELEMETRY_ENDPOINT", fmt.Sprintf("%s/aws/lambda/v1", srv.URL)) + defer os.Unsetenv("NEW_RELIC_TELEMETRY_ENDPOINT") + + _ = os.Remove("/tmp/newrelic-telemetry") + + go func() { + pipeOpened := false + + for { + select { + case <-ctx.Done(): + return + default: + if _, err := os.Stat("/tmp/newrelic-telemetry"); os.IsNotExist(err) { + if pipeOpened { + return + } else { + continue + } + } else { + pipeOpened = true + } + + pipe, err := os.OpenFile("/tmp/newrelic-telemetry", os.O_WRONLY, 0) + assert.Nil(t, err) + defer pipe.Close() + + pipe.WriteString("foobar\n") + pipe.Close() + time.Sleep(100 * time.Millisecond) + } + } + }() + + assert.NotPanics(t, main) +} + func overrideContext(ctx context.Context) { rootCtx = ctx } diff --git a/telemetry/client.go b/telemetry/client.go index 5070834..8d88147 100644 --- a/telemetry/client.go +++ b/telemetry/client.go @@ -177,6 +177,7 @@ func (c *Client) sendPayloads(compressedPayloads []*bytes.Buffer, builder reques successCount += 1 } } + return successCount, sentBytes, nil }