diff --git a/internal/bootstrap/bootstrap.go b/internal/bootstrap/bootstrap.go index 1084cbefa..ae08de9d1 100644 --- a/internal/bootstrap/bootstrap.go +++ b/internal/bootstrap/bootstrap.go @@ -14,20 +14,42 @@ import ( "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" + "golang.org/x/exp/slices" +) + +// Network is a network type for use in [Resolver]'s methods. +type Network = string + +const ( + // NetworkIP is a network type for both address families. + NetworkIP Network = "ip" + + // NetworkIP4 is a network type for IPv4 address family. + NetworkIP4 Network = "ip4" + + // NetworkIP6 is a network type for IPv6 address family. + NetworkIP6 Network = "ip6" + + // NetworkTCP is a network type for TCP connections. + NetworkTCP Network = "tcp" + + // NetworkUDP is a network type for UDP connections. + NetworkUDP Network = "udp" ) // DialHandler is a dial function for creating unencrypted network connections // to the upstream server. It establishes the connection to the server -// specified at initialization and ignores the addr. -type DialHandler func(ctx context.Context, network, addr string) (conn net.Conn, err error) +// specified at initialization and ignores the addr. network must be one of +// [NetworkTCP] or [NetworkUDP]. +type DialHandler func(ctx context.Context, network Network, addr string) (conn net.Conn, err error) // ResolveDialContext returns a DialHandler that uses addresses resolved from u // using resolver. u must not be nil. func ResolveDialContext( u *url.URL, timeout time.Duration, - resolver Resolver, - preferIPv6 bool, + r Resolver, + preferV6 bool, ) (h DialHandler, err error) { defer func() { err = errors.Annotate(err, "dialing %q: %w", u.Host) }() @@ -38,7 +60,7 @@ func ResolveDialContext( return nil, err } - if resolver == nil { + if r == nil { return nil, fmt.Errorf("resolver is nil: %w", ErrNoResolvers) } @@ -49,21 +71,20 @@ func ResolveDialContext( defer cancel() } - ips, err := resolver.LookupNetIP(ctx, "ip", host) + ips, err := r.LookupNetIP(ctx, NetworkIP, host) if err != nil { return nil, fmt.Errorf("resolving hostname: %w", err) } - proxynetutil.SortNetIPAddrs(ips, preferIPv6) + if preferV6 { + slices.SortStableFunc(ips, proxynetutil.PreferIPv6) + } else { + slices.SortStableFunc(ips, proxynetutil.PreferIPv4) + } addrs := make([]string, 0, len(ips)) for _, ip := range ips { - if !ip.IsValid() { - // All invalid addresses should be in the tail after sorting. - break - } - - addrs = append(addrs, netip.AddrPortFrom(ip, uint16(port)).String()) + addrs = append(addrs, netip.AddrPortFrom(ip, port).String()) } return NewDialContext(timeout, addrs...), nil @@ -71,14 +92,7 @@ func ResolveDialContext( // NewDialContext returns a DialHandler that dials addrs and returns the first // successful connection. At least a single addr should be specified. -// -// TODO(e.burkov): Consider using [Resolver] instead of -// [upstream.Options.Bootstrap] and [upstream.Options.ServerIPAddrs]. func NewDialContext(timeout time.Duration, addrs ...string) (h DialHandler) { - dialer := &net.Dialer{ - Timeout: timeout, - } - l := len(addrs) if l == 0 { log.Debug("bootstrap: no addresses to dial") @@ -88,9 +102,11 @@ func NewDialContext(timeout time.Duration, addrs ...string) (h DialHandler) { } } - // TODO(e.burkov): Check IPv6 preference here. + dialer := &net.Dialer{ + Timeout: timeout, + } - return func(ctx context.Context, network, _ string) (conn net.Conn, err error) { + return func(ctx context.Context, network Network, _ string) (conn net.Conn, err error) { var errs []error // Return first succeeded connection. Note that we're using addrs @@ -101,17 +117,18 @@ func NewDialContext(timeout time.Duration, addrs ...string) (h DialHandler) { start := time.Now() conn, err = dialer.DialContext(ctx, network, addr) elapsed := time.Since(start) - if err == nil { - log.Debug("bootstrap: connection to %s succeeded in %s", addr, elapsed) + if err != nil { + log.Debug("bootstrap: connection to %s failed in %s: %s", addr, elapsed, err) + errs = append(errs, err) - return conn, nil + continue } - log.Debug("bootstrap: connection to %s failed in %s: %s", addr, elapsed, err) - errs = append(errs, err) + log.Debug("bootstrap: connection to %s succeeded in %s", addr, elapsed) + + return conn, nil } - // TODO(e.burkov): Use errors.Join in Go 1.20. - return nil, errors.List("all dialers failed", errs...) + return nil, errors.Join(errs...) } } diff --git a/internal/bootstrap/bootstrap_test.go b/internal/bootstrap/bootstrap_test.go index e25f45c5f..53a0db11d 100644 --- a/internal/bootstrap/bootstrap_test.go +++ b/internal/bootstrap/bootstrap_test.go @@ -87,7 +87,7 @@ func TestResolveDialContext(t *testing.T) { network string, host string, ) (addrs []netip.Addr, err error) { - require.Equal(pt, "ip", network) + require.Equal(pt, bootstrap.NetworkIP, network) require.Equal(pt, hostname, host) return tc.addresses, nil @@ -103,7 +103,7 @@ func TestResolveDialContext(t *testing.T) { ) require.NoError(t, err) - conn, err := dialContext(context.Background(), "tcp", "") + conn, err := dialContext(context.Background(), bootstrap.NetworkTCP, "") require.NoError(t, err) expected, ok := testutil.RequireReceive(t, sig, testTimeout) @@ -120,7 +120,7 @@ func TestResolveDialContext(t *testing.T) { network string, host string, ) (addrs []netip.Addr, err error) { - require.Equal(pt, "ip", network) + require.Equal(pt, bootstrap.NetworkIP, network) require.Equal(pt, hostname, host) return nil, nil @@ -135,7 +135,7 @@ func TestResolveDialContext(t *testing.T) { ) require.NoError(t, err) - _, err = dialContext(context.Background(), "tcp", "") + _, err = dialContext(context.Background(), bootstrap.NetworkTCP, "") testutil.AssertErrorMsg(t, "no addresses", err) }) diff --git a/internal/bootstrap/error.go b/internal/bootstrap/error.go new file mode 100644 index 000000000..9f65e8226 --- /dev/null +++ b/internal/bootstrap/error.go @@ -0,0 +1,6 @@ +package bootstrap + +import "github.com/AdguardTeam/golibs/errors" + +// ErrNoResolvers is returned when zero resolvers specified. +const ErrNoResolvers errors.Error = "no resolvers specified" diff --git a/internal/bootstrap/resolver.go b/internal/bootstrap/resolver.go index b2c57c76f..9891adc25 100644 --- a/internal/bootstrap/resolver.go +++ b/internal/bootstrap/resolver.go @@ -8,22 +8,21 @@ import ( "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" + "golang.org/x/exp/slices" ) -// Resolver resolves the hostnames to IP addresses. +// Resolver resolves the hostnames to IP addresses. Note, that [net.Resolver] +// from standard library also implements this interface. type Resolver interface { - // LookupNetIP looks up the IP addresses for the given host. network must - // be one of "ip", "ip4" or "ip6". The response may be empty even if err is - // nil. - LookupNetIP(ctx context.Context, network, host string) (addrs []netip.Addr, err error) + // LookupNetIP looks up the IP addresses for the given host. network should + // be one of [NetworkIP], [NetworkIP4] or [NetworkIP6]. The response may be + // empty even if err is nil. All the addrs must be valid. + LookupNetIP(ctx context.Context, network Network, host string) (addrs []netip.Addr, err error) } // type check var _ Resolver = &net.Resolver{} -// ErrNoResolvers is returned when zero resolvers specified. -const ErrNoResolvers errors.Error = "no resolvers specified" - // ParallelResolver is a slice of resolvers that are queried concurrently. The // first successful response is returned. type ParallelResolver []Resolver @@ -34,7 +33,7 @@ var _ Resolver = ParallelResolver(nil) // LookupNetIP implements the [Resolver] interface for ParallelResolver. func (r ParallelResolver) LookupNetIP( ctx context.Context, - network string, + network Network, host string, ) (addrs []netip.Addr, err error) { resolversNum := len(r) @@ -48,7 +47,7 @@ func (r ParallelResolver) LookupNetIP( } // Size of channel must accommodate results of lookups from all resolvers, - // sending into channel will be block otherwise. + // sending into channel will block otherwise. ch := make(chan any, resolversNum) for _, rslv := range r { go lookupAsync(ctx, rslv, network, host, ch) @@ -97,3 +96,50 @@ func lookup(ctx context.Context, r Resolver, network, host string) (addrs []neti return addrs, err } + +// ConsequentResolver is a slice of resolvers that are queried in order until +// the first successful non-empty response, as opposed to just successful +// response requirement in [ParallelResolver]. +type ConsequentResolver []Resolver + +// type check +var _ Resolver = ConsequentResolver(nil) + +// LookupNetIP implements the [Resolver] interface for ConsequentResolver. +func (resolvers ConsequentResolver) LookupNetIP( + ctx context.Context, + network Network, + host string, +) (addrs []netip.Addr, err error) { + if len(resolvers) == 0 { + return nil, ErrNoResolvers + } + + var errs []error + for _, r := range resolvers { + addrs, err = r.LookupNetIP(ctx, network, host) + if err == nil && len(addrs) > 0 { + return addrs, nil + } + + errs = append(errs, err) + } + + return nil, errors.Join(errs...) +} + +// StaticResolver is a resolver which always responds with an underlying slice +// of IP addresses regardless of host and network. +type StaticResolver []netip.Addr + +// type check +var _ Resolver = StaticResolver(nil) + +// LookupNetIP implements the [Resolver] interface for StaticResolver. +func (r StaticResolver) LookupNetIP( + _ context.Context, + _ Network, + _ string, +) (addrs []netip.Addr, err error) { + return slices.Clone(r), nil +} diff --git a/internal/netutil/netutil.go b/internal/netutil/netutil.go index 6a1719841..21faae213 100644 --- a/internal/netutil/netutil.go +++ b/internal/netutil/netutil.go @@ -12,6 +12,42 @@ import ( "golang.org/x/exp/slices" ) +// PreferIPv4 compares two addresses, preferring IPv4 addresses over IPv6 ones. +// Invalid addresses are sorted near the end. +func PreferIPv4(a, b netip.Addr) (res int) { + if !a.IsValid() { + return 1 + } else if !b.IsValid() { + return -1 + } + + if aIs4 := a.Is4(); aIs4 == b.Is4() { + return a.Compare(b) + } else if aIs4 { + return -1 + } + + return 1 +} + +// PreferIPv6 compares two addresses, preferring IPv6 addresses over IPv4 ones. +// Invalid addresses are sorted near the end. +func PreferIPv6(a, b netip.Addr) (res int) { + if !a.IsValid() { + return 1 + } else if !b.IsValid() { + return -1 + } + + if aIs6 := a.Is6(); aIs6 == b.Is6() { + return a.Compare(b) + } else if aIs6 { + return -1 + } + + return 1 +} + // SortNetIPAddrs sorts addrs in accordance with the protocol preferences. // Invalid addresses are sorted near the end. Zones are ignored. func SortNetIPAddrs(addrs []netip.Addr, preferIPv6 bool) { diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index ba85db0f2..5a813ee70 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -329,7 +329,7 @@ func TestExchangeWithReservedDomains(t *testing.T) { upstreams, &upstream.Options{ InsecureSkipVerify: false, - Bootstrap: googleRslv, + Bootstrap: upstream.NewCachingResolver(googleRslv), Timeout: 1 * time.Second, }, ) @@ -412,7 +412,7 @@ func TestOneByOneUpstreamsExchange(t *testing.T) { u, err = upstream.AddressToUpstream( line, &upstream.Options{ - Bootstrap: googleRslv, + Bootstrap: upstream.NewCachingResolver(googleRslv), Timeout: timeOut, }, ) diff --git a/upstream/resolver.go b/upstream/resolver.go index ab017d975..5f98c800d 100644 --- a/upstream/resolver.go +++ b/upstream/resolver.go @@ -5,20 +5,35 @@ import ( "fmt" "net/netip" "net/url" + "strings" + "sync" + "time" "github.com/AdguardTeam/dnsproxy/internal/bootstrap" "github.com/AdguardTeam/dnsproxy/proxyutil" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/miekg/dns" - "golang.org/x/exp/slices" ) -// Resolver is an alias for the internal [bootstrap.Resolver] to allow custom -// implementations. Note, that the [net.Resolver] from standard library also -// implements this interface. +// Resolver resolves the hostnames to IP addresses. Note, that [net.Resolver] +// from standard library also implements this interface. type Resolver = bootstrap.Resolver +// StaticResolver is a resolver which always responds with an underlying slice +// of IP addresses. +type StaticResolver = bootstrap.StaticResolver + +// ParallelResolver is a slice of resolvers that are queried concurrently until +// the first successful response is returned, as opposed to all resolvers being +// queried in order in [ConsequentResolver]. +type ParallelResolver = bootstrap.ParallelResolver + +// ConsequentResolver is a slice of resolvers that are queried in order until +// the first successful non-empty response, as opposed to just successful +// response requirement in [ParallelResolver]. +type ConsequentResolver = bootstrap.ConsequentResolver + // UpstreamResolver is a wrapper around Upstream that implements the // [bootstrap.Resolver] interface. type UpstreamResolver struct { @@ -105,57 +120,107 @@ func validateBootstrap(u Upstream) (err error) { // type check var _ Resolver = &UpstreamResolver{} -// LookupNetIP implements the [Resolver] interface for upstreamResolver. +// LookupNetIP implements the [Resolver] interface for *UpstreamResolver. It +// doesn't consider the TTL of the DNS records. // -// TODO(e.burkov): Use context. +// TODO(e.burkov): Investigate why the empty slice is returned instead of nil. func (r *UpstreamResolver) LookupNetIP( - _ context.Context, - network string, + ctx context.Context, + network bootstrap.Network, host string, ) (ips []netip.Addr, err error) { if host == "" { return nil, nil } - switch network { - case "ip4", "ip6": - host = dns.Fqdn(host) - ips, err = r.resolve(host, network) - case "ip": - host = dns.Fqdn(host) - resCh := make(chan any, 2) - go r.resolveAsync(resCh, host, "ip4") - go r.resolveAsync(resCh, host, "ip6") - - var errs []error - for i := 0; i < 2; i++ { - switch res := <-resCh; res := res.(type) { - case error: - errs = append(errs, res) - case []netip.Addr: - ips = append(ips, res...) - } + host = dns.Fqdn(strings.ToLower(host)) + + rr, err := r.resolveIP(ctx, network, host) + if err != nil { + return []netip.Addr{}, err + } + + for _, ip := range rr { + ips = append(ips, ip.addr) + } + + return ips, err +} + +// ipResult reflects a single A/AAAA record from the DNS response. It's used +// to cache the results of lookups. +type ipResult struct { + addr netip.Addr + expire time.Time +} + +// filterExpired returns the addresses from res that are not expired yet. It +// returns nil if all the addresses are expired. +func filterExpired(res []ipResult, now time.Time) (filtered []netip.Addr) { + for _, r := range res { + if r.expire.After(now) { + filtered = append(filtered, r.addr) } + } - err = errors.Join(errs...) + return filtered +} + +// resolveIP performs a DNS lookup of host and returns the result. network must +// be either [bootstrap.NetworkIP4], [bootstrap.NetworkIP6] or +// [bootstrap.NetworkIP]. host must be in a lower-case FQDN form. +// +// TODO(e.burkov): Use context. +func (r *UpstreamResolver) resolveIP( + _ context.Context, + network bootstrap.Network, + host string, +) (rr []ipResult, err error) { + switch network { + case bootstrap.NetworkIP4, bootstrap.NetworkIP6: + return r.resolve(host, network) + case bootstrap.NetworkIP: + // Go on. default: - return []netip.Addr{}, fmt.Errorf("unsupported network %s", network) + return nil, fmt.Errorf("unsupported network %s", network) } - if len(ips) == 0 { - ips = []netip.Addr{} + resCh := make(chan any, 2) + go r.resolveAsync(resCh, host, bootstrap.NetworkIP4) + go r.resolveAsync(resCh, host, bootstrap.NetworkIP6) + + var errs []error + + for i := 0; i < 2; i++ { + switch res := <-resCh; res := res.(type) { + case error: + errs = append(errs, res) + case []ipResult: + rr = append(rr, res...) + } } - return ips, err + return rr, errors.Join(errs...) } // resolve performs a single DNS lookup of host and returns all the valid // addresses from the answer section of the response. network must be either -// "ip4" or "ip6". -func (r *UpstreamResolver) resolve(host, network string) (addrs []netip.Addr, err error) { - qtype := dns.TypeA - if network == "ip6" { +// "ip4" or "ip6". host must be in a lower-case FQDN form. +// +// TODO(e.burkov): Consider NS and Extra sections when setting TTL. Check out +// what RFCs say about it. +func (r *UpstreamResolver) resolve( + host string, + n bootstrap.Network, +) (res []ipResult, err error) { + var qtype uint16 + switch n { + case bootstrap.NetworkIP4: + qtype = dns.TypeA + case bootstrap.NetworkIP6: qtype = dns.TypeAAAA + default: + panic(fmt.Sprintf("unsupported network %q", n)) } req := &dns.Msg{ @@ -170,78 +235,107 @@ func (r *UpstreamResolver) resolve(host, network string) (addrs []netip.Addr, er }}, } - resp, err := r.Upstream.Exchange(req) - if err != nil || resp == nil { + // As per [upstream.Exchange] documentation, the response is always returned + // if no error occurred. + resp, err := r.Exchange(req) + if err != nil { return nil, err } + now := time.Now() for _, rr := range resp.Answer { - if addr := proxyutil.IPFromRR(rr); addr.IsValid() { - addrs = append(addrs, addr) + ip := proxyutil.IPFromRR(rr) + if !ip.IsValid() { + continue } + + res = append(res, ipResult{ + addr: ip, + expire: now.Add(time.Duration(rr.Header().Ttl) * time.Second), + }) } - return addrs, nil + return res, nil } // resolveAsync performs a single DNS lookup and sends the result to ch. It's // intended to be used as a goroutine. func (r *UpstreamResolver) resolveAsync(resCh chan<- any, host, network string) { - resp, err := r.resolve(host, network) + res, err := r.resolve(host, network) if err != nil { resCh <- err } else { - resCh <- resp + resCh <- res } } -// StaticResolver is a resolver which always responds with an underlying slice -// of IP addresses. -type StaticResolver []netip.Addr +// CachingResolver is a [Resolver] that caches the results of lookups. It's +// required to be created with [NewCachingResolver]. +type CachingResolver struct { + // resolver is the underlying resolver to use for lookups. + resolver *UpstreamResolver -// type check -var _ Resolver = StaticResolver(nil) + // mu protects cached and it's elements. + mu *sync.RWMutex -// LookupNetIP implements the [Resolver] interface for StaticResolver. -func (r StaticResolver) LookupNetIP( - ctx context.Context, - network string, - host string, -) (addrs []netip.Addr, err error) { - return slices.Clone(r), nil + // cached is the set of cached results sorted by [resolveResult.name]. + cached map[string][]ipResult } -// ConsequentResolver is a slice of resolvers that are queried in order until -// the first successful non-empty response, as opposed to just successful -// response requirement in [ParallelResolver]. -type ConsequentResolver []Resolver +// NewCachingResolver creates a new caching resolver that uses r for lookups. +func NewCachingResolver(r *UpstreamResolver) (cr *CachingResolver) { + return &CachingResolver{ + resolver: r, + mu: &sync.RWMutex{}, + cached: map[string][]ipResult{}, + } +} // type check -var _ Resolver = ConsequentResolver(nil) +var _ Resolver = (*CachingResolver)(nil) -// LookupNetIP implements the [Resolver] interface for ConsequentResolver. -func (resolvers ConsequentResolver) LookupNetIP( +// LookupNetIP implements the [Resolver] interface for *CachingResolver. +func (r *CachingResolver) LookupNetIP( ctx context.Context, - network string, + network bootstrap.Network, host string, ) (addrs []netip.Addr, err error) { - if len(resolvers) == 0 { - return nil, bootstrap.ErrNoResolvers + now := time.Now() + host = dns.Fqdn(strings.ToLower(host)) + + addrs = r.findCached(host, now) + if addrs != nil { + return addrs, nil } - var errs []error - for _, r := range resolvers { - addrs, err = r.LookupNetIP(ctx, network, host) - if err == nil && len(addrs) > 0 { - return addrs, nil - } + newRes, err := r.resolver.resolveIP(ctx, network, host) + if err != nil { + return []netip.Addr{}, err + } - errs = append(errs, err) + addrs = filterExpired(newRes, now) + if len(addrs) == 0 { + return []netip.Addr{}, nil } - return nil, errors.Join(errs...) + r.mu.Lock() + defer r.mu.Unlock() + + r.cached[host] = newRes + + return addrs, nil } -// ParallelResolver is an alias for the internal [bootstrap.ParallelResolver] to -// allow it's usage outside of the module. -type ParallelResolver = bootstrap.ParallelResolver +// findCached returns the cached addresses for host if it's not expired yet, and +// the corresponding cached result, if any. +func (r *CachingResolver) findCached(host string, now time.Time) (addrs []netip.Addr) { + r.mu.RLock() + defer r.mu.RUnlock() + + res, ok := r.cached[host] + if !ok { + return nil + } + + return filterExpired(res, now) +} diff --git a/upstream/upstream.go b/upstream/upstream.go index 5b147ea76..9cd345d87 100644 --- a/upstream/upstream.go +++ b/upstream/upstream.go @@ -14,7 +14,6 @@ import ( "os" "strconv" "strings" - "sync/atomic" "time" "github.com/AdguardTeam/dnsproxy/internal/bootstrap" @@ -323,9 +322,7 @@ func isTimeout(err error) (ok bool) { } } -// DialerInitializer returns the handler that it creates. All the subsequent -// calls to it, except the first one, will return the same handler so that -// resolving will be performed only once. +// DialerInitializer returns the handler that it creates. type DialerInitializer func() (handler bootstrap.DialHandler, err error) // newDialerInitializer creates an initializer of the dialer that will dial the @@ -335,7 +332,9 @@ func newDialerInitializer(u *url.URL, opts *Options) (di DialerInitializer) { // Don't resolve the address of the server since it's already an IP. handler := bootstrap.NewDialContext(opts.Timeout, u.Host) - return func() (bootstrap.DialHandler, error) { return handler, nil } + return func() (h bootstrap.DialHandler, dialerErr error) { + return handler, nil + } } boot := opts.Bootstrap @@ -344,27 +343,7 @@ func newDialerInitializer(u *url.URL, opts *Options) (di DialerInitializer) { boot = net.DefaultResolver } - var dialHandler atomic.Pointer[bootstrap.DialHandler] - return func() (h bootstrap.DialHandler, err error) { - // Check if the dial handler has already been created. - if hPtr := dialHandler.Load(); hPtr != nil { - return *hPtr, nil - } - - // TODO(e.burkov): It may appear that several exchanges will try to - // resolve the upstream hostname at the same time. Currently, the last - // successful value will be stored in dialHandler, but ideally we should - // resolve only once. - h, err = bootstrap.ResolveDialContext(u, opts.Timeout, boot, opts.PreferIPv6) - if err != nil { - return nil, fmt.Errorf("creating dial handler: %w", err) - } - - if !dialHandler.CompareAndSwap(nil, &h) { - return *dialHandler.Load(), nil - } - - return h, nil + return bootstrap.ResolveDialContext(u, opts.Timeout, boot, opts.PreferIPv6) } } diff --git a/upstream/upstream_test.go b/upstream/upstream_test.go index f56569253..0ad53b660 100644 --- a/upstream/upstream_test.go +++ b/upstream/upstream_test.go @@ -55,7 +55,7 @@ func TestUpstream_bootstrapTimeout(t *testing.T) { // Create an upstream that uses this faulty bootstrap. u, err := AddressToUpstream("tls://random-domain-name", &Options{ - Bootstrap: rslv, + Bootstrap: NewCachingResolver(rslv), Timeout: timeout, }) require.NoError(t, err) @@ -114,17 +114,20 @@ func TestUpstreams(t *testing.T) { }) require.NoError(t, err) + googleBoot := NewCachingResolver(googleRslv) + cloudflareBoot := NewCachingResolver(cloudflareRslv) + upstreams := []struct { bootstrap Resolver address string }{{ - bootstrap: googleRslv, + bootstrap: googleBoot, address: "8.8.8.8:53", }, { bootstrap: nil, address: "1.1.1.1", }, { - bootstrap: cloudflareRslv, + bootstrap: cloudflareBoot, address: "1.1.1.1", }, { bootstrap: nil, @@ -139,19 +142,19 @@ func TestUpstreams(t *testing.T) { bootstrap: nil, address: "tls://9.9.9.9:853", }, { - bootstrap: googleRslv, + bootstrap: googleBoot, address: "tls://dns.adguard.com", }, { - bootstrap: googleRslv, + bootstrap: googleBoot, address: "tls://dns.adguard.com:853", }, { - bootstrap: googleRslv, + bootstrap: googleBoot, address: "tls://dns.adguard.com:853", }, { bootstrap: nil, address: "tls://one.one.one.one", }, { - bootstrap: googleRslv, + bootstrap: googleBoot, address: "https://1dot1dot1dot1.cloudflare-dns.com/dns-query", }, { bootstrap: nil, @@ -165,11 +168,11 @@ func TestUpstreams(t *testing.T) { address: "sdns://AQIAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", }, { // AdGuard Family (DNSCrypt) - bootstrap: googleRslv, + bootstrap: googleBoot, address: "sdns://AQIAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMjo1NDQzILgxXdexS27jIKRw3C7Wsao5jMnlhvhdRUXWuMm1AFq6ITIuZG5zY3J5cHQuZmFtaWx5Lm5zMS5hZGd1YXJkLmNvbQ", }, { // Cloudflare DNS (DNS-over-HTTPS) - bootstrap: googleRslv, + bootstrap: googleBoot, address: "sdns://AgcAAAAAAAAABzEuMC4wLjGgENk8mGSlIfMGXMOlIlCcKvq7AVgcrZxtjon911-ep0cg63Ul-I8NlFj4GplQGb_TTLiczclX57DvMV8Q-JdjgRgSZG5zLmNsb3VkZmxhcmUuY29tCi9kbnMtcXVlcnk", }, { // Google (Plain) @@ -177,11 +180,11 @@ func TestUpstreams(t *testing.T) { address: "sdns://AAcAAAAAAAAABzguOC44Ljg", }, { // AdGuard DNS (DNS-over-TLS) - bootstrap: googleRslv, + bootstrap: googleBoot, address: "sdns://AwAAAAAAAAAAAAAPZG5zLmFkZ3VhcmQuY29t", }, { // AdGuard DNS (DNS-over-QUIC) - bootstrap: googleRslv, + bootstrap: googleBoot, address: "sdns://BAcAAAAAAAAAAAAXZG5zLmFkZ3VhcmQtZG5zLmNvbTo3ODQ", }, { // Cloudflare DNS (DNS-over-HTTPS) @@ -189,7 +192,7 @@ func TestUpstreams(t *testing.T) { address: "https://1.1.1.1/dns-query", }, { // AdGuard DNS (DNS-over-QUIC) - bootstrap: googleRslv, + bootstrap: googleBoot, address: "quic://dns.adguard-dns.com", }, { // Google DNS (HTTP3) @@ -215,7 +218,7 @@ func TestAddressToUpstream(t *testing.T) { cloudflareRslv, err := NewUpstreamResolver("1.1.1.1", nil) require.NoError(t, err) - opt := &Options{Bootstrap: cloudflareRslv} + opt := &Options{Bootstrap: NewCachingResolver(cloudflareRslv)} testCases := []struct { addr string @@ -314,7 +317,7 @@ func TestUpstreamDoTBootstrap(t *testing.T) { require.NoError(t, err) u, err := AddressToUpstream(tc.address, &Options{ - Bootstrap: rslv, + Bootstrap: NewCachingResolver(rslv), Timeout: timeout, }) require.NoErrorf(t, err, "failed to generate upstream from address %s", tc.address) @@ -361,7 +364,7 @@ func TestUpstreamsInvalidBootstrap(t *testing.T) { }) require.NoError(t, err) - rslv = append(rslv, r) + rslv = append(rslv, NewCachingResolver(r)) } u, err := AddressToUpstream(tc.address, &Options{