diff --git a/go.mod b/go.mod index ba34fd10e..011b011cf 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/AdguardTeam/dnsproxy go 1.20 require ( - github.com/AdguardTeam/golibs v0.17.0 + github.com/AdguardTeam/golibs v0.18.0 github.com/ameshkov/dnscrypt/v2 v2.2.7 github.com/ameshkov/dnsstamps v1.0.3 github.com/beefsack/go-rate v0.0.0-20220214233405-116f4ca011a0 diff --git a/go.sum b/go.sum index 7934a1fc3..ed7f0b364 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ -github.com/AdguardTeam/golibs v0.17.0 h1:oPp2+2kV41qH45AIFbAlHFTPQOQ6JbF+JemjeECFn1g= -github.com/AdguardTeam/golibs v0.17.0/go.mod h1:DKhCIXHcUYtBhU8ibTLKh1paUL96n5zhQBlx763sj+U= +github.com/AdguardTeam/golibs v0.17.3 h1:V3XWPh2OirWuz0lgvcrpq7Hz3qcb0j6gq5+oPSxe2/Y= +github.com/AdguardTeam/golibs v0.17.3/go.mod h1:DKhCIXHcUYtBhU8ibTLKh1paUL96n5zhQBlx763sj+U= +github.com/AdguardTeam/golibs v0.18.0 h1:ckS2YK7t2Ub6UkXl0fnreVaM15Zb07Hh1gmFqttjpWg= +github.com/AdguardTeam/golibs v0.18.0/go.mod h1:DKhCIXHcUYtBhU8ibTLKh1paUL96n5zhQBlx763sj+U= github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmHS9iAKVt9AyzRSqNU1qabPih5BY= github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da/go.mod h1:eHEWzANqSiWQsof+nXEI9bUVUyV6F53Fp89EuCh2EAA= github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635 h1:52m0LGchQBBVqJRyYYufQuIbVqRawmubW3OFGqK1ekw= diff --git a/internal/netutil/hosts.go b/internal/netutil/hosts.go deleted file mode 100644 index e071c5f70..000000000 --- a/internal/netutil/hosts.go +++ /dev/null @@ -1,155 +0,0 @@ -package netutil - -import ( - "fmt" - "io" - "net/netip" - "strings" - - "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/hostsfile" - "github.com/AdguardTeam/golibs/log" - "golang.org/x/exp/slices" -) - -// unit is a convenient alias for empty struct. -type unit = struct{} - -// set is a helper type that removes duplicates. -type set[K string | netip.Addr] map[K]unit - -// orderedSet is a helper type for storing values in original adding order and -// dealing with duplicates. -type orderedSet[K string | netip.Addr] struct { - set set[K] - vals []K -} - -// add adds val to os if it's not already there. -func (os *orderedSet[K]) add(key, val K) { - if _, ok := os.set[key]; !ok { - os.set[key] = unit{} - os.vals = append(os.vals, val) - } -} - -// Convenience aliases for [orderedSet]. -type ( - namesSet = orderedSet[string] - addrsSet = orderedSet[netip.Addr] -) - -// Hosts is a [hostsfile.HandleSet] that removes duplicates. -// -// It must be initialized with [NewHosts]. -// -// TODO(e.burkov): Think of storing only slices. -// -// TODO(e.burkov): Move to netutil/hostsfile in module golibs as a default -// implementation of some storage interface. -type Hosts struct { - // names maps each address to its names in original case and in original - // adding order without duplicates. - names map[netip.Addr]*namesSet - - // addrs maps each host to its addresses in original adding order without - // duplicates. - addrs map[string]*addrsSet -} - -// NewHosts parses hosts files from r and returns a new Hosts set. readers are -// optional, the error is only returned in case of parsing error. -func NewHosts(readers ...io.Reader) (h *Hosts, err error) { - h = &Hosts{ - names: map[netip.Addr]*namesSet{}, - addrs: map[string]*addrsSet{}, - } - - for i, r := range readers { - if err = hostsfile.Parse(h, r, nil); err != nil { - return nil, fmt.Errorf("reader at index %d: %w", i, err) - } - } - - return h, nil -} - -// type check -var _ hostsfile.HandleSet = (*Hosts)(nil) - -// Add implements the [hostsfile.Set] interface for *Hosts. -func (h *Hosts) Add(rec *hostsfile.Record) { - names := h.names[rec.Addr] - if names == nil { - names = &namesSet{set: set[string]{}} - h.names[rec.Addr] = names - } - - for _, name := range rec.Names { - lowered := strings.ToLower(name) - names.add(lowered, name) - - addrs := h.addrs[lowered] - if addrs == nil { - addrs = &addrsSet{ - vals: []netip.Addr{}, - set: set[netip.Addr]{}, - } - h.addrs[lowered] = addrs - } - addrs.add(rec.Addr, rec.Addr) - } -} - -// HandleInvalid implements the [hostsfile.HandleSet] interface for *Hosts. -func (h *Hosts) HandleInvalid(srcName string, _ []byte, err error) { - lineErr := &hostsfile.LineError{} - if !errors.As(err, &lineErr) { - log.Debug("hostset: unexpected error from hostsfile: %s", err) - - return - } - - if errors.Is(err, hostsfile.ErrEmptyLine) { - // Ignore empty lines and comments. - return - } - - log.Debug("hostset: source %q: %s", srcName, lineErr) -} - -// ByAddr returns each host for addr in original case, in original adding order -// without duplicates. It returns nil if h doesn't contain the addr. -func (h *Hosts) ByAddr(addr netip.Addr) (hosts []string) { - if hostsSet, ok := h.names[addr]; ok { - return hostsSet.vals - } - - return nil -} - -// ByName returns each address for host in original adding order without -// duplicates. It returns nil if h doesn't contain the host. -func (h *Hosts) ByName(host string) (addrs []netip.Addr) { - if addrsSet, ok := h.addrs[strings.ToLower(host)]; ok { - return addrsSet.vals - } - - return nil -} - -// Mappings returns a deep clone of the internal mappings. -func (h *Hosts) Mappings() (names map[netip.Addr][]string, addrs map[string][]netip.Addr) { - names = make(map[netip.Addr][]string, len(h.names)) - addrs = make(map[string][]netip.Addr, len(h.addrs)) - - for addr, namesSet := range h.names { - names[addr] = slices.Clone(namesSet.vals) - } - - for name, addrsSet := range h.addrs { - addrs[name] = slices.Clone(addrsSet.vals) - } - - return names, addrs -} diff --git a/internal/netutil/hosts_test.go b/internal/netutil/hosts_test.go deleted file mode 100644 index 6f1694ced..000000000 --- a/internal/netutil/hosts_test.go +++ /dev/null @@ -1,119 +0,0 @@ -package netutil_test - -import ( - "io/fs" - "net/netip" - "os" - "path" - "testing" - - "github.com/AdguardTeam/dnsproxy/internal/netutil" - "github.com/AdguardTeam/golibs/testutil" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "golang.org/x/exp/maps" - "golang.org/x/exp/slices" -) - -// testdata is an [fs.FS] containing data for tests. -var testdata = os.DirFS("./testdata") - -func TestHosts(t *testing.T) { - t.Parallel() - - var h *netutil.Hosts - var err error - t.Run("good_file", func(t *testing.T) { - var f fs.File - f, err = testdata.Open(path.Join(t.Name(), "hosts")) - require.NoError(t, err) - testutil.CleanupAndRequireSuccess(t, f.Close) - - h, err = netutil.NewHosts(f) - }) - require.NoError(t, err) - - // Variables mirroring the testdata/TestHosts/hosts file. - var ( - v4Addr1 = netip.MustParseAddr("0.0.0.1") - v4Addr2 = netip.MustParseAddr("0.0.0.2") - - mappedAddr1 = netip.MustParseAddr("::ffff:0.0.0.1") - mappedAddr2 = netip.MustParseAddr("::ffff:0.0.0.2") - - v6Addr1 = netip.MustParseAddr("::1") - v6Addr2 = netip.MustParseAddr("::2") - - wantHosts = map[string][]netip.Addr{ - "host.one": {v4Addr1, mappedAddr1, v6Addr1}, - "host.two": {v4Addr2, mappedAddr2, v6Addr2}, - "host.new": {v4Addr2, v4Addr1, mappedAddr2, mappedAddr1, v6Addr2, v6Addr1}, - "again.host.two": {v4Addr2, mappedAddr2, v6Addr2}, - } - - wantAddrs = map[netip.Addr][]string{ - v4Addr1: {"Host.One", "host.new"}, - v4Addr2: {"Host.Two", "Host.New", "Again.Host.Two"}, - mappedAddr1: {"Host.One", "host.new"}, - mappedAddr2: {"Host.Two", "Host.New", "Again.Host.Two"}, - v6Addr1: {"Host.One", "host.new"}, - v6Addr2: {"Host.Two", "Host.New", "Again.Host.Two"}, - } - ) - - t.Run("Mappings", func(t *testing.T) { - names, addrs := h.Mappings() - assert.Equal(t, wantAddrs, names) - assert.Equal(t, wantHosts, addrs) - }) - - t.Run("ByAddr", func(t *testing.T) { - t.Parallel() - - // Sort keys to make the test deterministic. - addrs := maps.Keys(wantAddrs) - slices.SortFunc(addrs, netip.Addr.Compare) - - for _, addr := range addrs { - addr := addr - t.Run(addr.String(), func(t *testing.T) { - t.Parallel() - - assert.Equal(t, wantAddrs[addr], h.ByAddr(addr)) - }) - } - }) - - t.Run("ByHost", func(t *testing.T) { - t.Parallel() - - // Sort keys to make the test deterministic. - hosts := maps.Keys(wantHosts) - slices.Sort(hosts) - - for _, host := range hosts { - host := host - t.Run(host, func(t *testing.T) { - t.Parallel() - - assert.Equal(t, wantHosts[host], h.ByName(host)) - }) - } - }) - - t.Run("bad_file", func(t *testing.T) { - var f fs.File - f, err = testdata.Open(path.Join(t.Name(), "hosts")) - require.NoError(t, err) - testutil.CleanupAndRequireSuccess(t, f.Close) - - _, err = netutil.NewHosts(f) - require.NoError(t, err) - }) - - t.Run("non-line_error", func(t *testing.T) { - assert.NotPanics(t, func() { - (&netutil.Hosts{}).HandleInvalid("test", nil, assert.AnError) - }) - }) -} diff --git a/internal/osutil/osutil.go b/internal/osutil/osutil.go deleted file mode 100644 index ab25908ab..000000000 --- a/internal/osutil/osutil.go +++ /dev/null @@ -1,14 +0,0 @@ -// Package osutil contains utilities for functions requiring system calls and -// other OS-specific APIs, except for network-related ones. -package osutil - -import "io/fs" - -// RootDirFS returns the fs.FS rooted at the operating system's root. On -// Windows it returns the fs.FS rooted at the volume of the system directory -// (usually, C:). -// -// TODO(e.burkov): Move to golibs. -func RootDirFS() (fsys fs.FS) { - return rootDirFS() -} diff --git a/internal/osutil/osutil_unix.go b/internal/osutil/osutil_unix.go deleted file mode 100644 index f15505a16..000000000 --- a/internal/osutil/osutil_unix.go +++ /dev/null @@ -1,12 +0,0 @@ -//go:build !windows - -package osutil - -import ( - "io/fs" - "os" -) - -func rootDirFS() (fsys fs.FS) { - return os.DirFS("/") -} diff --git a/internal/osutil/osutil_windows.go b/internal/osutil/osutil_windows.go deleted file mode 100644 index 7d5459e29..000000000 --- a/internal/osutil/osutil_windows.go +++ /dev/null @@ -1,25 +0,0 @@ -//go:build windows - -package osutil - -import ( - "io/fs" - "os" - "path/filepath" - - "github.com/AdguardTeam/golibs/log" - "golang.org/x/sys/windows" -) - -func rootDirFS() (fsys fs.FS) { - // TODO(a.garipov): Use a better way if golang/go#44279 is ever resolved. - sysDir, err := windows.GetSystemDirectory() - if err != nil { - log.Error("aghos: getting root filesystem: %s; using C:", err) - - // Assume that C: is the safe default. - return os.DirFS("C:") - } - - return os.DirFS(filepath.VolumeName(sysDir)) -} diff --git a/main.go b/main.go index fbad255bb..8e5b731e2 100644 --- a/main.go +++ b/main.go @@ -15,14 +15,13 @@ import ( "syscall" "time" - "github.com/AdguardTeam/dnsproxy/internal/bootstrap" proxynetutil "github.com/AdguardTeam/dnsproxy/internal/netutil" - "github.com/AdguardTeam/dnsproxy/internal/osutil" "github.com/AdguardTeam/dnsproxy/internal/version" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/mathutil" + "github.com/AdguardTeam/golibs/osutil" "github.com/AdguardTeam/golibs/timeutil" "github.com/ameshkov/dnscrypt/v2" goFlags "github.com/jessevdk/go-flags" @@ -477,7 +476,7 @@ func initBootstrap(bootstraps []string, opts *upstream.Options) (r upstream.Reso switch len(resolvers) { case 0: - etcHosts, hostsErr := bootstrap.NewDefaultHostsResolver(osutil.RootDirFS()) + etcHosts, hostsErr := upstream.NewDefaultHostsResolver(osutil.RootDirFS()) if hostsErr != nil { log.Error("creating default hosts resolver: %s", hostsErr) diff --git a/scripts/make/go-lint.sh b/scripts/make/go-lint.sh index 4fc5393de..69d899041 100644 --- a/scripts/make/go-lint.sh +++ b/scripts/make/go-lint.sh @@ -173,7 +173,6 @@ run_linter govulncheck ./... run_linter gocyclo --over 10\ ./internal/bootstrap/\ ./internal/netutil/\ - ./internal/osutil/\ ./internal/version/\ ./proxyutil/\ ./upstream/\ @@ -186,7 +185,6 @@ run_linter gocyclo --over 15 ./proxy/ # TODO(a.garipov): Enable for all. run_linter gocognit --over 10\ ./internal/bootstrap/\ - ./internal/osutil/\ ./internal/version/\ ./proxyutil/\ ./upstream/\ diff --git a/internal/bootstrap/hostsresolver.go b/upstream/hostsresolver.go similarity index 59% rename from internal/bootstrap/hostsresolver.go rename to upstream/hostsresolver.go index 13a2b527a..75a2236ff 100644 --- a/internal/bootstrap/hostsresolver.go +++ b/upstream/hostsresolver.go @@ -1,4 +1,4 @@ -package bootstrap +package upstream import ( "context" @@ -6,55 +6,53 @@ import ( "io/fs" "net/netip" - "github.com/AdguardTeam/dnsproxy/internal/netutil" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/hostsfile" "github.com/AdguardTeam/golibs/log" "golang.org/x/exp/slices" ) -// HostsResolver is a [Resolver] that uses [netutil.Hosts] as a source of IP. +// HostsResolver is a [Resolver] that looks into system hosts files, see +// [hostsfile]. type HostsResolver struct { - // addrs is an actual source of IP addresses. - addrs map[string][]netip.Addr + // strg contains all the hosts file data needed for lookups. + strg hostsfile.Storage } // NewHostsResolver is the resolver based on system hosts files. -func NewHostsResolver(hosts *netutil.Hosts) (hr *HostsResolver) { - hr = &HostsResolver{} - _, hr.addrs = hosts.Mappings() - - return hr +func NewHostsResolver(hosts hostsfile.Storage) (hr *HostsResolver) { + return &HostsResolver{ + strg: hosts, + } } // NewDefaultHostsResolver returns a resolver based on system hosts files // provided by the [hostsfile.DefaultHostsPaths] and read from rootFSys. -// -// TODO(e.burkov): Use. func NewDefaultHostsResolver(rootFSys fs.FS) (hr *HostsResolver, err error) { paths, err := hostsfile.DefaultHostsPaths() if err != nil { return nil, fmt.Errorf("getting default hosts paths: %w", err) } - hosts, _ := netutil.NewHosts() - for _, name := range paths { - err = parseHostsFile(rootFSys, hosts, name) + // The error is always nil here since no readers passed. + strg, _ := hostsfile.NewDefaultStorage() + for _, filename := range paths { + err = parseHostsFile(rootFSys, strg, filename) if err != nil { // Don't wrap the error since it's already informative enough as is. return nil, err } } - return NewHostsResolver(hosts), nil + return NewHostsResolver(strg), nil } // parseHostsFile reads a single hosts file from fsys and parses it into hosts. -func parseHostsFile(fsys fs.FS, hosts *netutil.Hosts, name string) (err error) { - f, err := fsys.Open(name) +func parseHostsFile(fsys fs.FS, hosts hostsfile.Set, filename string) (err error) { + f, err := fsys.Open(filename) if err != nil { if errors.Is(err, fs.ErrNotExist) { - log.Debug("hosts file %q doesn't exist", name) + log.Debug("hosts file %q doesn't exist", filename) return nil } @@ -63,8 +61,6 @@ func parseHostsFile(fsys fs.FS, hosts *netutil.Hosts, name string) (err error) { return err } - // TODO(e.burkov): Use [errors.Join] when it will be supported by all - // dependencies. defer func() { err = errors.WithDeferred(err, f.Close()) }() return hostsfile.Parse(hosts, f, nil) @@ -79,17 +75,23 @@ func (hr *HostsResolver) LookupNetIP( network string, host string, ) (addrs []netip.Addr, err error) { - var checkIP func(netip.Addr) (ok bool) + var ipMatches func(netip.Addr) (ok bool) switch network { case "ip4": - addrs, checkIP = slices.Clone(hr.addrs[host]), netip.Addr.Is6 + ipMatches = netip.Addr.Is4 case "ip6": - addrs, checkIP = slices.Clone(hr.addrs[host]), netip.Addr.Is4 + ipMatches = netip.Addr.Is6 case "ip": - return slices.Clone(hr.addrs[host]), nil + return slices.Clone(hr.strg.ByName(host)), nil default: return nil, fmt.Errorf("unsupported network %q", network) } - return slices.DeleteFunc(addrs, checkIP), nil + for _, addr := range hr.strg.ByName(host) { + if ipMatches(addr) { + addrs = append(addrs, addr) + } + } + + return addrs, nil } diff --git a/internal/bootstrap/hostsresolver_test.go b/upstream/hostsresolver_test.go similarity index 83% rename from internal/bootstrap/hostsresolver_test.go rename to upstream/hostsresolver_test.go index a50a3ca53..fe2ca6256 100644 --- a/internal/bootstrap/hostsresolver_test.go +++ b/upstream/hostsresolver_test.go @@ -1,13 +1,13 @@ -package bootstrap_test +package upstream_test import ( "context" "net/netip" - "strings" "testing" + "testing/fstest" - "github.com/AdguardTeam/dnsproxy/internal/bootstrap" - "github.com/AdguardTeam/dnsproxy/internal/netutil" + "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/AdguardTeam/golibs/hostsfile" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -24,10 +24,18 @@ func TestHostsResolver_LookupNetIP(t *testing.T) { v6Addr = netip.MustParseAddr("::1") ) - hosts, err := netutil.NewHosts(strings.NewReader(hostsData)) + paths, err := hostsfile.DefaultHostsPaths() require.NoError(t, err) + require.NotEmpty(t, paths) - hr := bootstrap.NewHostsResolver(hosts) + fsys := fstest.MapFS{ + paths[0]: { + Data: []byte(hostsData), + }, + } + + hr, err := upstream.NewDefaultHostsResolver(fsys) + require.NoError(t, err) testCases := []struct { name string @@ -73,12 +81,12 @@ func TestHostsResolver_LookupNetIP(t *testing.T) { name: "family_mismatch_v4", host: "ipv6.only", net: "ip4", - wantAddrs: []netip.Addr{}, + wantAddrs: nil, }, { name: "family_mismatch_v6", host: "ipv4.only", net: "ip6", - wantAddrs: []netip.Addr{}, + wantAddrs: nil, }} for _, tc := range testCases {