diff --git a/internal/server/server.go b/internal/server/server.go index a25fed2..38b6a52 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -152,44 +152,55 @@ func (s *Server) doRequests(incomingRequestHeaders http.Header) { var wg sync.WaitGroup for _, request := range s.requests { - wg.Add(1) - go func(srv *Server, request Request) { - defer wg.Done() - request.Do(incomingRequestHeaders, srv) - }(s, request) + var i uint + for i = 0; i < request.Count; i++ { + wg.Add(1) + go func(srv *Server, request Request) { + defer wg.Done() + request.Do(incomingRequestHeaders, srv) + }(s, request) + } } wg.Wait() } func (request Request) Do(incomingRequestHeaders http.Header, s *Server) { - var i uint + s.logger.WithFields(log.Fields{ + "url": request.URL, + }).Info("outgoing request") + httpClient := &http.Client{} - for i = 0; i < request.Count; i++ { + httpReq, err := http.NewRequest("GET", request.URL, nil) + if err != nil { s.logger.WithFields(log.Fields{ "url": request.URL, - }).Info("outgoing request") - httpReq, _ := http.NewRequest("GET", request.URL, nil) - propagateHeaders(incomingRequestHeaders, httpReq) - response, err := httpClient.Do(httpReq) - if err != nil { - s.logger.WithFields(log.Fields{ - "url": request.URL, - }).Error(err.Error()) - continue - } + }).Error(err.Error()) + return + } + + propagateHeaders(incomingRequestHeaders, httpReq) + + response, err := httpClient.Do(httpReq) + if err != nil { s.logger.WithFields(log.Fields{ - "url": request.URL, - "responseCode": response.StatusCode, - }).Info("response to outgoing request") + "url": request.URL, + }).Error(err.Error()) + return } + + defer response.Body.Close() + s.logger.WithFields(log.Fields{ + "url": request.URL, + "responseCode": response.StatusCode, + }).Info("response to outgoing request") } func propagateHeaders(incomingRequestHeaders http.Header, httpReq *http.Request) { for _, header := range headersToPropagate { - val := incomingRequestHeaders[header] - if len(val) != 0 { - httpReq.Header[header] = val + val := incomingRequestHeaders.Get(header) + if val != "" { + httpReq.Header.Set(header, val) } } }