diff --git a/instrumentation_http.go b/instrumentation_http.go index 583fa60cc..bc6959e4d 100644 --- a/instrumentation_http.go +++ b/instrumentation_http.go @@ -63,15 +63,34 @@ func TracingHandlerFunc(sensor *Sensor, pathTemplate string, handler http.Handle span := tracer.StartSpan("g.http", opts...) defer span.Finish() + var collectableHTTPHeaders []string if t, ok := tracer.(Tracer); ok { - params := collectHTTPParams(req, t.Options().Secrets) + opts := t.Options() + collectableHTTPHeaders = opts.CollectableHTTPHeaders + + params := collectHTTPParams(req, opts.Secrets) if len(params) > 0 { span.SetTag("http.params", params.Encode()) } } + collectedHeaders := make(map[string]string) + // make sure collected headers are sent in case of panic/error defer func() { - // Be sure to capture any kind of panic / error + if len(collectedHeaders) > 0 { + span.SetTag("http.header", collectedHeaders) + } + }() + + // collect request headers + for _, h := range collectableHTTPHeaders { + if v := req.Header.Get(h); v != "" { + collectedHeaders[h] = v + } + } + + defer func() { + // Be sure to capture any kind of panic/error if err := recover(); err != nil { if e, ok := err.(error); ok { span.SetTag("http.error", e.Error()) @@ -94,6 +113,13 @@ func TracingHandlerFunc(sensor *Sensor, pathTemplate string, handler http.Handle ctx = ContextWithSpan(ctx, span) w3ctrace.TracingHandlerFunc(handler)(wrapped, req.WithContext(ctx)) + // collect response headers + for _, h := range collectableHTTPHeaders { + if v := wrapped.Header().Get(h); v != "" { + collectedHeaders[h] = v + } + } + if wrapped.Status > 0 { if wrapped.Status > http.StatusInternalServerError { span.SetTag("http.error", http.StatusText(wrapped.Status)) @@ -136,13 +162,32 @@ func RoundTripper(sensor *Sensor, original http.RoundTripper) http.RoundTripper req = cloneRequest(ContextWithSpan(ctx, span), req) sensor.Tracer().Inject(span.Context(), ot.HTTPHeaders, ot.HTTPHeadersCarrier(req.Header)) + var collectableHTTPHeaders []string if t, ok := sensor.Tracer().(Tracer); ok { - params := collectHTTPParams(req, t.Options().Secrets) + opts := t.Options() + collectableHTTPHeaders = opts.CollectableHTTPHeaders + + params := collectHTTPParams(req, opts.Secrets) if len(params) > 0 { span.SetTag("http.params", params.Encode()) } } + collectedHeaders := make(map[string]string) + // make sure collected headers are sent in case of panic/error + defer func() { + if len(collectedHeaders) > 0 { + span.SetTag("http.header", collectedHeaders) + } + }() + + // collect request headers + for _, h := range collectableHTTPHeaders { + if v := req.Header.Get(h); v != "" { + collectedHeaders[h] = v + } + } + resp, err := original.RoundTrip(req) if err != nil { span.SetTag("http.error", err.Error()) @@ -150,6 +195,13 @@ func RoundTripper(sensor *Sensor, original http.RoundTripper) http.RoundTripper return resp, err } + // collect response headers + for _, h := range collectableHTTPHeaders { + if v := resp.Header.Get(h); v != "" { + collectedHeaders[h] = v + } + } + span.SetTag(string(ext.HTTPStatusCode), resp.StatusCode) return resp, err diff --git a/instrumentation_http_test.go b/instrumentation_http_test.go index 789bbc848..19c8a0f65 100644 --- a/instrumentation_http_test.go +++ b/instrumentation_http_test.go @@ -21,10 +21,14 @@ func TestTracingHandlerFunc_Write(t *testing.T) { }, recorder)) h := instana.TracingHandlerFunc(s, "/{action}", func(w http.ResponseWriter, req *http.Request) { + w.Header().Set("X-Response", "true") + w.Header().Set("X-Custom-Header-2", "response") fmt.Fprintln(w, "Ok") }) req := httptest.NewRequest(http.MethodGet, "/test?q=term", nil) + req.Header.Set("Authorization", "Basic blah") + req.Header.Set("X-Custom-Header-1", "request") rec := httptest.NewRecorder() h.ServeHTTP(rec, req) @@ -48,11 +52,15 @@ func TestTracingHandlerFunc_Write(t *testing.T) { data := span.Data.(instana.HTTPSpanData) assert.Equal(t, instana.HTTPSpanTags{ - Host: "example.com", - Status: http.StatusOK, - Method: "GET", - Path: "/test", - Params: "q=term", + Host: "example.com", + Status: http.StatusOK, + Method: "GET", + Path: "/test", + Params: "q=term", + Headers: map[string]string{ + "x-custom-header-1": "request", + "x-custom-header-2": "response", + }, PathTemplate: "/{action}", }, data.Tags) @@ -337,11 +345,17 @@ func TestRoundTripper(t *testing.T) { return &http.Response{ Status: http.StatusText(http.StatusNotImplemented), StatusCode: http.StatusNotImplemented, + Header: http.Header{ + "X-Response": []string{"true"}, + "X-Custom-Header-2": []string{"response"}, + }, }, nil })) ctx := instana.ContextWithSpan(context.Background(), parentSpan) req := httptest.NewRequest("GET", "http://user:password@example.com/hello?q=term&sensitive_key=s3cr3t&myPassword=qwerty&SECRET_VALUE=1", nil) + req.Header.Set("X-Custom-Header-1", "request") + req.Header.Set("Authorization", "Basic blah") _, err := rt.RoundTrip(req.WithContext(ctx)) require.NoError(t, err) @@ -369,6 +383,10 @@ func TestRoundTripper(t *testing.T) { Status: http.StatusNotImplemented, URL: "http://example.com/hello", Params: "SECRET_VALUE=%3Credacted%3E&myPassword=%3Credacted%3E&q=term&sensitive_key=%3Credacted%3E", + Headers: map[string]string{ + "x-custom-header-1": "request", + "x-custom-header-2": "response", + }, }, data.Tags) } diff --git a/sensor_test.go b/sensor_test.go index f7d423f15..3968971d2 100644 --- a/sensor_test.go +++ b/sensor_test.go @@ -12,6 +12,9 @@ const TestServiceName = "test_service" func TestMain(m *testing.M) { instana.InitSensor(&instana.Options{ Service: TestServiceName, + Tracer: instana.TracerOptions{ + CollectableHTTPHeaders: []string{"x-custom-header-1", "x-custom-header-2"}, + }, }) os.Exit(m.Run())