diff --git a/colly.go b/colly.go index fdca9451..4a957d6c 100644 --- a/colly.go +++ b/colly.go @@ -671,6 +671,10 @@ func (c *Collector) fetch(u, method string, depth int, requestData io.Reader, ct ID: atomic.AddUint32(&c.requestCount, 1), } + if req.Header.Get("Accept") == "" { + req.Header.Set("Accept", "*/*") + } + c.handleOnRequest(request) if request.abort { @@ -681,10 +685,6 @@ func (c *Collector) fetch(u, method string, depth int, requestData io.Reader, ct req.Header.Add("Content-Type", "application/x-www-form-urlencoded") } - if req.Header.Get("Accept") == "" { - req.Header.Set("Accept", "*/*") - } - var hTrace *HTTPTrace if c.TraceHTTP { hTrace = &HTTPTrace{} diff --git a/colly_test.go b/colly_test.go index 4358b63e..e330fc2e 100644 --- a/colly_test.go +++ b/colly_test.go @@ -147,6 +147,11 @@ func newUnstartedTestServer() *httptest.Server { w.Write([]byte(r.Host)) }) + mux.HandleFunc("/accept_header", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(r.Header.Get("Accept"))) + }) + mux.HandleFunc("/custom_header", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) w.Write([]byte(r.Header.Get("Test"))) @@ -424,6 +429,39 @@ var newCollectorTests = map[string]func(*testing.T){ }, } +func TestNoAcceptHeader(t *testing.T) { + ts := newTestServer() + defer ts.Close() + + var receivedHeader string + // checks if Accept is enabled by default + func() { + c := NewCollector() + c.OnResponse(func(resp *Response) { + receivedHeader = string(resp.Body) + }) + c.Visit(ts.URL + "/accept_header") + if receivedHeader != "*/*" { + t.Errorf("default Accept header isn't */*. got: %v", receivedHeader) + } + }() + + // checks if Accept can be disabled + func() { + c := NewCollector() + c.OnRequest(func(r *Request) { + r.Headers.Del("Accept") + }) + c.OnResponse(func(resp *Response) { + receivedHeader = string(resp.Body) + }) + c.Visit(ts.URL + "/accept_header") + if receivedHeader != "" { + t.Errorf("failed to pass request with no Accept header. got: %v", receivedHeader) + } + }() +} + func TestNewCollector(t *testing.T) { t.Run("Functional Options", func(t *testing.T) { for name, test := range newCollectorTests {