Skip to content

Commit

Permalink
Fix NTLM and Kerberos
Browse files Browse the repository at this point in the history
  • Loading branch information
juliens authored Feb 6, 2024
1 parent 8f9ad16 commit e11ff98
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 3 deletions.
11 changes: 11 additions & 0 deletions pkg/server/server_entrypoint_tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/traefik/traefik/v2/pkg/safe"
"github.com/traefik/traefik/v2/pkg/server/router"
tcprouter "github.com/traefik/traefik/v2/pkg/server/router/tcp"
"github.com/traefik/traefik/v2/pkg/server/service"
"github.com/traefik/traefik/v2/pkg/tcp"
"github.com/traefik/traefik/v2/pkg/types"
"golang.org/x/net/http2"
Expand Down Expand Up @@ -613,6 +614,16 @@ func createHTTPServer(ctx context.Context, ln net.Listener, configuration *stati
}
}

prevConnContext := serverHTTP.ConnContext
serverHTTP.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
// This adds an empty struct in order to store a RoundTripper in the ConnContext in case of Kerberos or NTLM.
ctx = service.AddTransportOnContext(ctx)
if prevConnContext != nil {
return prevConnContext(ctx, c)
}
return ctx
}

// ConfigureServer configures HTTP/2 with the MaxConcurrentStreams option for the given server.
// Also keeping behavior the same as
// https://cs.opensource.google/go/go/+/refs/tags/go1.17.7:src/net/http/server.go;l=3262
Expand Down
67 changes: 65 additions & 2 deletions pkg/server/service/roundtripper.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package service

import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net"
"net/http"
"reflect"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -149,10 +151,71 @@ func createRoundTripper(cfg *dynamic.ServersTransport) (http.RoundTripper, error

// Return directly HTTP/1.1 transport when HTTP/2 is disabled
if cfg.DisableHTTP2 {
return transport, nil
return &KerberosRoundTripper{
OriginalRoundTripper: transport,
new: func() http.RoundTripper {
return transport.Clone()
},
}, nil
}

return newSmartRoundTripper(transport, cfg.ForwardingTimeouts)
rt, err := newSmartRoundTripper(transport, cfg.ForwardingTimeouts)
if err != nil {
return nil, err
}
return &KerberosRoundTripper{
OriginalRoundTripper: rt,
new: func() http.RoundTripper {
return rt.Clone()
},
}, nil
}

type KerberosRoundTripper struct {
new func() http.RoundTripper
OriginalRoundTripper http.RoundTripper
}

type stickyRoundTripper struct {
RoundTripper http.RoundTripper
}

type transportKeyType string

var transportKey transportKeyType = "transport"

func AddTransportOnContext(ctx context.Context) context.Context {
return context.WithValue(ctx, transportKey, &stickyRoundTripper{})
}

func (k *KerberosRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
value, ok := request.Context().Value(transportKey).(*stickyRoundTripper)
if !ok {
return k.OriginalRoundTripper.RoundTrip(request)
}

if value.RoundTripper != nil {
return value.RoundTripper.RoundTrip(request)
}

resp, err := k.OriginalRoundTripper.RoundTrip(request)

// If we found that we are authenticating with Kerberos (Negotiate) or NTLM.
// We put a dedicated roundTripper in the ConnContext.
// This will stick the next calls to the same connection with the backend.
if err == nil && containsNTLMorNegotiate(resp.Header.Values("WWW-Authenticate")) {
value.RoundTripper = k.new()
}
return resp, err
}

func containsNTLMorNegotiate(h []string) bool {
for _, s := range h {
if strings.HasPrefix(s, "NTLM") || strings.HasPrefix(s, "Negotiate") {
return true
}
}
return false
}

func createRootCACertPool(rootCAs []traefiktls.FileOrContent) *x509.CertPool {
Expand Down
78 changes: 78 additions & 0 deletions pkg/server/service/roundtripper_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package service

import (
"context"
"crypto/tls"
"crypto/x509"
"net"
Expand Down Expand Up @@ -293,3 +294,80 @@ func TestDisableHTTP2(t *testing.T) {
})
}
}

type roundTripperFn func(req *http.Request) (*http.Response, error)

func (r roundTripperFn) RoundTrip(request *http.Request) (*http.Response, error) {
return r(request)
}

func TestKerberosRoundTripper(t *testing.T) {
testCases := []struct {
desc string

originalRoundTripperHeaders map[string][]string

expectedStatusCode []int
expectedDedicatedCount int
expectedOriginalCount int
}{
{
desc: "without special header",
expectedStatusCode: []int{http.StatusUnauthorized, http.StatusUnauthorized, http.StatusUnauthorized},
expectedOriginalCount: 3,
},
{
desc: "with Negotiate (Kerberos)",
originalRoundTripperHeaders: map[string][]string{"Www-Authenticate": {"Negotiate"}},
expectedStatusCode: []int{http.StatusUnauthorized, http.StatusOK, http.StatusOK},
expectedOriginalCount: 1,
expectedDedicatedCount: 2,
},
{
desc: "with NTLM",
originalRoundTripperHeaders: map[string][]string{"Www-Authenticate": {"NTLM"}},
expectedStatusCode: []int{http.StatusUnauthorized, http.StatusOK, http.StatusOK},
expectedOriginalCount: 1,
expectedDedicatedCount: 2,
},
}

for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()

origCount := 0
dedicatedCount := 0
rt := KerberosRoundTripper{
new: func() http.RoundTripper {
return roundTripperFn(func(req *http.Request) (*http.Response, error) {
dedicatedCount++
return &http.Response{
StatusCode: http.StatusOK,
}, nil
})
},
OriginalRoundTripper: roundTripperFn(func(req *http.Request) (*http.Response, error) {
origCount++
return &http.Response{
StatusCode: http.StatusUnauthorized,
Header: test.originalRoundTripperHeaders,
}, nil
}),
}

ctx := AddTransportOnContext(context.Background())
for _, expected := range test.expectedStatusCode {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://127.0.0.1", http.NoBody)
require.NoError(t, err)
resp, err := rt.RoundTrip(req)
require.NoError(t, err)
require.Equal(t, expected, resp.StatusCode)
}

require.Equal(t, test.expectedOriginalCount, origCount)
require.Equal(t, test.expectedDedicatedCount, dedicatedCount)
})
}
}
8 changes: 7 additions & 1 deletion pkg/server/service/smart_roundtripper.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
"golang.org/x/net/http2"
)

func newSmartRoundTripper(transport *http.Transport, forwardingTimeouts *dynamic.ForwardingTimeouts) (http.RoundTripper, error) {
func newSmartRoundTripper(transport *http.Transport, forwardingTimeouts *dynamic.ForwardingTimeouts) (*smartRoundTripper, error) {
transportHTTP1 := transport.Clone()

transportHTTP2, err := http2.ConfigureTransports(transport)
Expand Down Expand Up @@ -53,6 +53,12 @@ type smartRoundTripper struct {
http *http.Transport
}

func (m *smartRoundTripper) Clone() http.RoundTripper {
h := m.http.Clone()
h2 := m.http2.Clone()
return &smartRoundTripper{http: h, http2: h2}
}

func (m *smartRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
// If we have a connection upgrade, we don't use HTTP/2
if httpguts.HeaderValuesContainsToken(req.Header["Connection"], "Upgrade") {
Expand Down

0 comments on commit e11ff98

Please sign in to comment.