Skip to content

Commit

Permalink
Retry proxy on connection refused error (#368)
Browse files Browse the repository at this point in the history
  • Loading branch information
benbjohnson authored Jul 26, 2023
1 parent 1db7517 commit e007fd1
Showing 1 changed file with 41 additions and 1 deletion.
42 changes: 41 additions & 1 deletion http/proxy_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ package http

import (
"context"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
"regexp"
"syscall"
"time"

"github.com/superfly/litefs"
Expand Down Expand Up @@ -61,6 +63,8 @@ type ProxyServer struct {

// Time before cookie expires on client.
CookieExpiry time.Duration

HTTPTransport *http.Transport
}

// NewProxyServer returns a new instance of ProxyServer.
Expand All @@ -79,6 +83,19 @@ func NewProxyServer(store *litefs.Store) *ProxyServer {
Handler: http.HandlerFunc(s.serveHTTP),
}

s.HTTPTransport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: dialContextWithRetry(&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}),
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}

return s
}

Expand Down Expand Up @@ -238,7 +255,7 @@ func (s *ProxyServer) proxyToTarget(w http.ResponseWriter, r *http.Request, pass
r.URL.Scheme = "http"
r.URL.Host = s.Target

resp, err := http.DefaultTransport.RoundTrip(r)
resp, err := s.HTTPTransport.RoundTrip(r)
if err != nil {
http.Error(w, "Proxy error: "+err.Error(), http.StatusBadGateway)
return
Expand Down Expand Up @@ -295,3 +312,26 @@ func (s *ProxyServer) logf(format string, v ...any) {
log.Printf(format, v...)
}
}

// dialContextWithRetry returns a function that will retry
func dialContextWithRetry(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) {
return func(ctx context.Context, network, address string) (net.Conn, error) {
timeout := time.NewTimer(dialer.Timeout)
defer timeout.Stop()

for {
conn, err := dialer.DialContext(ctx, network, address)
if !errors.Is(err, syscall.ECONNREFUSED) {
return conn, err
}

select {
case <-ctx.Done():
return nil, context.Cause(ctx)
case <-timeout.C:
return nil, err
case <-time.After(100 * time.Millisecond):
}
}
}
}

0 comments on commit e007fd1

Please sign in to comment.