From ddc819000fa55599f9dad20e150902cfd46322f3 Mon Sep 17 00:00:00 2001 From: darkweak Date: Sat, 30 Dec 2023 01:29:54 +0100 Subject: [PATCH] fix(chore): POST request rewrite body --- context/cache.go | 4 +++ context/graphql.go | 26 +++++++++++++++ context/key.go | 4 +++ context/method.go | 4 +++ context/mode.go | 6 +++- context/now.go | 4 +++ context/timeout.go | 6 +++- context/types.go | 5 +-- context/types_test.go | 2 +- pkg/middleware/middleware.go | 65 ++++++++++++++++++------------------ pkg/middleware/writer.go | 1 - 11 files changed, 89 insertions(+), 38 deletions(-) diff --git a/context/cache.go b/context/cache.go index 4f64710e5..a2ff6ff33 100644 --- a/context/cache.go +++ b/context/cache.go @@ -19,6 +19,10 @@ type cacheContext struct { cacheName string } +func (*cacheContext) SetContextWithBaseRequest(req *http.Request, _ *http.Request) *http.Request { + return req +} + func (cc *cacheContext) SetupContext(c configurationtypes.AbstractConfigurationInterface) { cc.cacheName = defaultCacheName if c.GetDefaultCache().GetCacheName() != "" { diff --git a/context/graphql.go b/context/graphql.go index b04cc4475..625c6746d 100644 --- a/context/graphql.go +++ b/context/graphql.go @@ -21,6 +21,32 @@ type graphQLContext struct { custom bool } +func (g *graphQLContext) SetContextWithBaseRequest(req *http.Request, baseRq *http.Request) *http.Request { + ctx := req.Context() + ctx = context.WithValue(ctx, GraphQL, g.custom) + ctx = context.WithValue(ctx, HashBody, "") + ctx = context.WithValue(ctx, IsMutationRequest, false) + + if g.custom && req.Body != nil { + b := bytes.NewBuffer([]byte{}) + _, _ = io.Copy(b, req.Body) + req.Body = io.NopCloser(b) + baseRq.Body = io.NopCloser(b) + + if b.Len() > 0 { + if isMutation(b.Bytes()) { + ctx = context.WithValue(ctx, IsMutationRequest, true) + } else { + h := sha256.New() + h.Write(b.Bytes()) + ctx = context.WithValue(ctx, HashBody, fmt.Sprintf("-%x", h.Sum(nil))) + } + } + } + + return req.WithContext(ctx) +} + func (g *graphQLContext) SetupContext(c configurationtypes.AbstractConfigurationInterface) { if len(c.GetDefaultCache().GetAllowedHTTPVerbs()) != 0 { g.custom = true diff --git a/context/key.go b/context/key.go index 4fb731f56..40e2c821d 100644 --- a/context/key.go +++ b/context/key.go @@ -24,6 +24,10 @@ type keyContext struct { overrides []map[*regexp.Regexp]keyContext } +func (*keyContext) SetContextWithBaseRequest(req *http.Request, _ *http.Request) *http.Request { + return req +} + func (g *keyContext) SetupContext(c configurationtypes.AbstractConfigurationInterface) { k := c.GetDefaultCache().GetKey() g.disable_body = k.DisableBody diff --git a/context/method.go b/context/method.go index 95d20cfac..74becc87b 100644 --- a/context/method.go +++ b/context/method.go @@ -16,6 +16,10 @@ type methodContext struct { custom bool } +func (*methodContext) SetContextWithBaseRequest(req *http.Request, _ *http.Request) *http.Request { + return req +} + func (m *methodContext) SetupContext(c configurationtypes.AbstractConfigurationInterface) { m.allowedVerbs = defaultVerbs if len(c.GetDefaultCache().GetAllowedHTTPVerbs()) != 0 { diff --git a/context/mode.go b/context/mode.go index 0a1f903fe..dbe682d1c 100644 --- a/context/mode.go +++ b/context/mode.go @@ -13,6 +13,10 @@ type ModeContext struct { Strict, Bypass_request, Bypass_response bool } +func (*ModeContext) SetContextWithBaseRequest(req *http.Request, _ *http.Request) *http.Request { + return req +} + func (mc *ModeContext) SetupContext(c configurationtypes.AbstractConfigurationInterface) { mode := c.GetDefaultCache().GetMode() mc.Bypass_request = mode == "bypass" || mode == "bypass_request" @@ -25,4 +29,4 @@ func (mc *ModeContext) SetContext(req *http.Request) *http.Request { return req.WithContext(context.WithValue(req.Context(), Mode, mc)) } -var _ ctx = (*cacheContext)(nil) +var _ ctx = (*ModeContext)(nil) diff --git a/context/now.go b/context/now.go index 898cc18fe..d0d4e0f3b 100644 --- a/context/now.go +++ b/context/now.go @@ -12,6 +12,10 @@ const Now ctxKey = "souin_ctx.NOW" type nowContext struct{} +func (*nowContext) SetContextWithBaseRequest(req *http.Request, _ *http.Request) *http.Request { + return req +} + func (cc *nowContext) SetupContext(_ configurationtypes.AbstractConfigurationInterface) {} func (cc *nowContext) SetContext(req *http.Request) *http.Request { diff --git a/context/timeout.go b/context/timeout.go index e6468aef0..e2260bdf0 100644 --- a/context/timeout.go +++ b/context/timeout.go @@ -22,6 +22,10 @@ type timeoutContext struct { timeoutCache, timeoutBackend time.Duration } +func (*timeoutContext) SetContextWithBaseRequest(req *http.Request, _ *http.Request) *http.Request { + return req +} + func (t *timeoutContext) SetupContext(c configurationtypes.AbstractConfigurationInterface) { t.timeoutBackend = defaultTimeoutBackend t.timeoutCache = defaultTimeoutCache @@ -40,4 +44,4 @@ func (t *timeoutContext) SetContext(req *http.Request) *http.Request { return req.WithContext(context.WithValue(context.WithValue(ctx, TimeoutCancel, cancel), TimeoutCache, t.timeoutCache)) } -var _ ctx = (*cacheContext)(nil) +var _ ctx = (*timeoutContext)(nil) diff --git a/context/types.go b/context/types.go index 38bf5ed19..34e56363f 100644 --- a/context/types.go +++ b/context/types.go @@ -12,6 +12,7 @@ type ( ctx interface { SetupContext(c configurationtypes.AbstractConfigurationInterface) SetContext(req *http.Request) *http.Request + SetContextWithBaseRequest(req *http.Request, baseRq *http.Request) *http.Request } Context struct { @@ -53,6 +54,6 @@ func (c *Context) SetBaseContext(req *http.Request) *http.Request { return c.Mode.SetContext(c.Timeout.SetContext(c.Method.SetContext(c.CacheName.SetContext(c.Now.SetContext(req))))) } -func (c *Context) SetContext(req *http.Request) *http.Request { - return c.Key.SetContext(c.GraphQL.SetContext(req)) +func (c *Context) SetContext(req *http.Request, baseRq *http.Request) *http.Request { + return c.Key.SetContext(c.GraphQL.SetContextWithBaseRequest(req, baseRq)) } diff --git a/context/types_test.go b/context/types_test.go index 120b8c548..9333bec8c 100644 --- a/context/types_test.go +++ b/context/types_test.go @@ -37,7 +37,7 @@ func Test_Context_SetContext(t *testing.T) { co.Init(&c) req := httptest.NewRequest(http.MethodGet, "http://domain.com", nil) - req = co.SetContext(req) + req = co.SetContext(req, req) if req.Context().Value(Key) != "GET-http-domain.com-" { t.Errorf("The Key context must be equal to GET-http-domain.com-, %s given.", req.Context().Value(Key)) } diff --git a/pkg/middleware/middleware.go b/pkg/middleware/middleware.go index a99b88b37..3af89621e 100644 --- a/pkg/middleware/middleware.go +++ b/pkg/middleware/middleware.go @@ -302,6 +302,7 @@ func (s *SouinBaseHandler) Upstream( sfValue, err, _ := s.singleflightPool.Do(cachedKey, func() (interface{}, error) { shared = false if e := next(customWriter, rq); e != nil { + s.Configuration.GetLogger().Sugar().Warnf("%#v", e) customWriter.Header().Set("Cache-Status", fmt.Sprintf("%s; fwd=uri-miss; key=%s; detail=SERVE-HTTP-ERROR", rq.Context().Value(context.CacheName), rfc.GetCacheKeyFromCtx(rq.Context()))) return nil, e } @@ -427,61 +428,61 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n return nil } - rq = s.context.SetBaseContext(rq) - cacheName := rq.Context().Value(context.CacheName).(string) + req := s.context.SetBaseContext(rq) + cacheName := req.Context().Value(context.CacheName).(string) if rq.Header.Get("Upgrade") == "websocket" || (s.ExcludeRegex != nil && s.ExcludeRegex.MatchString(rq.RequestURI)) { rw.Header().Set("Cache-Status", cacheName+"; fwd=bypass; detail=EXCLUDED-REQUEST-URI") - return next(rw, rq) + return next(rw, req) } - if !rq.Context().Value(context.SupportedMethod).(bool) { + if !req.Context().Value(context.SupportedMethod).(bool) { rw.Header().Set("Cache-Status", cacheName+"; fwd=bypass; detail=UNSUPPORTED-METHOD") - err := next(rw, rq) - s.SurrogateKeyStorer.Invalidate(rq.Method, rw.Header()) + err := next(rw, req) + s.SurrogateKeyStorer.Invalidate(req.Method, rw.Header()) return err } - requestCc, coErr := cacheobject.ParseRequestCacheControl(rq.Header.Get("Cache-Control")) + requestCc, coErr := cacheobject.ParseRequestCacheControl(req.Header.Get("Cache-Control")) - modeContext := rq.Context().Value(context.Mode).(*context.ModeContext) + modeContext := req.Context().Value(context.Mode).(*context.ModeContext) if !modeContext.Bypass_request && (coErr != nil || requestCc == nil) { rw.Header().Set("Cache-Status", cacheName+"; fwd=bypass; detail=CACHE-CONTROL-EXTRACTION-ERROR") - err := next(rw, rq) - s.SurrogateKeyStorer.Invalidate(rq.Method, rw.Header()) + err := next(rw, req) + s.SurrogateKeyStorer.Invalidate(req.Method, rw.Header()) return err } - rq = s.context.SetContext(rq) - if rq.Context().Value(context.IsMutationRequest).(bool) { + req = s.context.SetContext(req, rq) + if req.Context().Value(context.IsMutationRequest).(bool) { rw.Header().Set("Cache-Status", cacheName+"; fwd=bypass; detail=IS-MUTATION-REQUEST") - err := next(rw, rq) - s.SurrogateKeyStorer.Invalidate(rq.Method, rw.Header()) + err := next(rw, req) + s.SurrogateKeyStorer.Invalidate(req.Method, rw.Header()) return err } - cachedKey := rq.Context().Value(context.Key).(string) + cachedKey := req.Context().Value(context.Key).(string) bufPool := s.bufPool.Get().(*bytes.Buffer) bufPool.Reset() defer s.bufPool.Put(bufPool) - customWriter := NewCustomWriter(rq, rw, bufPool) + customWriter := NewCustomWriter(req, rw, bufPool) go func(req *http.Request, crw *CustomWriter) { <-req.Context().Done() crw.mutex.Lock() crw.headersSent = true crw.mutex.Unlock() - }(rq, customWriter) + }(req, customWriter) s.Configuration.GetLogger().Sugar().Debugf("Request cache-control %+v", requestCc) if modeContext.Bypass_request || !requestCc.NoCache { - validator := rfc.ParseRequest(rq) + validator := rfc.ParseRequest(req) var response *http.Response for _, currentStorer := range s.Storers { - response = currentStorer.Prefix(cachedKey, rq, validator) + response = currentStorer.Prefix(cachedKey, req, validator) if response != nil { s.Configuration.GetLogger().Sugar().Debugf("Found response in the %s storage", currentStorer.Name()) break @@ -508,14 +509,14 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n } if validator.NeedRevalidation { - err := s.Revalidate(validator, next, customWriter, rq, requestCc, cachedKey) + err := s.Revalidate(validator, next, customWriter, req, requestCc, cachedKey) _, _ = customWriter.Send() return err } if resCc, _ := cacheobject.ParseResponseCacheControl(response.Header.Get("Cache-Control")); resCc.NoCachePresent { prometheus.Increment(prometheus.NoCachedResponseCounter) - err := s.Revalidate(validator, next, customWriter, rq, requestCc, cachedKey) + err := s.Revalidate(validator, next, customWriter, req, requestCc, cachedKey) _, _ = customWriter.Send() return err @@ -524,7 +525,7 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n if !modeContext.Strict || rfc.ValidateMaxAgeCachedResponse(requestCc, response) != nil { customWriter.Headers = response.Header customWriter.statusCode = response.StatusCode - s.Configuration.GetLogger().Sugar().Debugf("Serve from cache %+v", rq) + s.Configuration.GetLogger().Sugar().Debugf("Serve from cache %+v", req) _, _ = io.Copy(customWriter.Buf, response.Body) _, err := customWriter.Send() prometheus.Increment(prometheus.CachedResponseCounter) @@ -533,7 +534,7 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n } } else if response == nil && !requestCc.OnlyIfCached && (requestCc.MaxStaleSet || requestCc.MaxStale > -1) { for _, currentStorer := range s.Storers { - response = currentStorer.Prefix(storage.StalePrefix+cachedKey, rq, validator) + response = currentStorer.Prefix(storage.StalePrefix+cachedKey, req, validator) if response != nil { break } @@ -549,10 +550,10 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n rfc.HitStaleCache(&response.Header) _, _ = io.Copy(customWriter.Buf, response.Body) _, err := customWriter.Send() - customWriter = NewCustomWriter(rq, rw, bufPool) + customWriter = NewCustomWriter(req, rw, bufPool) go func(v *rfc.Revalidator, goCw *CustomWriter, goRq *http.Request, goNext func(http.ResponseWriter, *http.Request) error, goCc *cacheobject.RequestCacheDirectives, goCk string) { _ = s.Revalidate(v, goNext, goCw, goRq, goCc, goCk) - }(validator, customWriter, rq, next, requestCc, cachedKey) + }(validator, customWriter, req, next, requestCc, cachedKey) buf := s.bufPool.Get().(*bytes.Buffer) buf.Reset() defer s.bufPool.Put(buf) @@ -561,8 +562,8 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n } if responseCc.MustRevalidate || responseCc.NoCachePresent || validator.NeedRevalidation { - rq.Header["If-None-Match"] = append(rq.Header["If-None-Match"], validator.ResponseETag) - err := s.Revalidate(validator, next, customWriter, rq, requestCc, cachedKey) + req.Header["If-None-Match"] = append(req.Header["If-None-Match"], validator.ResponseETag) + err := s.Revalidate(validator, next, customWriter, req, requestCc, cachedKey) if err != nil { if responseCc.StaleIfError > -1 || requestCc.StaleIfError > 0 { code := fmt.Sprintf("; fwd-status=%d", customWriter.statusCode) @@ -623,13 +624,13 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n } errorCacheCh := make(chan error) - go func() { - errorCacheCh <- s.Upstream(customWriter, rq, next, requestCc, cachedKey) - }() + go func(vr *http.Request) { + errorCacheCh <- s.Upstream(customWriter, vr, next, requestCc, cachedKey) + }(req) select { - case <-rq.Context().Done(): - switch rq.Context().Err() { + case <-req.Context().Done(): + switch req.Context().Err() { case baseCtx.DeadlineExceeded: customWriter.WriteHeader(http.StatusGatewayTimeout) rw.Header().Set("Cache-Status", cacheName+"; fwd=bypass; detail=DEADLINE-EXCEEDED") diff --git a/pkg/middleware/writer.go b/pkg/middleware/writer.go index 916dfef07..e46c2c868 100644 --- a/pkg/middleware/writer.go +++ b/pkg/middleware/writer.go @@ -84,7 +84,6 @@ func (r *CustomWriter) Send() (int, error) { r.Header().Del(rfc.StoredTTLHeader) if !r.headersSent { - // r.Rw.Header().Set("Content-Length", fmt.Sprintf("%d", len(b))) r.Rw.WriteHeader(r.statusCode) r.headersSent = true