diff --git a/internal/app/machined/pkg/controllers/network/dns_resolve_cache.go b/internal/app/machined/pkg/controllers/network/dns_resolve_cache.go index 763b4007ab..d93ccaaf48 100644 --- a/internal/app/machined/pkg/controllers/network/dns_resolve_cache.go +++ b/internal/app/machined/pkg/controllers/network/dns_resolve_cache.go @@ -37,7 +37,7 @@ type DNSResolveCacheController struct { mx sync.Mutex handler *dns.Handler nodeHandler *dns.NodeHandler - cache *dns.Cache + rootHandler dnssrv.Handler runners map[runnerConfig]pair.Pair[func(), <-chan struct{}] reconcile chan struct{} originalCtx context.Context //nolint:containedctx @@ -130,7 +130,7 @@ func (ctrl *DNSResolveCacheController) Run(ctx context.Context, r controller.Run runnerCfg := runnerConfig{net: netwk, addr: addr} if _, ok := ctrl.runners[runnerCfg]; !ok { - runner, rErr := newDNSRunner(runnerCfg, ctrl.cache, ctrl.Logger, cfg.TypedSpec().ServiceHostDNSAddress.IsValid()) + runner, rErr := newDNSRunner(runnerCfg, ctrl.rootHandler, ctrl.Logger, cfg.TypedSpec().ServiceHostDNSAddress.IsValid()) if rErr != nil { return fmt.Errorf("error creating dns runner: %w", rErr) } @@ -200,7 +200,7 @@ func (ctrl *DNSResolveCacheController) init(ctx context.Context) { ctrl.originalCtx = ctx ctrl.handler = dns.NewHandler(ctrl.Logger) ctrl.nodeHandler = dns.NewNodeHandler(ctrl.handler, &stateMapper{state: ctrl.State}, ctrl.Logger) - ctrl.cache = dns.NewCache(ctrl.nodeHandler, ctrl.Logger) + ctrl.rootHandler = dns.NewCache(ctrl.nodeHandler, ctrl.Logger) ctrl.runners = map[runnerConfig]pair.Pair[func(), <-chan struct{}]{} ctrl.reconcile = make(chan struct{}, 1) @@ -256,7 +256,7 @@ type runnerConfig struct { addr netip.AddrPort } -func newDNSRunner(cfg runnerConfig, cache *dns.Cache, logger *zap.Logger, forwardEnabled bool) (*dns.Server, error) { +func newDNSRunner(cfg runnerConfig, rootHandler dnssrv.Handler, logger *zap.Logger, forwardEnabled bool) (*dns.Server, error) { if cfg.addr.Addr().Is6() { cfg.net += "6" } @@ -279,7 +279,7 @@ func newDNSRunner(cfg runnerConfig, cache *dns.Cache, logger *zap.Logger, forwar serverOpts = dns.ServerOptions{ PacketConn: packetConn, - Handler: cache, + Handler: rootHandler, Logger: logger, } @@ -291,7 +291,7 @@ func newDNSRunner(cfg runnerConfig, cache *dns.Cache, logger *zap.Logger, forwar serverOpts = dns.ServerOptions{ Listener: listener, - Handler: cache, + Handler: rootHandler, ReadTimeout: 3 * time.Second, WriteTimeout: 5 * time.Second, IdleTimeout: func() time.Duration { return 10 * time.Second }, diff --git a/internal/integration/api/common.go b/internal/integration/api/common.go index bc5f8f9654..715bc3fe22 100644 --- a/internal/integration/api/common.go +++ b/internal/integration/api/common.go @@ -166,6 +166,14 @@ func (suite *CommonSuite) TestDNSResolver() { suite.Require().Equal("", stdout) suite.Require().Contains(stderr, "'index.html' saved") + + stdout, stderr, err = suite.ExecuteCommandInPod(suite.ctx, namespace, pod, "nslookup really-long-record.dev.siderolabs.io") + suite.Require().NoError(err) + + suite.Require().Contains(stdout, "really-long-record.dev.siderolabs.io") + suite.Require().NotContains(stdout, "Can't find") + suite.Require().NotContains(stdout, "No answer") + suite.Require().Equal(stderr, "") } func init() { diff --git a/internal/integration/integration_test.go b/internal/integration/integration_test.go index 89c19448bb..6f48a585ca 100644 --- a/internal/integration/integration_test.go +++ b/internal/integration/integration_test.go @@ -13,6 +13,7 @@ import ( "fmt" "os" "path/filepath" + "slices" "testing" "github.com/stretchr/testify/suite" @@ -180,8 +181,5 @@ func init() { flag.StringVar(&provision_test.DefaultSettings.CustomCNIURL, "talos.provision.custom-cni-url", provision_test.DefaultSettings.CustomCNIURL, "custom CNI URL for the cluster (provision tests only)") flag.StringVar(&provision_test.DefaultSettings.CNIBundleURL, "talos.provision.cni-bundle-url", provision_test.DefaultSettings.CNIBundleURL, "URL to download CNI bundle from") - allSuites = append(allSuites, api.GetAllSuites()...) - allSuites = append(allSuites, cli.GetAllSuites()...) - allSuites = append(allSuites, k8s.GetAllSuites()...) - allSuites = append(allSuites, provision_test.GetAllSuites()...) + allSuites = slices.Concat(api.GetAllSuites(), cli.GetAllSuites(), k8s.GetAllSuites(), provision_test.GetAllSuites()) } diff --git a/internal/pkg/dns/dns.go b/internal/pkg/dns/dns.go index 5e8e5bf3a1..5f3703575a 100644 --- a/internal/pkg/dns/dns.go +++ b/internal/pkg/dns/dns.go @@ -50,7 +50,7 @@ func NewCache(next plugin.Handler, l *zap.Logger) *Cache { // ServeDNS implements [dns.Handler]. func (c *Cache) ServeDNS(wr dns.ResponseWriter, msg *dns.Msg) { - _, err := c.cache.ServeDNS(context.Background(), wr, msg) + _, err := c.cache.ServeDNS(context.Background(), request.NewScrubWriter(msg, wr), msg) if err != nil { // we should probably call newProxy.Healthcheck() if there are too many errors c.logger.Warn("error serving dns request", zap.Error(err)) @@ -77,6 +77,8 @@ func (h *Handler) Name() string { } // ServeDNS implements plugin.Handler. +// +//nolint:gocyclo func (h *Handler) ServeDNS(ctx context.Context, wrt dns.ResponseWriter, msg *dns.Msg) (int, error) { h.mx.RLock() defer h.mx.RUnlock() @@ -107,9 +109,21 @@ func (h *Handler) ServeDNS(ctx context.Context, wrt dns.ResponseWriter, msg *dns ) for _, ups := range upstreams { - resp, err = ups.Connect(ctx, req, proxy.Options{}) - if errors.Is(err, proxy.ErrCachedClosed) { // Remote side closed conn, can only happen with TCP. - continue + opts := proxy.Options{} + + for { + resp, err = ups.Connect(ctx, req, opts) + + switch { + case errors.Is(err, proxy.ErrCachedClosed): // Remote side closed conn, can only happen with TCP. + continue + case resp != nil && resp.Truncated && !opts.ForceTCP: // Retry with TCP if truncated + opts.ForceTCP = true + + continue + } + + break } if err == nil { @@ -279,6 +293,7 @@ func NewServer(opts ServerOptions) *Server { Listener: opts.Listener, PacketConn: opts.PacketConn, Handler: opts.Handler, + UDPSize: dns.DefaultMsgSize, // 4096 since default is [dns.MinMsgSize] = 512 bytes, which is too small. ReadTimeout: opts.ReadTimeout, WriteTimeout: opts.WriteTimeout, IdleTimeout: opts.IdleTimeout,