Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: TRR2 resolver #846

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions internal/measurexlite/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ func (r *resolverTrace) LookupNS(ctx context.Context, domain string) ([]*net.NS,
return r.r.LookupNS(netxlite.ContextWithTrace(ctx, r.tx), domain)
}

// NewStdlibResolver returns a trace-aware stdlib resolver
func (tx *Trace) NewStdlibResolver(logger model.Logger, dialer model.Dialer, address string) model.Resolver {
return tx.newParallelResolverTrace(func() model.Resolver {
return netxlite.NewStdlibResolver(logger)
})
}

// NewParallelUDPResolver returns a trace-ware parallel UDP resolver
func (tx *Trace) NewParallelUDPResolver(logger model.Logger, dialer model.Dialer, address string) model.Resolver {
return tx.newParallelResolverTrace(func() model.Resolver {
Expand All @@ -78,6 +85,41 @@ func (tx *Trace) NewParallelDNSOverHTTPSResolver(logger model.Logger, URL string
})
}

// newSimpleResolverTrace is equivalent to returning a simple resolver
// except that it returns a model.SimpleResolver that uses this trace.
func (tx *Trace) newSimpleResolverTrace(newResolver func() model.SimpleResolver) model.SimpleResolver {
return &simpleResolverTrace{
r: tx.newSimpleResolver(newResolver),
tx: tx,
}
}

// simpleResolverTrace is a trace-aware simple resolver
type simpleResolverTrace struct {
r model.SimpleResolver
tx *Trace
}

var _ model.SimpleResolver = &simpleResolverTrace{}

// Network implements model.SimpleResolver.Network
func (r *simpleResolverTrace) Network() string {
return r.r.Network()
}

// LookupHost implements model.SimpleResolver.LookupHost
func (r *simpleResolverTrace) LookupHost(ctx context.Context, hostname string) ([]string, error) {
return r.r.LookupHost(netxlite.ContextWithTrace(ctx, r.tx), hostname)
}

// NewTrustedRecursiveResolver2 returns a trace-aware TRR2 resolver
func (tx *Trace) NewTrustedRecursiveResolver2(logger model.Logger, address string,
timeout int) model.SimpleResolver {
return tx.newSimpleResolverTrace(func() model.SimpleResolver {
return NewTrustedRecursiveResolver2(logger, address, timeout)
})
}

// OnDNSRoundTripForLookupHost implements model.Trace.OnDNSRoundTripForLookupHost
func (tx *Trace) OnDNSRoundTripForLookupHost(started time.Time, reso model.Resolver, query model.DNSQuery,
response model.DNSResponse, addrs []string, err error, finished time.Time) {
Expand Down
276 changes: 274 additions & 2 deletions internal/measurexlite/dns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,20 @@ package measurexlite

import (
"context"
"net"
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/miekg/dns"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
"github.com/ooni/probe-cli/v3/internal/netxlite"
"github.com/ooni/probe-cli/v3/internal/testingx"
)

func TestNewUnwrappedParallelResolver(t *testing.T) {
t.Run("NewUnwrappedParallelResolver creates an UnwrappedParallelResolver with Trace", func(t *testing.T) {
func TestNewParallelResolver(t *testing.T) {
t.Run("NewParallelResolverTrace creates an ParallelResolver with Trace", func(t *testing.T) {
underlying := &mocks.Resolver{}
zeroTime := time.Now()
trace := NewTrace(0, zeroTime)
Expand Down Expand Up @@ -44,6 +46,19 @@ func TestNewUnwrappedParallelResolver(t *testing.T) {
MockNetwork: func() string {
return "udp"
},
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return []string{"1.1.1.1"}, nil
},
MockLookupHTTPS: func(ctx context.Context, domain string) (*model.HTTPSSvc, error) {
return &model.HTTPSSvc{
IPv4: []string{"1.1.1.1"},
}, nil
},
MockLookupNS: func(ctx context.Context, domain string) ([]*net.NS, error) {
return []*net.NS{{
Host: "1.1.1.1",
}}, nil
},
MockCloseIdleConnections: func() {
called = true
},
Expand All @@ -65,6 +80,46 @@ func TestNewUnwrappedParallelResolver(t *testing.T) {
}
})

t.Run("LookupHost is correctly forwarded", func(t *testing.T) {
want := []string{"1.1.1.1"}
ctx := context.Background()
got, err := resolver.LookupHost(ctx, "example.com")
if err != nil {
t.Fatal("expected nil error")
}
if diff := cmp.Diff(want, got); diff != "" {
t.Fatal(diff)
}
})

t.Run("LookupHTTPS is correctly forwarded", func(t *testing.T) {
want := &model.HTTPSSvc{
IPv4: []string{"1.1.1.1"},
}
ctx := context.Background()
got, err := resolver.LookupHTTPS(ctx, "example.com")
if err != nil {
t.Fatal("expected nil error")
}
if diff := cmp.Diff(want, got); diff != "" {
t.Fatal(diff)
}
})

t.Run("LookupHost is correctly forwarded", func(t *testing.T) {
want := []*net.NS{{
Host: "1.1.1.1",
}}
ctx := context.Background()
got, err := resolver.LookupNS(ctx, "example.com")
if err != nil {
t.Fatal("expected nil error")
}
if diff := cmp.Diff(want, got); diff != "" {
t.Fatal(diff)
}
})

t.Run("CloseIdleConnections is correctly forwarded", func(t *testing.T) {
resolver.CloseIdleConnections()
if !called {
Expand Down Expand Up @@ -221,6 +276,223 @@ func TestNewUnwrappedParallelResolver(t *testing.T) {
})
}

func TestNewSimpleResolver(t *testing.T) {
t.Run("NewSimpleResolverTrace creates a SimpleResolver with Trace", func(t *testing.T) {
underlying := &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return []string{}, nil
},
}
zeroTime := time.Now()
trace := NewTrace(0, zeroTime)
trace.NewSimpleResolverFn = func() model.SimpleResolver {
return underlying
}
resolver := trace.newSimpleResolverTrace(func() model.SimpleResolver {
return nil
})
resolvert := resolver.(*simpleResolverTrace)
if resolvert.r != underlying {
t.Fatal("invalid simple resolver")
}
if resolvert.tx != trace {
t.Fatal("invalid trace")
}
})

t.Run("Trace-aware resolver forwards underlying functions", func(t *testing.T) {
zeroTime := time.Now()
trace := NewTrace(0, zeroTime)
newMockResolver := func() model.SimpleResolver {
return &mocks.Resolver{
MockNetwork: func() string {
return "udp"
},
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return []string{"1.1.1.1"}, nil
},
}
}
resolver := trace.newSimpleResolver(newMockResolver)

t.Run("Network is correctly forwarded", func(t *testing.T) {
got := resolver.Network()
if got != "udp" {
t.Fatal("Network not called")
}
})

t.Run("LookupHost is correctly forwarded", func(t *testing.T) {
want := []string{"1.1.1.1"}
ctx := context.Background()
got, err := resolver.LookupHost(ctx, "example.com")
if err != nil {
t.Fatal("expected nil error")
}
if diff := cmp.Diff(want, got); diff != "" {
t.Fatal(diff)
}
})
})

t.Run("LookupHost saves into trace", func(t *testing.T) {
zeroTime := time.Now()
td := testingx.NewTimeDeterministic(zeroTime)
trace := NewTrace(0, zeroTime)
trace.TimeNowFn = td.Now
txp := &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
response := &mocks.DNSResponse{
MockDecodeLookupHost: func() ([]string, error) {
if query.Type() != dns.TypeA {
return []string{"fe80::a00:20ff:feb9:4c54"}, nil
}
return []string{"1.1.1.1"}, nil
},
}
return response, nil
},
MockRequiresPadding: func() bool {
return true
},
MockNetwork: func() string {
return ""
},
MockAddress: func() string {
return "dns.google"
},
}
newSimpleResolver := func() model.SimpleResolver {
return &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
reso := netxlite.NewUnwrappedParallelResolver(txp)
return reso.LookupHost(ctx, domain)
},
}
}
resolver := trace.newSimpleResolverTrace(newSimpleResolver)
ctx := context.Background()
addrs, err := resolver.LookupHost(ctx, "example.com")
if err != nil {
t.Fatal("unexpected err", err)
}
if len(addrs) != 2 {
t.Fatal("unexpected array output", addrs)
}
if addrs[0] != "1.1.1.1" && addrs[1] != "1.1.1.1" {
t.Fatal("unexpected array output", addrs)
}
if addrs[0] != "fe80::a00:20ff:feb9:4c54" && addrs[1] != "fe80::a00:20ff:feb9:4c54" {
t.Fatal("unexpected array output", addrs)
}

t.Run("DNSLookups QueryType A", func(t *testing.T) {
events := trace.DNSLookupsFromRoundTrip(dns.TypeA)
if len(events) != 1 {
t.Fatal("expected to see single DNSLookup event")
}
lookup := events[0]
answers := lookup.Answers
if lookup.Failure != nil {
t.Fatal("unexpected err", *(lookup.Failure))
}
if lookup.ResolverAddress != "dns.google" {
t.Fatal("unexpected address field")
}
if len(answers) != 1 {
t.Fatal("expected 1 DNS answer, got", len(answers))
}
if answers[0].AnswerType != "A" || answers[0].IPv4 != "1.1.1.1" {
t.Fatal("unexpected DNS answer", answers)
}
})

t.Run("DNSLookups QueryType AAAA", func(t *testing.T) {
events := trace.DNSLookupsFromRoundTrip(dns.TypeAAAA)
if len(events) != 1 {
t.Fatal("expected to see single DNSLookup event")
}
lookup := events[0]
answers := lookup.Answers
if lookup.Failure != nil {
t.Fatal("unexpected err", *(lookup.Failure))
}
if lookup.ResolverAddress != "dns.google" {
t.Fatal("unexpected address field")
}
if len(answers) != 1 {
t.Fatal("expected 1 DNS answer, got", len(answers))
}
if answers[0].AnswerType != "AAAA" || answers[0].IPv6 != "fe80::a00:20ff:feb9:4c54" {
t.Fatal("unexpected DNS answer", answers)
}
})
})

t.Run("LookupHost discards events when buffers are full", func(t *testing.T) {
zeroTime := time.Now()
td := testingx.NewTimeDeterministic(zeroTime)
trace := NewTrace(0, zeroTime)
trace.DNSLookup = map[uint16]chan *model.ArchivalDNSLookupResult{
dns.TypeA: make(chan *model.ArchivalDNSLookupResult), // no buffer
dns.TypeAAAA: make(chan *model.ArchivalDNSLookupResult), // no buffer
}
trace.TimeNowFn = td.Now
txp := &mocks.DNSTransport{
MockRoundTrip: func(ctx context.Context, query model.DNSQuery) (model.DNSResponse, error) {
response := &mocks.DNSResponse{
MockDecodeLookupHost: func() ([]string, error) {
if query.Type() != dns.TypeA {
return []string{"fe80::a00:20ff:feb9:4c54"}, nil
}
return []string{"1.1.1.1"}, nil
},
}
return response, nil
},
MockRequiresPadding: func() bool {
return true
},
MockNetwork: func() string {
return ""
},
MockAddress: func() string {
return "dns.google"
},
}
newSimpleResolver := func() model.SimpleResolver {
return &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
reso := netxlite.NewUnwrappedParallelResolver(txp)
return reso.LookupHost(ctx, domain)
},
}
}
resolver := trace.newSimpleResolverTrace(newSimpleResolver)
ctx := context.Background()
addrs, err := resolver.LookupHost(ctx, "example.com")
if err != nil {
t.Fatal("unexpected err", err)
}
if len(addrs) != 2 {
t.Fatal("unexpected array output", addrs)
}

t.Run("DNSLookups QueryType A", func(t *testing.T) {
events := trace.DNSLookupsFromRoundTrip(dns.TypeA)
if len(events) != 0 {
t.Fatal("expected to see no DNSLookup")
}
})
t.Run("DNSLookups QueryType AAAA", func(t *testing.T) {
events := trace.DNSLookupsFromRoundTrip(dns.TypeAAAA)
if len(events) != 0 {
t.Fatal("expected to see no DNSLookup")
}
})
})
}

func TestAnswersFromAddrs(t *testing.T) {
tests := []struct {
name string
Expand Down
13 changes: 13 additions & 0 deletions internal/measurexlite/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ type Trace struct {
// calls to the netxlite.NewParallelResolver factory.
NewParallelResolverFn func() model.Resolver

// NewSimpleResolverFn is OPTIONAL and can be used to override
// calls to the model.SimpleResolver factory functions.
NewSimpleResolverFn func() model.SimpleResolver

// NewDialerWithoutResolverFn is OPTIONAL and can be used to override
// calls to the netxlite.NewDialerWithoutResolver factory.
NewDialerWithoutResolverFn func(dl model.DebugLogger) model.Dialer
Expand Down Expand Up @@ -164,6 +168,15 @@ func (tx *Trace) newParallelResolver(newResolver func() model.Resolver) model.Re
return newResolver()
}

// newSimpleResolver indirectly calls the passed simple resolver
// thus allowing us to mock this function for testing
func (tx *Trace) newSimpleResolver(newResolver func() model.SimpleResolver) model.SimpleResolver {
if tx.NewSimpleResolverFn != nil {
return tx.NewSimpleResolverFn()
}
return newResolver()
}

// newTLSHandshakerStdlib indirectly calls netxlite.NewTLSHandshakerStdlib
// thus allowing us to mock this func for testing.
func (tx *Trace) newTLSHandshakerStdlib(dl model.DebugLogger) model.TLSHandshaker {
Expand Down
Loading