Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: correctly handle dns messages in our dns implementation #8768

Merged
merged 1 commit into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ type DNSResolveCacheController struct {
mx sync.Mutex
handler *dns.Handler
nodeHandler *dns.NodeHandler
cache *dns.Cache
rootHandler dnssrv.Handler
runners map[runnerConfig]pair.Pair[func(), <-chan struct{}]
reconcile chan struct{}
originalCtx context.Context //nolint:containedctx
Expand Down Expand Up @@ -130,7 +130,7 @@ func (ctrl *DNSResolveCacheController) Run(ctx context.Context, r controller.Run
runnerCfg := runnerConfig{net: netwk, addr: addr}

if _, ok := ctrl.runners[runnerCfg]; !ok {
runner, rErr := newDNSRunner(runnerCfg, ctrl.cache, ctrl.Logger, cfg.TypedSpec().ServiceHostDNSAddress.IsValid())
runner, rErr := newDNSRunner(runnerCfg, ctrl.rootHandler, ctrl.Logger, cfg.TypedSpec().ServiceHostDNSAddress.IsValid())
if rErr != nil {
return fmt.Errorf("error creating dns runner: %w", rErr)
}
Expand Down Expand Up @@ -200,7 +200,7 @@ func (ctrl *DNSResolveCacheController) init(ctx context.Context) {
ctrl.originalCtx = ctx
ctrl.handler = dns.NewHandler(ctrl.Logger)
ctrl.nodeHandler = dns.NewNodeHandler(ctrl.handler, &stateMapper{state: ctrl.State}, ctrl.Logger)
ctrl.cache = dns.NewCache(ctrl.nodeHandler, ctrl.Logger)
ctrl.rootHandler = dns.NewCache(ctrl.nodeHandler, ctrl.Logger)
ctrl.runners = map[runnerConfig]pair.Pair[func(), <-chan struct{}]{}
ctrl.reconcile = make(chan struct{}, 1)

Expand Down Expand Up @@ -256,7 +256,7 @@ type runnerConfig struct {
addr netip.AddrPort
}

func newDNSRunner(cfg runnerConfig, cache *dns.Cache, logger *zap.Logger, forwardEnabled bool) (*dns.Server, error) {
func newDNSRunner(cfg runnerConfig, rootHandler dnssrv.Handler, logger *zap.Logger, forwardEnabled bool) (*dns.Server, error) {
if cfg.addr.Addr().Is6() {
cfg.net += "6"
}
Expand All @@ -279,7 +279,7 @@ func newDNSRunner(cfg runnerConfig, cache *dns.Cache, logger *zap.Logger, forwar

serverOpts = dns.ServerOptions{
PacketConn: packetConn,
Handler: cache,
Handler: rootHandler,
Logger: logger,
}

Expand All @@ -291,7 +291,7 @@ func newDNSRunner(cfg runnerConfig, cache *dns.Cache, logger *zap.Logger, forwar

serverOpts = dns.ServerOptions{
Listener: listener,
Handler: cache,
Handler: rootHandler,
ReadTimeout: 3 * time.Second,
WriteTimeout: 5 * time.Second,
IdleTimeout: func() time.Duration { return 10 * time.Second },
Expand Down
8 changes: 8 additions & 0 deletions internal/integration/api/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,14 @@ func (suite *CommonSuite) TestDNSResolver() {

suite.Require().Equal("", stdout)
suite.Require().Contains(stderr, "'index.html' saved")

stdout, stderr, err = suite.ExecuteCommandInPod(suite.ctx, namespace, pod, "nslookup really-long-record.dev.siderolabs.io")
suite.Require().NoError(err)

suite.Require().Contains(stdout, "really-long-record.dev.siderolabs.io")
suite.Require().NotContains(stdout, "Can't find")
suite.Require().NotContains(stdout, "No answer")
suite.Require().Equal(stderr, "")
}

func init() {
Expand Down
6 changes: 2 additions & 4 deletions internal/integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"fmt"
"os"
"path/filepath"
"slices"
"testing"

"github.com/stretchr/testify/suite"
Expand Down Expand Up @@ -180,8 +181,5 @@ func init() {
flag.StringVar(&provision_test.DefaultSettings.CustomCNIURL, "talos.provision.custom-cni-url", provision_test.DefaultSettings.CustomCNIURL, "custom CNI URL for the cluster (provision tests only)")
flag.StringVar(&provision_test.DefaultSettings.CNIBundleURL, "talos.provision.cni-bundle-url", provision_test.DefaultSettings.CNIBundleURL, "URL to download CNI bundle from")

allSuites = append(allSuites, api.GetAllSuites()...)
allSuites = append(allSuites, cli.GetAllSuites()...)
allSuites = append(allSuites, k8s.GetAllSuites()...)
allSuites = append(allSuites, provision_test.GetAllSuites()...)
allSuites = slices.Concat(api.GetAllSuites(), cli.GetAllSuites(), k8s.GetAllSuites(), provision_test.GetAllSuites())
}
23 changes: 19 additions & 4 deletions internal/pkg/dns/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func NewCache(next plugin.Handler, l *zap.Logger) *Cache {

// ServeDNS implements [dns.Handler].
func (c *Cache) ServeDNS(wr dns.ResponseWriter, msg *dns.Msg) {
_, err := c.cache.ServeDNS(context.Background(), wr, msg)
_, err := c.cache.ServeDNS(context.Background(), request.NewScrubWriter(msg, wr), msg)
if err != nil {
// we should probably call newProxy.Healthcheck() if there are too many errors
c.logger.Warn("error serving dns request", zap.Error(err))
Expand All @@ -77,6 +77,8 @@ func (h *Handler) Name() string {
}

// ServeDNS implements plugin.Handler.
//
//nolint:gocyclo
func (h *Handler) ServeDNS(ctx context.Context, wrt dns.ResponseWriter, msg *dns.Msg) (int, error) {
h.mx.RLock()
defer h.mx.RUnlock()
Expand Down Expand Up @@ -107,9 +109,21 @@ func (h *Handler) ServeDNS(ctx context.Context, wrt dns.ResponseWriter, msg *dns
)

for _, ups := range upstreams {
resp, err = ups.Connect(ctx, req, proxy.Options{})
if errors.Is(err, proxy.ErrCachedClosed) { // Remote side closed conn, can only happen with TCP.
continue
opts := proxy.Options{}

for {
resp, err = ups.Connect(ctx, req, opts)

switch {
case errors.Is(err, proxy.ErrCachedClosed): // Remote side closed conn, can only happen with TCP.
continue
case resp != nil && resp.Truncated && !opts.ForceTCP: // Retry with TCP if truncated
opts.ForceTCP = true

continue
}

break
}

if err == nil {
Expand Down Expand Up @@ -279,6 +293,7 @@ func NewServer(opts ServerOptions) *Server {
Listener: opts.Listener,
PacketConn: opts.PacketConn,
Handler: opts.Handler,
UDPSize: dns.DefaultMsgSize, // 4096 since default is [dns.MinMsgSize] = 512 bytes, which is too small.
ReadTimeout: opts.ReadTimeout,
WriteTimeout: opts.WriteTimeout,
IdleTimeout: opts.IdleTimeout,
Expand Down