From a313be2161b65435a864f66d41d61ad33b467a77 Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Tue, 3 Mar 2020 15:24:48 +0100 Subject: [PATCH] Vendor ooni/netx@4f8d645bce and use it The vendoring has omitted files we don't care about. The source code is not changed. The README.md and DESIGN.md file have been updated. Part of https://github.com/ooni/probe-engine/issues/302. --- experiment/sniblocking/sniblocking.go | 2 +- experiment/telegram/telegram_test.go | 2 +- go.mod | 4 +- go.sum | 6 - internal/httpx/httpx.go | 2 +- internal/netxlogger/netxlogger.go | 2 +- internal/netxlogger/netxlogger_test.go | 2 +- internal/oonidatamodel/oonidatamodel.go | 2 +- internal/oonidatamodel/oonidatamodel_test.go | 2 +- internal/oonitemplates/oonitemplates.go | 8 +- netx/DESIGN.md | 404 +++++++++ netx/README.md | 23 + netx/handlers/handlers.go | 29 + netx/handlers/handlers_test.go | 13 + netx/httpx/httpx.go | 155 ++++ netx/httpx/httpx_test.go | 144 ++++ netx/internal/connid/connid.go | 31 + netx/internal/connid/connid_test.go | 80 ++ netx/internal/dialer/connx/connx.go | 83 ++ netx/internal/dialer/connx/connx_test.go | 61 ++ netx/internal/dialer/dialer.go | 21 + netx/internal/dialer/dialer_test.go | 33 + netx/internal/dialer/dialerbase/dialerbase.go | 97 +++ .../dialer/dialerbase/dialerbase_test.go | 43 + netx/internal/dialer/dnsdialer/dnsdialer.go | 103 +++ .../dialer/dnsdialer/dnsdialer_test.go | 135 +++ netx/internal/dialer/tlsdialer/tlsdialer.go | 103 +++ .../dialer/tlsdialer/tlsdialer_test.go | 93 ++ netx/internal/dialid/dialid.go | 23 + netx/internal/dialid/dialid_test.go | 24 + netx/internal/errwrapper/errwrapper.go | 127 +++ netx/internal/errwrapper/errwrapper_test.go | 167 ++++ .../httptransport/bodytracer/bodytracer.go | 84 ++ .../bodytracer/bodytracer_test.go | 39 + netx/internal/httptransport/httptransport.go | 47 + .../httptransport/httptransport_test.go | 39 + .../tracetripper/tracetripper.go | 274 ++++++ .../tracetripper/tracetripper_test.go | 272 ++++++ .../transactioner/transactioner.go | 39 + .../transactioner/transactioner_test.go | 57 ++ netx/internal/internal.go | 355 ++++++++ netx/internal/internal_test.go | 403 +++++++++ .../resolver/bogondetector/bogondetector.go | 53 ++ .../bogondetector/bogondetector_test.go | 18 + .../resolver/brokenresolver/brokenresolver.go | 44 + .../brokenresolver/brokenresolver_test.go | 61 ++ .../resolver/chainresolver/chainresolver.go | 68 ++ .../chainresolver/chainresolver_test.go | 64 ++ .../dnstransport/dnsoverhttps/dnsoverhttps.go | 69 ++ .../dnsoverhttps/dnsoverhttps_test.go | 114 +++ .../dnstransport/dnsovertcp/dnsovertcp.go | 136 +++ .../dnsovertcp/dnsovertcp_test.go | 192 +++++ .../dnstransport/dnsoverudp/dnsoverudp.go | 64 ++ .../dnsoverudp/dnsoverudp_test.go | 164 ++++ .../resolver/ooniresolver/ooniresolver.go | 233 +++++ .../ooniresolver/ooniresolver_test.go | 278 ++++++ .../resolver/parentresolver/parentresolver.go | 138 +++ .../parentresolver/parentresolver_test.go | 153 ++++ netx/internal/resolver/resolver.go | 50 ++ netx/internal/resolver/resolver_test.go | 103 +++ .../resolver/systemresolver/systemresolver.go | 70 ++ .../systemresolver/systemresolver_test.go | 99 +++ netx/internal/transactionid/transactionid.go | 24 + .../transactionid/transactionid_test.go | 24 + netx/modelx/modelx.go | 807 ++++++++++++++++++ netx/modelx/modelx_test.go | 84 ++ netx/netx.go | 147 ++++ netx/netx_test.go | 155 ++++ netx/testdata/cacert.pem | 54 ++ 69 files changed, 7080 insertions(+), 19 deletions(-) create mode 100644 netx/DESIGN.md create mode 100644 netx/README.md create mode 100644 netx/handlers/handlers.go create mode 100644 netx/handlers/handlers_test.go create mode 100644 netx/httpx/httpx.go create mode 100644 netx/httpx/httpx_test.go create mode 100644 netx/internal/connid/connid.go create mode 100644 netx/internal/connid/connid_test.go create mode 100644 netx/internal/dialer/connx/connx.go create mode 100644 netx/internal/dialer/connx/connx_test.go create mode 100644 netx/internal/dialer/dialer.go create mode 100644 netx/internal/dialer/dialer_test.go create mode 100644 netx/internal/dialer/dialerbase/dialerbase.go create mode 100644 netx/internal/dialer/dialerbase/dialerbase_test.go create mode 100644 netx/internal/dialer/dnsdialer/dnsdialer.go create mode 100644 netx/internal/dialer/dnsdialer/dnsdialer_test.go create mode 100644 netx/internal/dialer/tlsdialer/tlsdialer.go create mode 100644 netx/internal/dialer/tlsdialer/tlsdialer_test.go create mode 100644 netx/internal/dialid/dialid.go create mode 100644 netx/internal/dialid/dialid_test.go create mode 100644 netx/internal/errwrapper/errwrapper.go create mode 100644 netx/internal/errwrapper/errwrapper_test.go create mode 100644 netx/internal/httptransport/bodytracer/bodytracer.go create mode 100644 netx/internal/httptransport/bodytracer/bodytracer_test.go create mode 100644 netx/internal/httptransport/httptransport.go create mode 100644 netx/internal/httptransport/httptransport_test.go create mode 100644 netx/internal/httptransport/tracetripper/tracetripper.go create mode 100644 netx/internal/httptransport/tracetripper/tracetripper_test.go create mode 100644 netx/internal/httptransport/transactioner/transactioner.go create mode 100644 netx/internal/httptransport/transactioner/transactioner_test.go create mode 100644 netx/internal/internal.go create mode 100644 netx/internal/internal_test.go create mode 100644 netx/internal/resolver/bogondetector/bogondetector.go create mode 100644 netx/internal/resolver/bogondetector/bogondetector_test.go create mode 100644 netx/internal/resolver/brokenresolver/brokenresolver.go create mode 100644 netx/internal/resolver/brokenresolver/brokenresolver_test.go create mode 100644 netx/internal/resolver/chainresolver/chainresolver.go create mode 100644 netx/internal/resolver/chainresolver/chainresolver_test.go create mode 100644 netx/internal/resolver/dnstransport/dnsoverhttps/dnsoverhttps.go create mode 100644 netx/internal/resolver/dnstransport/dnsoverhttps/dnsoverhttps_test.go create mode 100644 netx/internal/resolver/dnstransport/dnsovertcp/dnsovertcp.go create mode 100644 netx/internal/resolver/dnstransport/dnsovertcp/dnsovertcp_test.go create mode 100644 netx/internal/resolver/dnstransport/dnsoverudp/dnsoverudp.go create mode 100644 netx/internal/resolver/dnstransport/dnsoverudp/dnsoverudp_test.go create mode 100644 netx/internal/resolver/ooniresolver/ooniresolver.go create mode 100644 netx/internal/resolver/ooniresolver/ooniresolver_test.go create mode 100644 netx/internal/resolver/parentresolver/parentresolver.go create mode 100644 netx/internal/resolver/parentresolver/parentresolver_test.go create mode 100644 netx/internal/resolver/resolver.go create mode 100644 netx/internal/resolver/resolver_test.go create mode 100644 netx/internal/resolver/systemresolver/systemresolver.go create mode 100644 netx/internal/resolver/systemresolver/systemresolver_test.go create mode 100644 netx/internal/transactionid/transactionid.go create mode 100644 netx/internal/transactionid/transactionid_test.go create mode 100644 netx/modelx/modelx.go create mode 100644 netx/modelx/modelx_test.go create mode 100644 netx/netx.go create mode 100644 netx/netx_test.go create mode 100644 netx/testdata/cacert.pem diff --git a/experiment/sniblocking/sniblocking.go b/experiment/sniblocking/sniblocking.go index d8040b99..23ee2905 100644 --- a/experiment/sniblocking/sniblocking.go +++ b/experiment/sniblocking/sniblocking.go @@ -12,13 +12,13 @@ import ( "net/url" "time" - "github.com/ooni/netx/modelx" "github.com/ooni/probe-engine/experiment" "github.com/ooni/probe-engine/experiment/handler" "github.com/ooni/probe-engine/internal/netxlogger" "github.com/ooni/probe-engine/internal/oonidatamodel" "github.com/ooni/probe-engine/internal/oonitemplates" "github.com/ooni/probe-engine/model" + "github.com/ooni/probe-engine/netx/modelx" "github.com/ooni/probe-engine/session" ) diff --git a/experiment/telegram/telegram_test.go b/experiment/telegram/telegram_test.go index 512fbc07..8d84f0ac 100644 --- a/experiment/telegram/telegram_test.go +++ b/experiment/telegram/telegram_test.go @@ -7,11 +7,11 @@ import ( "testing" "github.com/apex/log" - "github.com/ooni/netx/modelx" "github.com/ooni/probe-engine/experiment/handler" "github.com/ooni/probe-engine/internal/kvstore" "github.com/ooni/probe-engine/internal/oonitemplates" "github.com/ooni/probe-engine/model" + "github.com/ooni/probe-engine/netx/modelx" "github.com/ooni/probe-engine/session" ) diff --git a/go.mod b/go.mod index a738d1ad..75ed9191 100644 --- a/go.mod +++ b/go.mod @@ -37,9 +37,9 @@ require ( github.com/m-lab/go v1.2.2 github.com/m-lab/ndt7-client-go v0.2.0 github.com/marusama/semaphore v0.0.0-20171214154724-565ffd8e868a // indirect + github.com/miekg/dns v1.1.27 github.com/montanaflynn/stats v0.6.3 github.com/neubot/dash v0.4.1 - github.com/ooni/netx v0.0.0-20200211124352-4f8d645bce64 github.com/oschwald/geoip2-golang v1.4.0 github.com/patrickmn/go-cache v2.1.0+incompatible // indirect github.com/pborman/getopt v0.0.0-20190409184431-ee0cd42419d3 @@ -57,6 +57,6 @@ require ( go.uber.org/multierr v1.1.1-0.20180122172545-ddea229ff1df // indirect go.uber.org/zap v1.9.2-0.20180814183419-67bc79d13d15 // indirect golang.org/x/crypto v0.0.0-20200214034016-1d94cc7ab1c6 // indirect - golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa // indirect + golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa golang.org/x/sys v0.0.0-20200219091948-cb0a6d8edb6c // indirect ) diff --git a/go.sum b/go.sum index c9fbcfc7..3817410b 100644 --- a/go.sum +++ b/go.sum @@ -259,8 +259,6 @@ github.com/onsi/gomega v1.5.0 h1:izbySO9zDPmjJ8rDjLvkA2zJHIo+HkYXHnf7eN7SSyo= github.com/onsi/gomega v1.5.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/onsi/gomega v1.7.0 h1:XPnZz8VVBHjVsy1vzJmRwIcSwiUO+JFfrv/xGiigmME= github.com/onsi/gomega v1.7.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= -github.com/ooni/netx v0.0.0-20200211124352-4f8d645bce64 h1:rLXhIycJB+yGJPMRY1tP4KzYqd/EvhrP8x7CjbeKmJc= -github.com/ooni/netx v0.0.0-20200211124352-4f8d645bce64/go.mod h1:vTJ7nYH2j51lX8yvhdLaHscvwz1GvWoiB26jYKEmqyA= github.com/openconfig/gnmi v0.0.0-20190823184014-89b2bf29312c/go.mod h1:t+O9It+LKzfOAhKTT5O0ehDix+MTqbtT0T9t+7zzOvc= github.com/openconfig/reference v0.0.0-20190727015836-8dfd928c9696/go.mod h1:ym2A+zigScwkSEb/cVQB0/ZMpU3rqiH6X7WRRsxgOGw= github.com/oschwald/geoip2-golang v1.4.0 h1:5RlrjCgRyIGDz/mBmPfnAF4h8k0IAcRv9PvrpOfz+Ug= @@ -369,8 +367,6 @@ golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190829043050-9756ffdc2472/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20191227163750-53104e6ec876 h1:sKJQZMuxjOAR/Uo2LBfU90onWEf1dF4C+0hPJCc9Mpc= -golang.org/x/crypto v0.0.0-20191227163750-53104e6ec876/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200214034016-1d94cc7ab1c6 h1:Sy5bstxEqwwbYs6n0/pBuxKENqOeZUgD45Gp3Q3pqLg= golang.org/x/crypto v0.0.0-20200214034016-1d94cc7ab1c6/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -456,8 +452,6 @@ golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20191220142924-d4481acd189f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191224085550-c709ea063b76/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200107162124-548cf772de50 h1:YvQ10rzcqWXLlJZ3XCUoO25savxmscf4+SC+ZqiCHhA= -golang.org/x/sys v0.0.0-20200107162124-548cf772de50/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200219091948-cb0a6d8edb6c h1:jceGD5YNJGgGMkJz79agzOln1K9TaZUjv5ird16qniQ= golang.org/x/sys v0.0.0-20200219091948-cb0a6d8edb6c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= diff --git a/internal/httpx/httpx.go b/internal/httpx/httpx.go index 4f37bd78..d121cf94 100644 --- a/internal/httpx/httpx.go +++ b/internal/httpx/httpx.go @@ -6,9 +6,9 @@ import ( "crypto/tls" "crypto/x509" "io/ioutil" + "net" "net/http" "net/url" - "net" "time" "github.com/ooni/probe-engine/internal/httplog" diff --git a/internal/netxlogger/netxlogger.go b/internal/netxlogger/netxlogger.go index d9ab697d..1ebef308 100644 --- a/internal/netxlogger/netxlogger.go +++ b/internal/netxlogger/netxlogger.go @@ -8,8 +8,8 @@ import ( "net/http" "strings" - "github.com/ooni/netx/modelx" "github.com/ooni/probe-engine/internal/tlsx" + "github.com/ooni/probe-engine/netx/modelx" ) // Logger is the interface we expect from a logger diff --git a/internal/netxlogger/netxlogger_test.go b/internal/netxlogger/netxlogger_test.go index c001854e..0f1d7d95 100644 --- a/internal/netxlogger/netxlogger_test.go +++ b/internal/netxlogger/netxlogger_test.go @@ -6,7 +6,7 @@ import ( "github.com/apex/log" "github.com/apex/log/handlers/discard" - "github.com/ooni/netx/httpx" + "github.com/ooni/probe-engine/netx/httpx" ) func TestIntegration(t *testing.T) { diff --git a/internal/oonidatamodel/oonidatamodel.go b/internal/oonidatamodel/oonidatamodel.go index 5b8248d4..17c2490c 100644 --- a/internal/oonidatamodel/oonidatamodel.go +++ b/internal/oonidatamodel/oonidatamodel.go @@ -14,9 +14,9 @@ import ( "strings" "unicode/utf8" - "github.com/ooni/netx/modelx" "github.com/ooni/probe-engine/internal/oonitemplates" "github.com/ooni/probe-engine/internal/tlsx" + "github.com/ooni/probe-engine/netx/modelx" ) // TCPConnectStatus contains the TCP connect status. diff --git a/internal/oonidatamodel/oonidatamodel_test.go b/internal/oonidatamodel/oonidatamodel_test.go index adb76b77..787684a2 100644 --- a/internal/oonidatamodel/oonidatamodel_test.go +++ b/internal/oonidatamodel/oonidatamodel_test.go @@ -10,8 +10,8 @@ import ( "testing" "time" - "github.com/ooni/netx/modelx" "github.com/ooni/probe-engine/internal/oonitemplates" + "github.com/ooni/probe-engine/netx/modelx" ) func TestUnitNewTCPConnectListEmpty(t *testing.T) { diff --git a/internal/oonitemplates/oonitemplates.go b/internal/oonitemplates/oonitemplates.go index f14bfedb..afd65035 100644 --- a/internal/oonitemplates/oonitemplates.go +++ b/internal/oonitemplates/oonitemplates.go @@ -21,10 +21,10 @@ import ( goptlib "git.torproject.org/pluggable-transports/goptlib.git" "github.com/m-lab/go/rtx" - "github.com/ooni/netx" - "github.com/ooni/netx/handlers" - "github.com/ooni/netx/httpx" - "github.com/ooni/netx/modelx" + "github.com/ooni/probe-engine/netx" + "github.com/ooni/probe-engine/netx/handlers" + "github.com/ooni/probe-engine/netx/httpx" + "github.com/ooni/probe-engine/netx/modelx" "gitlab.com/yawning/obfs4.git/transports" obfs4base "gitlab.com/yawning/obfs4.git/transports/base" ) diff --git a/netx/DESIGN.md b/netx/DESIGN.md new file mode 100644 index 00000000..1ebcffec --- /dev/null +++ b/netx/DESIGN.md @@ -0,0 +1,404 @@ +# OONI Network Extensions + +| Author | Simone Basso | +|--------------|--------------| +| Last-Updated | 2020-03-03 | +| Status | approved | + +## Introduction + +OONI experiments send and/or receive network traffic to +determine if there is blocking. We want the implementation +of OONI experiments to be as simple as possible. We also +_want to attribute errors to the major network or protocol +operation that caused them_. + +At the same time, _we want an experiment to collect as much +low-level data as possible_. For example, we want to know +whether and when the TLS handshake completed; what certificates +were provided by the server; what TLS version was selected; +and so forth. These bits of information are very useful +to analyze a measurement and better classify it. + +We also want to _automatically or manually run follow-up +measurements where we change some configuration properties +and repeat the measurement_. For example, we may want to +configure DNS over HTTPS (DoH) and then attempt to +fetch again an URL. Or we may want to detect whether +there is SNI bases blocking. This package allows us to +do that in other parts of probe-engine. + +## Rationale + +As we observed [ooni/probe-engine#13]( +https://github.com/ooni/probe-engine/issues/13), every +experiment consists of two separate phases: + +1. measurement gathering + +2. measurement analysis + +During measurement gathering, we perform specific actions +that cause network data to be sent and/or received. During +measurement analysis, we process the measurement on the +device. For some experiments (e.g., Web Connectivity), this +second phase also entails contacting OONI backend services +that provide data useful to complete the analysis. + +This package implements measurement gathering. The analysis +is performed by other packages in probe-engine. The core +design idea is to provide OONI-measurements-aware replacements +for Go standard library interfaces, e.g., the +`http.RoundTripper`. On top of that, we'll create all the +required interfaces to achive the measurement goals mentioned above. + +We are of course writing test templates in `probe-engine` +anyway, because we need additional abstraction, but we can +take advantage of the fact that the API exposed by this package +is stable by definition, because it mimics the stdlib. Also, +for many experiments we can collect information pertaining +to TCP, DNS, TLS, and HTTP with a single call to `netx`. + +This code used to live at `github.com/ooni/netx`. On 2020-03-02 +we merged github.com/ooni/netx@4f8d645bce6466bb into `probe-engine` +because it was more practical and enabled easier refactoring. + +## Definitions + +Consistently with Go's terminology, we define +_HTTP round trip_ the process where we get a request +to send; we find a suitable connection for sending +it, or we create one; we send headers and +possibly body; and we receive response headers. + +We also define _HTTP transaction_ the process starting +with an HTTP round trip and terminating by reading +the full response body. + +We define _netx replacement_ a Go struct of interface that +has the same interface of a Go standard library object +but additionally performs measurements. + +## Enhanced error handling + +This library MUST wrap `error` such that: + +1. we can classify all errors we care about; and + +2. we can map them to major operations. + +The `github.com/ooni/netx/modelx` MUST contain a wrapper for +Go `error` named `ErrWrapper` that is at least like: + +```Go +type ErrWrapper struct { + Failure string // error classification + Operation string // operation that caused error + WrappedErr error // the original error +} + +func (e *ErrWrapper) Error() string { + return e.Failure +} +``` + +Where `Failure` is one of the errors we care about, i.e.: + +- `connection_refused`: ECONNREFUSED +- `connection_reset`: ECONNRESET +- `dns_bogon_error`: detected bogon in DNS reply +- `dns_nxdomain_error`: NXDOMAIN in DNS reply +- `eof_error`: unexpected EOF on connection +- `generic_timeout_error`: some timer has expired +- `ssl_invalid_hostname`: certificate not valid for SNI +- `ssl_unknown_autority`: cannot find CA validating certificate +- `ssl_invalid_certificate`: e.g. certificate expired +- `unknown_failure `: any other error + +Note that we care about bogons in DNS replies because they are +often used to censor specific websites. + +And where `Operation` is one of: + +- `resolve`: domain name resolution +- `connect`: TCP connect +- `tls_handshake`: TLS handshake +- `http_round_trip`: reading/writing HTTP + +The code in this library MUST wrap returned errors such +that we can cast back to `ErrWrapper` during the analysis +phase, using Go 1.13 `errors` library as follows: + +```Go +var wrapper *modelx.ErrWrapper +if errors.As(err, &wrapper) == true { + // Do something with the error +} +``` + +## Netx replacements + +We want to provide netx replacements for the following +interfaces in the Go standard library: + +1. `http.RoundTripper` + +2. `http.Client` + +3. `net.Dialer` + +4. `net.Resolver` + +Accordingly, we'll define the following interfaces in +the `github.com/ooni/probe-engine/netx/modelx` package: + +```Go +type DNSResolver interface { + LookupAddr(ctx context.Context, addr string) ([]string, error) + LookupCNAME(ctx context.Context, host string) (string, error) + LookupHost(ctx context.Context, hostname string) ([]string, error) + LookupMX(ctx context.Context, name string) ([]*net.MX, error) + LookupNS(ctx context.Context, name string) ([]*net.NS, error) +} + +type Dialer interface { + Dial(network, address string) (net.Conn, error) + DialContext(ctx context.Context, network, address string) (net.Conn, error) +} + +type TLSDialer interface { + DialTLS(network, address string) (net.Conn, error) + DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) +} +``` + +We won't need an interface for `http.RoundTripper` +because it is already an interface, so we'll just use it. + +Our replacements will implement these interfaces. + +Using an API compatible with Go's standard libary makes +it possible to use, say, our `net.Dialer` replacement with +other libraries. Both `http.Transport` and +`gorilla/websocket`'s `websocket.Dialer` have +functions like `Dial` and `DialContext` that can be +overriden. By overriding such function pointers, +we could use our replacements instead of the standard +libary, thus we could collect measurements while +using third party code to implement specific protocols. + +Also, using interfaces allows us to combine code +quite easily. For example, a resolver that detects +bogons is easily implemented as a wrapper around +another resolve that performs the real resolution. + +## Dispatching events + +The `github.com/ooni/netx/modelx` package will define +an handler for low level events as: + +```Go +type Handler interface { + OnMeasurement(Measurement) +} +``` + +We will provide a mechanism to bind a specific +handler to a `context.Context` such that the handler +will receive all the measurements caused by code +using such context. This mechanism is like: + +```Go +type MeasurementRoot struct { + Beginning time.Time // the "zero" time + Handler Handler // the handler to use +} +``` + +You will be able to assign a `MeasurementRoot` to +a context by using the following function: + +```Go +func WithMeasurementRoot( + ctx context.Context, root *MeasurementRoot) context.Context +``` + +which will return a clone of the original context +that uses the `MeasurementRoot`. Pass this context to +any method of our replacements to get measurements. + +Given such context, or a subcontext, you can get +back the original `MeasurementRoot` using: + +```Go +func ContextMeasurementRoot(ctx context.Context) *MeasurementRoot +``` + +which will return the context `MeasurementRoot` or +`nil` if none is set into the context. This is how our +internal code gets access to the `MeasurementRoot`. + +## Constructing and configuring replacements + +The `github.com/ooni/probe-engine/netx` package MUST provide an API such +that you can construct and configure a `net.Resolver` replacement +as follows: + +```Go +r, err := netx.NewResolverWithoutHandler(dnsNetwork, dnsAddress) +if err != nil { + log.Fatal("cannot configure specifc resolver") +} +var resolver modelx.DNSResolver = r +// now use resolver ... +``` + +where `DNSNetwork` and `DNSAddress` configure the type +of the resolver as follows: + +- when `DNSNetwork` is `""` or `"system"`, `DNSAddress` does +not matter and we use the system resolver + +- when `DNSNetwork` is `"udp"`, `DNSAddress` is the address +or domain name, with optional port, of the DNS server +(e.g., `8.8.8.8:53`) + +- when `DNSNetwork` is `"tcp"`, `DNSAddress` is the address +or domain name, with optional port, of the DNS server +(e.g., `8.8.8.8:53`) + +- when `DNSNetwork` is `"dot"`, `DNSAddress` is the address +or domain name, with optional port, of the DNS server +(e.g., `8.8.8.8:853`) + +- when `DNSNetwork` is `"doh"`, `DNSAddress` is the URL +of the DNS server (e.g. `https://cloudflare-dns.com/dns-query`) + +When the resolve is not the system one, we'll also be able +to emit events when performing resolution. Otherwise, we'll +just emit the `DNSResolveDone` event defined below. + +Any resolver returned by this function may be configured to return the +`dns_bogon_error` if any `LookupHost` lookup returns a bogon IP. + +The package will also contain this function: + +```Go +func ChainResolvers( + primary, secondary modelx.DNSResolver) modelx.DNSResolver +``` + +where you can create a new resolver where `secondary` will be +invoked whenever `primary` fails. This functionality allows +us to be more resilient and bypass automatically certain types +of censorship, e.g., a resolver returning a bogon. + +The `github.com/ooni/probe-engine/netx` package MUST also provide an API such +that you can construct and configure a `net.Dialer` replacement +as follows: + +```Go +d := netx.NewDialerWithoutHandler() +d.SetResolver(resolver) +d.ForceSpecificSNI("www.kernel.org") +d.SetCABundle("/etc/ssl/cert.pem") +d.ForceSkipVerify() +var dialer modelx.Dialer = d +// now use dialer +``` + +where `SetResolver` allows you to change the resolver, +`ForceSpecificSNI` forces the TLS dials to use such SNI +instead of using the provided domain, `SetCABundle` +allows to set a specific CA bundle, and `ForceSkipVerify` +allows to disable certificate verification. All these funcs +MUST NOT be invoked once you're using the dialer. + +The `github.com/ooni/probe-engine/netx/httpx` package MUST contain +code so that we can do: + +```Go +t := httpx.NewHTTPTransportWithProxyFunc( + http.ProxyFromEnvironment, +) +t.SetResolver(resolver) +t.ForceSpecificSNI("www.kernel.org") +t.SetCABundle("/etc/ssl/cert.pem") +t.ForceSkipVerify() +var transport http.RoundTripper = t +// now use transport +``` + +where the functions have the same semantics as the +namesake functions described before and the same caveats. + +We also have syntactic sugar on top of that and legacy +methods, but this fully describes the design. + +## Structure of events + +The `github.com/ooni/probe-engine/netx/modelx` will contain the +definition of low-level events. We are interested in +knowing the following: + +1. the timing and result of each I/O operation. + +2. the timing of HTTP events occurring during the +lifecycle of an HTTP request. + +3. the timing and result of the TLS handshake including +the negotiated TLS version and other details such as +what certificates the server has provided. + +4. DNS events, e.g. queries and replies, generated +as part of using DoT and DoH. + +We will represent time as a `time.Duration` since the +beginning configured either in the context or when +constructing an object. The `modelx` package will also +define the `Measurement` event as follows: + +```Go +type Measurement struct { + Connect *ConnectEvent + HTTPConnectionReady *HTTPConnectionReadyEvent + HTTPRoundTripDone *HTTPRoundTripDoneEvent + ResolveDone *ResolveDoneEvent + TLSHandshakeDone *TLSHandshakeDoneEvent +} +``` + +The events above MUST always be present, but more +events will likely be available. The structure +will contain a pointer for every event that +we support. The events processing code will check +what pointer or pointers are not `nil` to known +which event or events have occurred. + +To simplify joining events together the following holds: + +1. when we're establishing a new connection there is a nonzero +`DialID` shared by `Connect` and `ResolveDone` + +2. a new connection has a nonzero `ConnID` that is emitted +as part of a successful `Connect` event + +3. during an HTTP transaction there is a nonzero `TransactionID` +shared by `HTTPConnectionReady` and `HTTPRoundTripDone` + +4. if the TLS handshake is invoked by HTTP code it will have a +nonzero `TrasactionID` otherwise a nonzero `ConnID` + +5. the `HTTPConnectionReady` will also see the `ConnID` + +6. when a transaction starts dialing, it will pass its +`TransactionID` to `ResolveDone` and `Connect` + +7. when we're dialing a connection for DoH, we pass the `DialID` +to the `HTTPConnectionReady` event as well + +Because of the following rules, it should always be possible +to bind together events. Also, we define more events than the +above, but they are ancillary to the above events. Also, the +main reason why `HTTPConnectionReady` is here is because it is +the event allowing to bind `ConnID` and `TransactionID`. diff --git a/netx/README.md b/netx/README.md new file mode 100644 index 00000000..5d06de64 --- /dev/null +++ b/netx/README.md @@ -0,0 +1,23 @@ +# github.com/ooni/probe-engine/netx + +OONI extensions to the `net` and `net/http` packages. This code is +used by `ooni/probe-engine` as a low level library to collect +network, DNS, and HTTP events occurring during OONI measurements. + +This library contains replacements for commonly used standard library +interfaces that facilitate seamless network measurements. By using +such replacements, as opposed to standard library interfaces, we can: + +* save the timing of HTTP events (e.g. received response headers) +* save the timing and result of every Connect, Read, Write, Close operation +* save the timing and result of the TLS handshake (including certificates) + +By default, this library uses the system resolver. In addition, it +is possible to configure alternative DNS transports and remote +servers. We support DNS over UDP, DNS over TCP, DNS over TLS (DoT), +and DNS over HTTPS (DoH). When using an alternative transport, we +are also able to intercept and save DNS messages, as well as any +other interaction with the remote server (e.g., the result of the +TLS handshake for DoT and DoH). + +This package is a fork of [github.com/ooni/netx](https://github.com/ooni/netx). diff --git a/netx/handlers/handlers.go b/netx/handlers/handlers.go new file mode 100644 index 00000000..a36fcb30 --- /dev/null +++ b/netx/handlers/handlers.go @@ -0,0 +1,29 @@ +// Package handlers contains default modelx.Handler handlers. +package handlers + +import ( + "encoding/json" + "fmt" + + "github.com/m-lab/go/rtx" + "github.com/ooni/probe-engine/netx/modelx" +) + +type stdoutHandler struct{} + +func (stdoutHandler) OnMeasurement(m modelx.Measurement) { + data, err := json.Marshal(m) + rtx.Must(err, "unexpected json.Marshal failure") + fmt.Printf("%s\n", string(data)) +} + +// StdoutHandler is a Handler that logs on stdout. +var StdoutHandler stdoutHandler + +type noHandler struct{} + +func (noHandler) OnMeasurement(m modelx.Measurement) { +} + +// NoHandler is a Handler that does not print anything +var NoHandler noHandler diff --git a/netx/handlers/handlers_test.go b/netx/handlers/handlers_test.go new file mode 100644 index 00000000..3fa2f7a9 --- /dev/null +++ b/netx/handlers/handlers_test.go @@ -0,0 +1,13 @@ +package handlers_test + +import ( + "testing" + + "github.com/ooni/probe-engine/netx/handlers" + "github.com/ooni/probe-engine/netx/modelx" +) + +func TestIntegration(t *testing.T) { + handlers.NoHandler.OnMeasurement(modelx.Measurement{}) + handlers.StdoutHandler.OnMeasurement(modelx.Measurement{}) +} diff --git a/netx/httpx/httpx.go b/netx/httpx/httpx.go new file mode 100644 index 00000000..148d5c19 --- /dev/null +++ b/netx/httpx/httpx.go @@ -0,0 +1,155 @@ +// Package httpx contains OONI's net/http extensions. It defines the Client and +// the Transport replacements that we should use in OONI. They emit measurements +// collected at network and HTTP level using a specific handler. +package httpx + +import ( + "net/http" + "net/url" + "time" + + "github.com/ooni/probe-engine/netx/handlers" + "github.com/ooni/probe-engine/netx/internal" + "github.com/ooni/probe-engine/netx/modelx" +) + +// Transport performs measurements during HTTP round trips. +type Transport struct { + dialer *internal.Dialer + transport *internal.HTTPTransport +} + +func newTransport( + beginning time.Time, handler modelx.Handler, + proxyFunc func(*http.Request) (*url.URL, error), +) *Transport { + t := new(Transport) + t.dialer = internal.NewDialer(beginning, handler) + t.transport = internal.NewHTTPTransport( + beginning, + handler, + t.dialer, + false, // DisableKeepAlives + proxyFunc, + ) + return t +} + +// NewTransportWithProxyFunc creates a transport without any +// handler attached using the specified proxy func. +func NewTransportWithProxyFunc( + proxyFunc func(*http.Request) (*url.URL, error), +) *Transport { + return newTransport(time.Now(), handlers.NoHandler, proxyFunc) +} + +// NewTransport creates a new Transport. The beginning argument is +// the time to use as zero for computing the elapsed time. +func NewTransport(beginning time.Time, handler modelx.Handler) *Transport { + return newTransport(beginning, handler, http.ProxyFromEnvironment) +} + +// RoundTrip executes a single HTTP transaction, returning +// a Response for the provided Request. +func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { + return t.transport.RoundTrip(req) +} + +// CloseIdleConnections closes any connections which were previously connected +// from previous requests but are now sitting idle in a "keep-alive" state. It +// does not interrupt any connections currently in use. +func (t *Transport) CloseIdleConnections() { + t.transport.CloseIdleConnections() +} + +// ConfigureDNS is exactly like netx.Dialer.ConfigureDNS. +func (t *Transport) ConfigureDNS(network, address string) error { + return t.dialer.ConfigureDNS(network, address) +} + +// SetResolver is exactly like netx.Dialer.SetResolver. +func (t *Transport) SetResolver(r modelx.DNSResolver) { + t.dialer.SetResolver(r) +} + +// SetCABundle internally calls netx.Dialer.SetCABundle and +// therefore it has the same caveats and limitations. +func (t *Transport) SetCABundle(path string) error { + return t.dialer.SetCABundle(path) +} + +// ForceSpecificSNI forces using a specific SNI. +func (t *Transport) ForceSpecificSNI(sni string) error { + return t.dialer.ForceSpecificSNI(sni) +} + +// ForceSkipVerify forces to skip certificate verification +func (t *Transport) ForceSkipVerify() error { + return t.dialer.ForceSkipVerify() +} + +// Client is a replacement for http.Client. +type Client struct { + // HTTPClient is the underlying client. Pass this client to existing code + // that expects an *http.HTTPClient. For this reason we can't embed it. + HTTPClient *http.Client + + // Transport is the transport configured by NewClient to be used + // by the HTTPClient field. + Transport *Transport +} + +// NewClientWithProxyFunc creates a new client using the +// specified proxyFunc for handling proxying. +func NewClientWithProxyFunc( + handler modelx.Handler, + proxyFunc func(*http.Request) (*url.URL, error), +) *Client { + transport := newTransport(time.Now(), handler, proxyFunc) + return &Client{ + HTTPClient: &http.Client{ + Transport: transport, + }, + Transport: transport, + } +} + +// NewClient creates a new client instance. +func NewClient(handler modelx.Handler) *Client { + return NewClientWithProxyFunc(handler, http.ProxyFromEnvironment) +} + +// NewClientWithoutProxy creates a client without any +// configured proxy attached to it. This is suitable +// for measurements where you don't want a proxy to be +// in the middle and alter the measurements. +func NewClientWithoutProxy(handler modelx.Handler) *Client { + return NewClientWithProxyFunc(handler, nil) +} + +// ConfigureDNS internally calls netx.Dialer.ConfigureDNS and +// therefore it has the same caveats and limitations. +func (c *Client) ConfigureDNS(network, address string) error { + return c.Transport.ConfigureDNS(network, address) +} + +// SetResolver internally calls netx.Dialer.SetResolver +func (c *Client) SetResolver(r modelx.DNSResolver) { + c.Transport.SetResolver(r) +} + +// SetCABundle internally calls netx.Dialer.SetCABundle and +// therefore it has the same caveats and limitations. +func (c *Client) SetCABundle(path string) error { + return c.Transport.SetCABundle(path) +} + +// ForceSpecificSNI forces using a specific SNI. +func (c *Client) ForceSpecificSNI(sni string) error { + return c.Transport.ForceSpecificSNI(sni) +} + +// ForceSkipVerify forces to skip certificate verification +func (c *Client) ForceSkipVerify() error { + return c.Transport.ForceSkipVerify() +} diff --git a/netx/httpx/httpx_test.go b/netx/httpx/httpx_test.go new file mode 100644 index 00000000..37046e3e --- /dev/null +++ b/netx/httpx/httpx_test.go @@ -0,0 +1,144 @@ +package httpx_test + +import ( + "io/ioutil" + "net" + "net/http" + "net/http/httptest" + "os" + "sync/atomic" + "testing" + "time" + + "github.com/ooni/probe-engine/netx/handlers" + "github.com/ooni/probe-engine/netx/httpx" +) + +func TestIntegration(t *testing.T) { + client := httpx.NewClientWithoutProxy(handlers.NoHandler) + defer client.Transport.CloseIdleConnections() + err := client.ConfigureDNS("udp", "1.1.1.1:53") + if err != nil { + t.Fatal(err) + } + resp, err := client.HTTPClient.Get("https://www.google.com") + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } +} + +func TestIntegrationSetResolver(t *testing.T) { + client := httpx.NewClientWithoutProxy(handlers.NoHandler) + defer client.Transport.CloseIdleConnections() + client.SetResolver(new(net.Resolver)) + resp, err := client.HTTPClient.Get("https://www.google.com") + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } +} +func TestSetCABundle(t *testing.T) { + client := httpx.NewClientWithoutProxy(handlers.NoHandler) + err := client.SetCABundle("../testdata/cacert.pem") + if err != nil { + t.Fatal(err) + } +} + +func TestForceSpecificSNI(t *testing.T) { + client := httpx.NewClientWithoutProxy(handlers.NoHandler) + err := client.ForceSpecificSNI("www.facebook.com") + if err != nil { + t.Fatal(err) + } + resp, err := client.HTTPClient.Get("https://www.google.com") + if err == nil { + t.Fatal("expected an error here") + } + // TODO(bassosimone): how to unwrap the error in Go < 1.13? Anyway we are + // already testing we're getting the right error in netx_test.go. + t.Log(err) + if resp != nil { + t.Fatal("expected a nil response here") + } +} + +func TestForceSkipVerify(t *testing.T) { + client := httpx.NewClientWithoutProxy(handlers.NoHandler) + client.ForceSkipVerify() + resp, err := client.HTTPClient.Get("https://self-signed.badssl.com/") + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected non nil response here") + } +} + +func TestNewClientWithoutProxy(t *testing.T) { + client := httpx.NewClientWithoutProxy(handlers.NoHandler) + proxyTestMain(t, client.HTTPClient, 200) +} + +func TestNewClientHonoursProxy(t *testing.T) { + client := httpx.NewClient(handlers.NoHandler) + proxyTestMain(t, client.HTTPClient, 451) +} + +func TestNewTransportHonoursProxy(t *testing.T) { + transport := httpx.NewTransport( + time.Now(), handlers.NoHandler, + ) + client := &http.Client{Transport: transport} + proxyTestMain(t, client, 451) +} + +func TestNewTransportWithoutAnyProxy(t *testing.T) { + transport := httpx.NewTransportWithProxyFunc(nil) + client := &http.Client{Transport: transport} + proxyTestMain(t, client, 200) +} + +func proxyTestMain(t *testing.T, client *http.Client, expect int) { + req, err := http.NewRequest("GET", "http://www.google.com", nil) + if err != nil { + t.Fatal(err) + } + resp, err := client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != expect { + t.Fatal("unexpected status code") + } +} + +var ( + proxyServer *httptest.Server + proxyCount int64 +) + +func TestMain(m *testing.M) { + server := httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&proxyCount, 1) + w.WriteHeader(451) + })) + defer server.Close() + os.Setenv("HTTP_PROXY", server.URL) + os.Exit(m.Run()) +} diff --git a/netx/internal/connid/connid.go b/netx/internal/connid/connid.go new file mode 100644 index 00000000..5d37ba70 --- /dev/null +++ b/netx/internal/connid/connid.go @@ -0,0 +1,31 @@ +// Package connid contains code to generate the connectionID +package connid + +import ( + "net" + "strconv" + "strings" +) + +// Compute computes the connectionID from the local socket address. The zero +// value is conventionally returned to mean "unknown". +func Compute(network, address string) int64 { + _, portstring, err := net.SplitHostPort(address) + if err != nil { + return 0 + } + portnum, err := strconv.Atoi(portstring) + if err != nil { + return 0 + } + if portnum < 0 || portnum > 65535 { + return 0 + } + result := int64(portnum) + if strings.Contains(network, "udp") { + result *= -1 + } else if !strings.Contains(network, "tcp") { + result = 0 + } + return result +} diff --git a/netx/internal/connid/connid_test.go b/netx/internal/connid/connid_test.go new file mode 100644 index 00000000..e4f32017 --- /dev/null +++ b/netx/internal/connid/connid_test.go @@ -0,0 +1,80 @@ +package connid + +import "testing" + +func TestIntegrationTCP(t *testing.T) { + num := Compute("tcp", "1.2.3.4:6789") + if num != 6789 { + t.Fatal("unexpected result") + } +} + +func TestIntegrationTCP4(t *testing.T) { + num := Compute("tcp4", "130.192.91.211:34566") + if num != 34566 { + t.Fatal("unexpected result") + } +} + +func TestIntegrationTCP6(t *testing.T) { + num := Compute("tcp4", "[::1]:4444") + if num != 4444 { + t.Fatal("unexpected result") + } +} + +func TestIntegrationUDP(t *testing.T) { + num := Compute("udp", "1.2.3.4:6789") + if num != -6789 { + t.Fatal("unexpected result") + } +} + +func TestIntegrationUDP4(t *testing.T) { + num := Compute("udp4", "130.192.91.211:34566") + if num != -34566 { + t.Fatal("unexpected result") + } +} + +func TestIntegrationUDP6(t *testing.T) { + num := Compute("udp6", "[::1]:4444") + if num != -4444 { + t.Fatal("unexpected result") + } +} + +func TestIntegrationInvalidAddress(t *testing.T) { + num := Compute("udp6", "[::1]") + if num != 0 { + t.Fatal("unexpected result") + } +} + +func TestIntegrationInvalidPort(t *testing.T) { + num := Compute("udp6", "[::1]:antani") + if num != 0 { + t.Fatal("unexpected result") + } +} + +func TestIntegrationNegativePort(t *testing.T) { + num := Compute("udp6", "[::1]:-1") + if num != 0 { + t.Fatal("unexpected result") + } +} + +func TestIntegrationLargePort(t *testing.T) { + num := Compute("udp6", "[::1]:65536") + if num != 0 { + t.Fatal("unexpected result") + } +} + +func TestIntegrationInvalidNetwork(t *testing.T) { + num := Compute("unix", "[::1]:65531") + if num != 0 { + t.Fatal("unexpected result") + } +} diff --git a/netx/internal/dialer/connx/connx.go b/netx/internal/dialer/connx/connx.go new file mode 100644 index 00000000..d70ffea1 --- /dev/null +++ b/netx/internal/dialer/connx/connx.go @@ -0,0 +1,83 @@ +// Package connx contains net.Conn extensions +package connx + +import ( + "net" + "time" + + "github.com/ooni/probe-engine/netx/internal/errwrapper" + "github.com/ooni/probe-engine/netx/modelx" +) + +// MeasuringConn is a net.Conn used to perform measurements +type MeasuringConn struct { + net.Conn + Beginning time.Time + Handler modelx.Handler + ID int64 +} + +// Read reads data from the connection. +func (c *MeasuringConn) Read(b []byte) (n int, err error) { + start := time.Now() + n, err = c.Conn.Read(b) + err = errwrapper.SafeErrWrapperBuilder{ + ConnID: c.ID, + Error: err, + Operation: "read", + }.MaybeBuild() + stop := time.Now() + c.Handler.OnMeasurement(modelx.Measurement{ + Read: &modelx.ReadEvent{ + ConnID: c.ID, + DurationSinceBeginning: stop.Sub(c.Beginning), + Error: err, + NumBytes: int64(n), + SyscallDuration: stop.Sub(start), + }, + }) + return +} + +// Write writes data to the connection +func (c *MeasuringConn) Write(b []byte) (n int, err error) { + start := time.Now() + n, err = c.Conn.Write(b) + err = errwrapper.SafeErrWrapperBuilder{ + ConnID: c.ID, + Error: err, + Operation: "write", + }.MaybeBuild() + stop := time.Now() + c.Handler.OnMeasurement(modelx.Measurement{ + Write: &modelx.WriteEvent{ + ConnID: c.ID, + DurationSinceBeginning: stop.Sub(c.Beginning), + Error: err, + NumBytes: int64(n), + SyscallDuration: stop.Sub(start), + }, + }) + return +} + +// Close closes the connection +func (c *MeasuringConn) Close() (err error) { + start := time.Now() + err = c.Conn.Close() + err = errwrapper.SafeErrWrapperBuilder{ + ConnID: c.ID, + Error: err, + Operation: "close", + }.MaybeBuild() + stop := time.Now() + c.Handler.OnMeasurement(modelx.Measurement{ + Close: &modelx.CloseEvent{ + ConnID: c.ID, + DurationSinceBeginning: stop.Sub(c.Beginning), + Error: err, + SyscallDuration: stop.Sub(start), + }, + }) + return +} diff --git a/netx/internal/dialer/connx/connx_test.go b/netx/internal/dialer/connx/connx_test.go new file mode 100644 index 00000000..c816c860 --- /dev/null +++ b/netx/internal/dialer/connx/connx_test.go @@ -0,0 +1,61 @@ +package connx + +import ( + "net" + "testing" + "time" + + "github.com/ooni/probe-engine/netx/handlers" +) + +func TestIntegrationMeasuringConn(t *testing.T) { + conn := net.Conn(&MeasuringConn{ + Conn: fakeconn{}, + Handler: handlers.NoHandler, + }) + defer conn.Close() + data := make([]byte, 1<<17) + n, err := conn.Read(data) + if err != nil { + t.Fatal(err) + } + if n != len(data) { + t.Fatal("invalid number of bytes read") + } + n, err = conn.Write(data) + if err != nil { + t.Fatal(err) + } + if n != len(data) { + t.Fatal("invalid number of bytes written") + } +} + +type fakeconn struct{} + +func (fakeconn) Read(b []byte) (n int, err error) { + n = len(b) + return +} +func (fakeconn) Write(b []byte) (n int, err error) { + n = len(b) + return +} +func (fakeconn) Close() (err error) { + return +} +func (fakeconn) LocalAddr() net.Addr { + return &net.TCPAddr{} +} +func (fakeconn) RemoteAddr() net.Addr { + return &net.TCPAddr{} +} +func (fakeconn) SetDeadline(t time.Time) (err error) { + return +} +func (fakeconn) SetReadDeadline(t time.Time) (err error) { + return +} +func (fakeconn) SetWriteDeadline(t time.Time) (err error) { + return +} diff --git a/netx/internal/dialer/dialer.go b/netx/internal/dialer/dialer.go new file mode 100644 index 00000000..317b577c --- /dev/null +++ b/netx/internal/dialer/dialer.go @@ -0,0 +1,21 @@ +// Package dialer contains the dialer's API. The dialer defined +// in here implements basic DNS, but that is overridable. +package dialer + +import ( + "crypto/tls" + + "github.com/ooni/probe-engine/netx/internal/dialer/dnsdialer" + "github.com/ooni/probe-engine/netx/internal/dialer/tlsdialer" + "github.com/ooni/probe-engine/netx/modelx" +) + +// New creates a new modelx.Dialer +func New(resolver modelx.DNSResolver, dialer modelx.Dialer) *dnsdialer.Dialer { + return dnsdialer.New(resolver, dialer) +} + +// NewTLS creates a new modelx.TLSDialer +func NewTLS(dialer modelx.Dialer, config *tls.Config) *tlsdialer.TLSDialer { + return tlsdialer.New(dialer, config) +} diff --git a/netx/internal/dialer/dialer_test.go b/netx/internal/dialer/dialer_test.go new file mode 100644 index 00000000..dac9cc84 --- /dev/null +++ b/netx/internal/dialer/dialer_test.go @@ -0,0 +1,33 @@ +package dialer + +import ( + "crypto/tls" + "net" + "testing" + + "github.com/ooni/probe-engine/netx/modelx" +) + +func TestIntegrationNew(t *testing.T) { + var dialer modelx.Dialer = New(new(net.Resolver), new(net.Dialer)) + conn, err := dialer.Dial("tcp", "www.kernel.org:80") + if err != nil { + t.Fatal(err) + } + if conn == nil { + t.Fatal("expected non-nil conn") + } + conn.Close() +} + +func TestIntegrationNewTLS(t *testing.T) { + var dialer modelx.TLSDialer = NewTLS(new(net.Dialer), new(tls.Config)) + conn, err := dialer.DialTLS("tcp", "www.kernel.org:443") + if err != nil { + t.Fatal(err) + } + if conn == nil { + t.Fatal("expected non-nil conn") + } + conn.Close() +} diff --git a/netx/internal/dialer/dialerbase/dialerbase.go b/netx/internal/dialer/dialerbase/dialerbase.go new file mode 100644 index 00000000..2f550d1e --- /dev/null +++ b/netx/internal/dialer/dialerbase/dialerbase.go @@ -0,0 +1,97 @@ +// Package dialerbase contains the base dialer functionality. We connect +// to a remote endpoint, but we don't support DNS. +package dialerbase + +import ( + "context" + "net" + "time" + + "github.com/ooni/probe-engine/netx/internal/connid" + "github.com/ooni/probe-engine/netx/internal/dialer/connx" + "github.com/ooni/probe-engine/netx/internal/errwrapper" + "github.com/ooni/probe-engine/netx/internal/transactionid" + "github.com/ooni/probe-engine/netx/modelx" +) + +// Dialer is a net.Dialer that is only able to connect to +// remote TCP/UDP endpoints. DNS is not supported. +type Dialer struct { + dialer modelx.Dialer + beginning time.Time + handler modelx.Handler + dialID int64 +} + +// New creates a new dialer +func New( + beginning time.Time, + handler modelx.Handler, + dialer modelx.Dialer, + dialID int64, +) *Dialer { + return &Dialer{ + dialer: dialer, + beginning: beginning, + handler: handler, + dialID: dialID, + } +} + +// Dial creates a TCP or UDP connection. See net.Dial docs. +func (d *Dialer) Dial(network, address string) (net.Conn, error) { + return d.DialContext(context.Background(), network, address) +} + +// DialContext dials a new connection with context. +func (d *Dialer) DialContext( + ctx context.Context, network, address string, +) (net.Conn, error) { + // this is the same timeout used by Go's net/http.DefaultTransport + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + start := time.Now() + conn, err := d.dialer.DialContext(ctx, network, address) + stop := time.Now() + err = errwrapper.SafeErrWrapperBuilder{ + // ConnID does not make any sense if we've failed and the error + // does not make any sense (and is nil) if we succeded. + DialID: d.dialID, + Error: err, + Operation: "connect", + }.MaybeBuild() + connID := safeConnID(network, conn) + txID := transactionid.ContextTransactionID(ctx) + d.handler.OnMeasurement(modelx.Measurement{ + Connect: &modelx.ConnectEvent{ + ConnID: connID, + DialID: d.dialID, + DurationSinceBeginning: stop.Sub(d.beginning), + Error: err, + Network: network, + RemoteAddress: address, + SyscallDuration: stop.Sub(start), + TransactionID: txID, + }, + }) + if err != nil { + return nil, err + } + return &connx.MeasuringConn{ + Conn: conn, + Beginning: d.beginning, + Handler: d.handler, + ID: connID, + }, nil +} + +func safeLocalAddress(conn net.Conn) (s string) { + if conn != nil && conn.LocalAddr() != nil { + s = conn.LocalAddr().String() + } + return +} + +func safeConnID(network string, conn net.Conn) int64 { + return connid.Compute(network, safeLocalAddress(conn)) +} diff --git a/netx/internal/dialer/dialerbase/dialerbase_test.go b/netx/internal/dialer/dialerbase/dialerbase_test.go new file mode 100644 index 00000000..c8675e86 --- /dev/null +++ b/netx/internal/dialer/dialerbase/dialerbase_test.go @@ -0,0 +1,43 @@ +package dialerbase + +import ( + "context" + "net" + "testing" + "time" + + "github.com/ooni/probe-engine/netx/handlers" + "github.com/ooni/probe-engine/netx/modelx" +) + +func TestIntegrationSuccess(t *testing.T) { + dialer := newdialer() + conn, err := dialer.Dial("tcp", "8.8.8.8:53") + if err != nil { + t.Fatal(err) + } + conn.Close() +} + +func TestIntegrationErrorNoConnect(t *testing.T) { + dialer := newdialer() + ctx, cancel := context.WithTimeout(context.Background(), 1) + defer cancel() + conn, err := dialer.DialContext(ctx, "tcp", "8.8.8.8:53") + if err == nil { + t.Fatal("expected an error here") + } + if ctx.Err() == nil { + t.Fatal("expected context to be expired here") + } + if conn != nil { + t.Fatal("expected nil conn here") + } +} + +// see whether we implement the interface +func newdialer() modelx.Dialer { + return New( + time.Now(), handlers.NoHandler, new(net.Dialer), 17, + ) +} diff --git a/netx/internal/dialer/dnsdialer/dnsdialer.go b/netx/internal/dialer/dnsdialer/dnsdialer.go new file mode 100644 index 00000000..625e8846 --- /dev/null +++ b/netx/internal/dialer/dnsdialer/dnsdialer.go @@ -0,0 +1,103 @@ +// Package dnsdialer contains a dialer with DNS lookups. +package dnsdialer + +import ( + "context" + "errors" + "net" + "strings" + + "github.com/ooni/probe-engine/netx/internal/dialer/dialerbase" + "github.com/ooni/probe-engine/netx/internal/dialid" + "github.com/ooni/probe-engine/netx/modelx" +) + +// Dialer defines the dialer API. We implement the most basic form +// of DNS, but more advanced resolutions are possible. +type Dialer struct { + dialer modelx.Dialer + resolver modelx.DNSResolver +} + +// New creates a new Dialer. +func New(resolver modelx.DNSResolver, dialer modelx.Dialer) (d *Dialer) { + return &Dialer{ + dialer: dialer, + resolver: resolver, + } +} + +// Dial creates a TCP or UDP connection. See net.Dial docs. +func (d *Dialer) Dial(network, address string) (net.Conn, error) { + return d.DialContext(context.Background(), network, address) +} + +// DialContext is like Dial but the context allows to interrupt a +// pending connection attempt at any time. +func (d *Dialer) DialContext( + ctx context.Context, network, address string, +) (conn net.Conn, err error) { + root := modelx.ContextMeasurementRootOrDefault(ctx) + onlyhost, onlyport, err := net.SplitHostPort(address) + if err != nil { + return nil, err + } + ctx = dialid.WithDialID(ctx) // important to create before lookupHost + dialID := dialid.ContextDialID(ctx) + var addrs []string + addrs, err = d.lookupHost(ctx, onlyhost) + if err != nil { + return + } + var errorslist []error + for _, addr := range addrs { + dialer := dialerbase.New( + root.Beginning, root.Handler, d.dialer, dialID, + ) + target := net.JoinHostPort(addr, onlyport) + conn, err = dialer.DialContext(ctx, network, target) + if err == nil { + return + } + errorslist = append(errorslist, err) + } + err = reduceErrors(errorslist) + return +} + +func reduceErrors(errorslist []error) error { + if len(errorslist) == 0 { + return nil + } + // If we have a know error, let's consider this the real error + // since it's probably most relevant. Otherwise let's return the + // first considering that (1) local resolvers likely will give + // us IPv4 first and (2) also our resolver does that. So, in case + // the user has no IPv6 connectivity, an IPv6 error is going to + // appear later in the list of errors. + for _, err := range errorslist { + var wrapper *modelx.ErrWrapper + if errors.As(err, &wrapper) && !strings.HasPrefix( + err.Error(), "unknown_error", + ) { + return err + } + } + // TODO(bassosimone): handle this case in a better way + return errorslist[0] +} + +func (d *Dialer) lookupHost( + ctx context.Context, hostname string, +) ([]string, error) { + if net.ParseIP(hostname) != nil { + return []string{hostname}, nil + } + root := modelx.ContextMeasurementRootOrDefault(ctx) + lookupHost := root.LookupHost + if root.LookupHost == nil { + lookupHost = d.resolver.LookupHost + } + addrs, err := lookupHost(ctx, hostname) + return addrs, err +} diff --git a/netx/internal/dialer/dnsdialer/dnsdialer_test.go b/netx/internal/dialer/dnsdialer/dnsdialer_test.go new file mode 100644 index 00000000..d507932f --- /dev/null +++ b/netx/internal/dialer/dnsdialer/dnsdialer_test.go @@ -0,0 +1,135 @@ +package dnsdialer + +import ( + "context" + "errors" + "net" + "testing" + "time" + + "github.com/ooni/probe-engine/netx/handlers" + "github.com/ooni/probe-engine/netx/modelx" +) + +func TestIntegrationDial(t *testing.T) { + dialer := newdialer() + conn, err := dialer.Dial("tcp", "www.google.com:80") + if err != nil { + t.Fatal(err) + } + conn.Close() +} + +func TestIntegrationDialAddress(t *testing.T) { + dialer := newdialer() + conn, err := dialer.Dial("tcp", "8.8.8.8:853") + if err != nil { + t.Fatal(err) + } + conn.Close() +} + +func TestIntegrationNoPort(t *testing.T) { + dialer := newdialer() + conn, err := dialer.Dial("tcp", "antani.ooni.io") + if err == nil { + t.Fatal("expected an error here") + } + if conn != nil { + t.Fatal("expected a nil conn here") + } +} + +func TestIntegrationLookupFailure(t *testing.T) { + dialer := newdialer() + conn, err := dialer.Dial("tcp", "antani.ooni.io:443") + if err == nil { + t.Fatal("expected an error here") + } + if conn != nil { + t.Fatal("expected a nil conn here") + } +} + +func TestIntegrationDialTCPFailure(t *testing.T) { + dialer := newdialer() + // The port is unreachable and filtered. The timeout is here + // to make sure that we don't run for too much time. + ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second) + defer cancel() + conn, err := dialer.DialContext(ctx, "tcp", "ooni.io:12345") + if err == nil { + t.Fatal("expected an error here") + } + if conn != nil { + t.Fatal("expected a nil conn here") + } +} + +func newdialer() modelx.Dialer { + return New(new(net.Resolver), new(net.Dialer)) +} + +func TestReduceErrors(t *testing.T) { + t.Run("no errors", func(t *testing.T) { + result := reduceErrors(nil) + if result != nil { + t.Fatal("wrong result") + } + }) + + t.Run("single error", func(t *testing.T) { + err := errors.New("mocked error") + result := reduceErrors([]error{err}) + if result != err { + t.Fatal("wrong result") + } + }) + + t.Run("multiple errors", func(t *testing.T) { + err1 := errors.New("mocked error #1") + err2 := errors.New("mocked error #2") + result := reduceErrors([]error{err1, err2}) + if result.Error() != "mocked error #1" { + t.Fatal("wrong result") + } + }) + + t.Run("multiple errors with meaningful ones", func(t *testing.T) { + err1 := errors.New("mocked error #1") + err2 := &modelx.ErrWrapper{ + Failure: "unknown_error: antani", + } + err3 := &modelx.ErrWrapper{ + Failure: "connection_refused", + } + err4 := errors.New("mocked error #3") + result := reduceErrors([]error{err1, err2, err3, err4}) + if result.Error() != "connection_refused" { + t.Fatal("wrong result") + } + }) +} + +func TestIntegrationDivertLookupHost(t *testing.T) { + dialer := newdialer() + failure := errors.New("mocked error") + root := &modelx.MeasurementRoot{ + Beginning: time.Now(), + Handler: handlers.NoHandler, + LookupHost: func(ctx context.Context, hostname string) ([]string, error) { + return nil, failure + }, + } + ctx := modelx.WithMeasurementRoot(context.Background(), root) + conn, err := dialer.DialContext(ctx, "tcp", "google.com:443") + if err == nil { + t.Fatal("expected an error here") + } + if !errors.Is(err, failure) { + t.Fatal("not the error we expected") + } + if conn != nil { + t.Fatal("expected a nil conn here") + } +} diff --git a/netx/internal/dialer/tlsdialer/tlsdialer.go b/netx/internal/dialer/tlsdialer/tlsdialer.go new file mode 100644 index 00000000..f4c79731 --- /dev/null +++ b/netx/internal/dialer/tlsdialer/tlsdialer.go @@ -0,0 +1,103 @@ +// Package tlsdialer contains the TLS dialer +package tlsdialer + +import ( + "context" + "crypto/tls" + "net" + "time" + + "github.com/ooni/probe-engine/netx/internal/dialer/connx" + "github.com/ooni/probe-engine/netx/internal/errwrapper" + "github.com/ooni/probe-engine/netx/modelx" +) + +// TLSDialer is the TLS dialer +type TLSDialer struct { + ConnectTimeout time.Duration // default: 30 second + TLSHandshakeTimeout time.Duration // default: 10 second + config *tls.Config + dialer modelx.Dialer + setDeadline func(net.Conn, time.Time) error +} + +// New creates a new TLS dialer +func New(dialer modelx.Dialer, config *tls.Config) *TLSDialer { + return &TLSDialer{ + ConnectTimeout: 30 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + config: config, + dialer: dialer, + setDeadline: func(conn net.Conn, t time.Time) error { + return conn.SetDeadline(t) + }, + } +} + +// DialTLS dials a new TLS connection +func (d *TLSDialer) DialTLS(network, address string) (net.Conn, error) { + ctx := context.Background() + return d.DialTLSContext(ctx, network, address) +} + +// DialTLSContext is like DialTLS, but with context +func (d *TLSDialer) DialTLSContext( + ctx context.Context, network, address string, +) (net.Conn, error) { + host, _, err := net.SplitHostPort(address) + if err != nil { + return nil, err + } + ctx, cancel := context.WithTimeout(ctx, d.ConnectTimeout) + defer cancel() + conn, err := d.dialer.DialContext(ctx, network, address) + if err != nil { + return nil, err + } + config := d.config.Clone() // avoid polluting original config + if config.ServerName == "" { + config.ServerName = host + } + err = d.setDeadline(conn, time.Now().Add(d.TLSHandshakeTimeout)) + if err != nil { + conn.Close() + return nil, err + } + tlsconn := tls.Client(conn, config) + var connID int64 + if mconn, ok := conn.(*connx.MeasuringConn); ok { + connID = mconn.ID + } + root := modelx.ContextMeasurementRootOrDefault(ctx) + // Implementation note: when DialTLS is not set, the code in + // net/http will perform the handshake. Otherwise, if DialTLS + // is set, we will end up here. This code is still used when + // performing non-HTTP TLS-enabled dial operations. + root.Handler.OnMeasurement(modelx.Measurement{ + TLSHandshakeStart: &modelx.TLSHandshakeStartEvent{ + ConnID: connID, + DurationSinceBeginning: time.Now().Sub(root.Beginning), + SNI: config.ServerName, + }, + }) + err = tlsconn.Handshake() + err = errwrapper.SafeErrWrapperBuilder{ + ConnID: connID, + Error: err, + Operation: "tls_handshake", + }.MaybeBuild() + root.Handler.OnMeasurement(modelx.Measurement{ + TLSHandshakeDone: &modelx.TLSHandshakeDoneEvent{ + ConnID: connID, + ConnectionState: modelx.NewTLSConnectionState(tlsconn.ConnectionState()), + Error: err, + DurationSinceBeginning: time.Now().Sub(root.Beginning), + }, + }) + conn.SetDeadline(time.Time{}) // clear deadline + if err != nil { + conn.Close() + return nil, err + } + return tlsconn, err +} diff --git a/netx/internal/dialer/tlsdialer/tlsdialer_test.go b/netx/internal/dialer/tlsdialer/tlsdialer_test.go new file mode 100644 index 00000000..11f44b86 --- /dev/null +++ b/netx/internal/dialer/tlsdialer/tlsdialer_test.go @@ -0,0 +1,93 @@ +package tlsdialer + +import ( + "crypto/tls" + "errors" + "net" + "testing" + "time" + + "github.com/ooni/probe-engine/netx/handlers" + "github.com/ooni/probe-engine/netx/internal/dialer/dialerbase" + "github.com/ooni/probe-engine/netx/modelx" +) + +func TestIntegrationSuccess(t *testing.T) { + dialer := newdialer() + conn, err := dialer.DialTLS("tcp", "www.google.com:443") + if err != nil { + t.Fatal(err) + } + if conn == nil { + t.Fatal("connection is nil") + } + conn.Close() +} + +func TestIntegrationSuccessWithMeasuringConn(t *testing.T) { + dialer := newdialer() + dialer.(*TLSDialer).dialer = dialerbase.New( + time.Now(), handlers.NoHandler, new(net.Dialer), 17, + ) + conn, err := dialer.DialTLS("tcp", "www.google.com:443") + if err != nil { + t.Fatal(err) + } + if conn == nil { + t.Fatal("connection is nil") + } + conn.Close() +} + +func TestIntegrationFailureSplitHostPort(t *testing.T) { + dialer := newdialer() + conn, err := dialer.DialTLS("tcp", "www.google.com") // missing port + if err == nil { + t.Fatal("expected an error here") + } + if conn != nil { + t.Fatal("connection is not nil") + } +} + +func TestIntegrationFailureConnectTimeout(t *testing.T) { + dialer := newdialer() + dialer.(*TLSDialer).ConnectTimeout = 10 * time.Microsecond + conn, err := dialer.DialTLS("tcp", "www.google.com:443") + if err == nil { + t.Fatal("expected an error here") + } + if conn != nil { + t.Fatal("connection is not nil") + } +} + +func TestIntegrationFailureTLSHandshakeTimeout(t *testing.T) { + dialer := newdialer() + dialer.(*TLSDialer).TLSHandshakeTimeout = 10 * time.Microsecond + conn, err := dialer.DialTLS("tcp", "www.google.com:443") + if err == nil { + t.Fatal("expected an error here") + } + if conn != nil { + t.Fatal("connection is not nil") + } +} + +func TestIntegrationFailureSetDeadline(t *testing.T) { + dialer := newdialer() + dialer.(*TLSDialer).setDeadline = func(conn net.Conn, t time.Time) error { + return errors.New("mocked error") + } + conn, err := dialer.DialTLS("tcp", "www.google.com:443") + if err == nil { + t.Fatal("expected an error here") + } + if conn != nil { + t.Fatal("connection is not nil") + } +} + +func newdialer() modelx.TLSDialer { + return New(new(net.Dialer), new(tls.Config)) +} diff --git a/netx/internal/dialid/dialid.go b/netx/internal/dialid/dialid.go new file mode 100644 index 00000000..2f398785 --- /dev/null +++ b/netx/internal/dialid/dialid.go @@ -0,0 +1,23 @@ +package dialid + +import ( + "context" + "sync/atomic" +) + +type contextkey struct{} + +var id int64 + +// WithDialID returns a copy of ctx with DialID +func WithDialID(ctx context.Context) context.Context { + return context.WithValue( + ctx, contextkey{}, atomic.AddInt64(&id, 1), + ) +} + +// ContextDialID returns the DialID of the context, or zero +func ContextDialID(ctx context.Context) int64 { + id, _ := ctx.Value(contextkey{}).(int64) + return id +} diff --git a/netx/internal/dialid/dialid_test.go b/netx/internal/dialid/dialid_test.go new file mode 100644 index 00000000..c67ce139 --- /dev/null +++ b/netx/internal/dialid/dialid_test.go @@ -0,0 +1,24 @@ +package dialid + +import ( + "context" + "testing" +) + +func TestIntegration(t *testing.T) { + ctx := context.Background() + id := ContextDialID(ctx) + if id != 0 { + t.Fatal("unexpected ID for empty context") + } + ctx = WithDialID(ctx) + id = ContextDialID(ctx) + if id != 1 { + t.Fatal("expected ID equal to 1") + } + ctx = WithDialID(ctx) + id = ContextDialID(ctx) + if id != 2 { + t.Fatal("expected ID equal to 2") + } +} diff --git a/netx/internal/errwrapper/errwrapper.go b/netx/internal/errwrapper/errwrapper.go new file mode 100644 index 00000000..8cd7f9a5 --- /dev/null +++ b/netx/internal/errwrapper/errwrapper.go @@ -0,0 +1,127 @@ +// Package errwrapper contains our error wrapper +package errwrapper + +import ( + "crypto/x509" + "errors" + "fmt" + "strings" + + "github.com/ooni/probe-engine/netx/modelx" +) + +// SafeErrWrapperBuilder contains a builder for modelx.ErrWrapper that +// is safe, i.e., behaves correctly when the error is nil. +type SafeErrWrapperBuilder struct { + // ConnID is the connection ID, if any + ConnID int64 + + // DialID is the dial ID, if any + DialID int64 + + // Error is the error, if any + Error error + + // Operation is the operation that failed + Operation string + + // TransactionID is the transaction ID, if any + TransactionID int64 +} + +// MaybeBuild builds a new modelx.ErrWrapper, if b.Error is not nil, and returns +// a nil error value, instead, if b.Error is nil. +func (b SafeErrWrapperBuilder) MaybeBuild() (err error) { + if b.Error != nil { + err = &modelx.ErrWrapper{ + ConnID: b.ConnID, + DialID: b.DialID, + Failure: toFailureString(b.Error), + Operation: toOperationString(b.Error, b.Operation), + TransactionID: b.TransactionID, + WrappedErr: b.Error, + } + } + return +} + +func toFailureString(err error) string { + // The list returned here matches the values used by MK unless + // explicitly noted otherwise with a comment. + + var errwrapper *modelx.ErrWrapper + if errors.As(err, &errwrapper) { + return errwrapper.Error() // we've already wrapped it + } + + if errors.Is(err, modelx.ErrDNSBogon) { + return "dns_bogon_error" // not in MK + } + + var x509HostnameError x509.HostnameError + if errors.As(err, &x509HostnameError) { + // Test case: https://wrong.host.badssl.com/ + return "ssl_invalid_hostname" + } + var x509UnknownAuthorityError x509.UnknownAuthorityError + if errors.As(err, &x509UnknownAuthorityError) { + // Test case: https://self-signed.badssl.com/. This error has + // never been among the ones returned by MK. + return "ssl_unknown_authority" + } + var x509CertificateInvalidError x509.CertificateInvalidError + if errors.As(err, &x509CertificateInvalidError) { + // Test case: https://expired.badssl.com/ + return "ssl_invalid_certificate" + } + + s := err.Error() + if strings.HasSuffix(s, "EOF") { + return "eof_error" + } + if strings.HasSuffix(s, "connection refused") { + return "connection_refused" + } + if strings.HasSuffix(s, "connection reset by peer") { + return "connection_reset" + } + if strings.HasSuffix(s, "context deadline exceeded") { + return "generic_timeout_error" + } + if strings.HasSuffix(s, "i/o timeout") { + return "generic_timeout_error" + } + if strings.HasSuffix(s, "TLS handshake timeout") { + return "generic_timeout_error" + } + if strings.HasSuffix(s, "no such host") { + // This is dns_lookup_error in MK but such error is used as a + // generic "hey, the lookup failed" error. Instead, this error + // that we return here is significantly more specific. + return "dns_nxdomain_error" + } + + return fmt.Sprintf("unknown_failure: %s", s) +} + +func toOperationString(err error, operation string) string { + var errwrapper *modelx.ErrWrapper + if errors.As(err, &errwrapper) { + // Basically, as explained in modelx.ErrWrapper docs, let's + // keep the child major operation, if any. + if errwrapper.Operation == "connect" { + return errwrapper.Operation + } + if errwrapper.Operation == "http_round_trip" { + return errwrapper.Operation + } + if errwrapper.Operation == "resolve" { + return errwrapper.Operation + } + if errwrapper.Operation == "tls_handshake" { + return errwrapper.Operation + } + // FALLTHROUGH + } + return operation +} diff --git a/netx/internal/errwrapper/errwrapper_test.go b/netx/internal/errwrapper/errwrapper_test.go new file mode 100644 index 00000000..2a17e73e --- /dev/null +++ b/netx/internal/errwrapper/errwrapper_test.go @@ -0,0 +1,167 @@ +package errwrapper + +import ( + "context" + "crypto/x509" + "errors" + "io" + "net" + "syscall" + "testing" + + "github.com/ooni/probe-engine/netx/modelx" +) + +func TestMaybeBuildFactory(t *testing.T) { + err := SafeErrWrapperBuilder{ + ConnID: 1, + DialID: 10, + Error: errors.New("mocked error"), + TransactionID: 100, + }.MaybeBuild() + var target *modelx.ErrWrapper + if errors.As(err, &target) == false { + t.Fatal("not the expected error type") + } + if target.ConnID != 1 { + t.Fatal("wrong ConnID") + } + if target.DialID != 10 { + t.Fatal("wrong DialID") + } + if target.Failure != "unknown_failure: mocked error" { + t.Fatal("the failure string is wrong") + } + if target.TransactionID != 100 { + t.Fatal("the transactionID is wrong") + } + if target.WrappedErr.Error() != "mocked error" { + t.Fatal("the wrapped error is wrong") + } +} + +func TestToFailureString(t *testing.T) { + t.Run("for already wrapped error", func(t *testing.T) { + err := SafeErrWrapperBuilder{Error: io.EOF}.MaybeBuild() + if toFailureString(err) != "eof_error" { + t.Fatal("unexpected result") + } + }) + t.Run("for modelx.ErrDNSBogon", func(t *testing.T) { + if toFailureString(modelx.ErrDNSBogon) != "dns_bogon_error" { + t.Fatal("unexpected result") + } + }) + t.Run("for x509.HostnameError", func(t *testing.T) { + var err x509.HostnameError + if toFailureString(err) != "ssl_invalid_hostname" { + t.Fatal("unexpected result") + } + }) + t.Run("for x509.UnknownAuthorityError", func(t *testing.T) { + var err x509.UnknownAuthorityError + if toFailureString(err) != "ssl_unknown_authority" { + t.Fatal("unexpected result") + } + }) + t.Run("for x509.CertificateInvalidError", func(t *testing.T) { + var err x509.CertificateInvalidError + if toFailureString(err) != "ssl_invalid_certificate" { + t.Fatal("unexpected result") + } + }) + t.Run("for EOF", func(t *testing.T) { + if toFailureString(io.EOF) != "eof_error" { + t.Fatal("unexpected results") + } + }) + t.Run("for connection_refused", func(t *testing.T) { + if toFailureString(syscall.ECONNREFUSED) != "connection_refused" { + t.Fatal("unexpected results") + } + }) + t.Run("for connection_reset", func(t *testing.T) { + if toFailureString(syscall.ECONNRESET) != "connection_reset" { + t.Fatal("unexpected results") + } + }) + t.Run("for context deadline expired", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 1) + defer cancel() + <-ctx.Done() + if toFailureString(ctx.Err()) != "generic_timeout_error" { + t.Fatal("unexpected results") + } + }) + t.Run("for i/o error", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 1) + defer cancel() + conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", "www.google.com:80") + if err == nil { + t.Fatal("expected an error here") + } + if conn != nil { + t.Fatal("expected nil connection here") + } + if toFailureString(err) != "generic_timeout_error" { + t.Fatal("unexpected results") + } + }) + t.Run("for TLS handshake timeout error", func(t *testing.T) { + err := errors.New("net/http: TLS handshake timeout") + if toFailureString(err) != "generic_timeout_error" { + t.Fatal("unexpected results") + } + }) + t.Run("for no such host", func(t *testing.T) { + if toFailureString(&net.DNSError{ + Err: "no such host", + }) != "dns_nxdomain_error" { + t.Fatal("unexpected results") + } + }) +} + +func TestUnitToOperationString(t *testing.T) { + t.Run("for connect", func(t *testing.T) { + // You're doing HTTP and connect fails. You want to know + // that connect failed not that HTTP failed. + err := &modelx.ErrWrapper{Operation: "connect"} + if toOperationString(err, "http_round_trip") != "connect" { + t.Fatal("unexpected result") + } + }) + t.Run("for http_round_trip", func(t *testing.T) { + // You're doing DoH and something fails inside HTTP. You want + // to know about the internal HTTP error, not resolve. + err := &modelx.ErrWrapper{Operation: "http_round_trip"} + if toOperationString(err, "resolve") != "http_round_trip" { + t.Fatal("unexpected result") + } + }) + t.Run("for resolve", func(t *testing.T) { + // You're doing HTTP and the DNS fails. You want to + // know that resolve failed. + err := &modelx.ErrWrapper{Operation: "resolve"} + if toOperationString(err, "http_round_trip") != "resolve" { + t.Fatal("unexpected result") + } + }) + t.Run("for tls_handshake", func(t *testing.T) { + // You're doing HTTP and the TLS handshake fails. You want + // to know about a TLS handshake error. + err := &modelx.ErrWrapper{Operation: "tls_handshake"} + if toOperationString(err, "http_round_trip") != "tls_handshake" { + t.Fatal("unexpected result") + } + }) + t.Run("for minor operation", func(t *testing.T) { + // You just noticed that TLS handshake failed and you + // have a child error telling you that read failed. Here + // you want to know about a TLS handshake error. + err := &modelx.ErrWrapper{Operation: "read"} + if toOperationString(err, "tls_handshake") != "tls_handshake" { + t.Fatal("unexpected result") + } + }) +} diff --git a/netx/internal/httptransport/bodytracer/bodytracer.go b/netx/internal/httptransport/bodytracer/bodytracer.go new file mode 100644 index 00000000..0d35bc17 --- /dev/null +++ b/netx/internal/httptransport/bodytracer/bodytracer.go @@ -0,0 +1,84 @@ +// Package bodytracer contains the HTTP body tracer. The purpose +// of tracing is to emit events while we read response bodies. +package bodytracer + +import ( + "io" + "net/http" + "time" + + "github.com/ooni/probe-engine/netx/internal/transactionid" + "github.com/ooni/probe-engine/netx/modelx" +) + +// Transport performs single HTTP transactions and emits +// measurement events as they happen. +type Transport struct { + roundTripper http.RoundTripper +} + +// New creates a new Transport. +func New(roundTripper http.RoundTripper) *Transport { + return &Transport{roundTripper: roundTripper} +} + +// RoundTrip executes a single HTTP transaction, returning +// a Response for the provided Request. +func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) { + resp, err = t.roundTripper.RoundTrip(req) + if err != nil { + return + } + // "The http Client and Transport guarantee that Body is always + // non-nil, even on responses without a body or responses with + // a zero-length body." (from the docs) + resp.Body = &bodyWrapper{ + ReadCloser: resp.Body, + root: modelx.ContextMeasurementRootOrDefault(req.Context()), + tid: transactionid.ContextTransactionID(req.Context()), + } + return +} + +// CloseIdleConnections closes the idle connections. +func (t *Transport) CloseIdleConnections() { + // Adapted from net/http code + type closeIdler interface { + CloseIdleConnections() + } + if tr, ok := t.roundTripper.(closeIdler); ok { + tr.CloseIdleConnections() + } +} + +type bodyWrapper struct { + io.ReadCloser + root *modelx.MeasurementRoot + tid int64 +} + +func (bw *bodyWrapper) Read(b []byte) (n int, err error) { + n, err = bw.ReadCloser.Read(b) + bw.root.Handler.OnMeasurement(modelx.Measurement{ + HTTPResponseBodyPart: &modelx.HTTPResponseBodyPartEvent{ + // "Read reads up to len(p) bytes into p. It returns the number of + // bytes read (0 <= n <= len(p)) and any error encountered." + Data: b[:n], + Error: err, + DurationSinceBeginning: time.Now().Sub(bw.root.Beginning), + TransactionID: bw.tid, + }, + }) + return +} + +func (bw *bodyWrapper) Close() (err error) { + err = bw.ReadCloser.Close() + bw.root.Handler.OnMeasurement(modelx.Measurement{ + HTTPResponseDone: &modelx.HTTPResponseDoneEvent{ + DurationSinceBeginning: time.Now().Sub(bw.root.Beginning), + TransactionID: bw.tid, + }, + }) + return +} diff --git a/netx/internal/httptransport/bodytracer/bodytracer_test.go b/netx/internal/httptransport/bodytracer/bodytracer_test.go new file mode 100644 index 00000000..33b09876 --- /dev/null +++ b/netx/internal/httptransport/bodytracer/bodytracer_test.go @@ -0,0 +1,39 @@ +package bodytracer + +import ( + "io/ioutil" + "net/http" + "testing" +) + +func TestIntegration(t *testing.T) { + client := &http.Client{ + Transport: New(http.DefaultTransport), + } + resp, err := client.Get("https://www.google.com") + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + client.CloseIdleConnections() +} + +func TestIntegrationFailure(t *testing.T) { + client := &http.Client{ + Transport: New(http.DefaultTransport), + } + // This fails the request because we attempt to speak cleartext HTTP with + // a server that instead is expecting TLS. + resp, err := client.Get("http://www.google.com:443") + if err == nil { + t.Fatal("expected an error here") + } + if resp != nil { + t.Fatal("expected a nil response here") + } + client.CloseIdleConnections() +} diff --git a/netx/internal/httptransport/httptransport.go b/netx/internal/httptransport/httptransport.go new file mode 100644 index 00000000..2a0eb639 --- /dev/null +++ b/netx/internal/httptransport/httptransport.go @@ -0,0 +1,47 @@ +// Package httptransport contains HTTP transport extensions. Here we +// define a http.Transport that emits events. +package httptransport + +import ( + "net/http" + + "github.com/ooni/probe-engine/netx/internal/httptransport/bodytracer" + "github.com/ooni/probe-engine/netx/internal/httptransport/tracetripper" + "github.com/ooni/probe-engine/netx/internal/httptransport/transactioner" +) + +// Transport performs single HTTP transactions and emits +// measurement events as they happen. +type Transport struct { + roundTripper http.RoundTripper +} + +// New creates a new Transport. +func New(roundTripper http.RoundTripper) *Transport { + return &Transport{ + roundTripper: transactioner.New(bodytracer.New( + tracetripper.New(roundTripper))), + } +} + +// RoundTrip executes a single HTTP transaction, returning +// a Response for the provided Request. +func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) { + // Make sure we're not sending Go's default User-Agent + // if the user has configured no user agent + if req.Header.Get("User-Agent") == "" { + req.Header["User-Agent"] = nil + } + return t.roundTripper.RoundTrip(req) +} + +// CloseIdleConnections closes the idle connections. +func (t *Transport) CloseIdleConnections() { + // Adapted from net/http code + type closeIdler interface { + CloseIdleConnections() + } + if tr, ok := t.roundTripper.(closeIdler); ok { + tr.CloseIdleConnections() + } +} diff --git a/netx/internal/httptransport/httptransport_test.go b/netx/internal/httptransport/httptransport_test.go new file mode 100644 index 00000000..8ea3e95f --- /dev/null +++ b/netx/internal/httptransport/httptransport_test.go @@ -0,0 +1,39 @@ +package httptransport + +import ( + "io/ioutil" + "net/http" + "testing" +) + +func TestIntegration(t *testing.T) { + client := &http.Client{ + Transport: New(http.DefaultTransport), + } + resp, err := client.Get("https://www.google.com") + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + client.CloseIdleConnections() +} + +func TestIntegrationFailure(t *testing.T) { + client := &http.Client{ + Transport: New(http.DefaultTransport), + } + // This fails the request because we attempt to speak cleartext HTTP with + // a server that instead is expecting TLS. + resp, err := client.Get("http://www.google.com:443") + if err == nil { + t.Fatal("expected an error here") + } + if resp != nil { + t.Fatal("expected a nil response here") + } + client.CloseIdleConnections() +} diff --git a/netx/internal/httptransport/tracetripper/tracetripper.go b/netx/internal/httptransport/tracetripper/tracetripper.go new file mode 100644 index 00000000..3a0c169c --- /dev/null +++ b/netx/internal/httptransport/tracetripper/tracetripper.go @@ -0,0 +1,274 @@ +// Package tracetripper contains the tracing round tripper +package tracetripper + +import ( + "bytes" + "crypto/tls" + "io" + "io/ioutil" + "net/http" + "net/http/httptrace" + "sync" + "sync/atomic" + "time" + + "github.com/ooni/probe-engine/netx/internal/connid" + "github.com/ooni/probe-engine/netx/internal/dialid" + "github.com/ooni/probe-engine/netx/internal/errwrapper" + "github.com/ooni/probe-engine/netx/internal/transactionid" + "github.com/ooni/probe-engine/netx/modelx" +) + +// Transport performs single HTTP transactions. +type Transport struct { + readAllErrs int64 + readAll func(r io.Reader) ([]byte, error) + roundTripper http.RoundTripper +} + +// New creates a new Transport. +func New(roundTripper http.RoundTripper) *Transport { + return &Transport{ + readAll: ioutil.ReadAll, + roundTripper: roundTripper, + } +} + +type readCloseWrapper struct { + closer io.Closer + reader io.Reader +} + +func newReadCloseWrapper( + reader io.Reader, closer io.ReadCloser, +) *readCloseWrapper { + return &readCloseWrapper{ + closer: closer, + reader: reader, + } +} + +func (c *readCloseWrapper) Read(p []byte) (int, error) { + return c.reader.Read(p) +} + +func (c *readCloseWrapper) Close() error { + return c.closer.Close() +} + +func readSnap( + source *io.ReadCloser, limit int64, + readAll func(r io.Reader) ([]byte, error), +) (data []byte, err error) { + data, err = readAll(io.LimitReader(*source, limit)) + if err == nil { + *source = newReadCloseWrapper( + io.MultiReader(bytes.NewReader(data), *source), + *source, + ) + } + return +} + +// RoundTrip executes a single HTTP transaction, returning +// a Response for the provided Request. +func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { + root := modelx.ContextMeasurementRootOrDefault(req.Context()) + + tid := transactionid.ContextTransactionID(req.Context()) + root.Handler.OnMeasurement(modelx.Measurement{ + HTTPRoundTripStart: &modelx.HTTPRoundTripStartEvent{ + DialID: dialid.ContextDialID(req.Context()), + DurationSinceBeginning: time.Now().Sub(root.Beginning), + Method: req.Method, + TransactionID: tid, + URL: req.URL.String(), + }, + }) + + var ( + err error + majorOp = "http_round_trip" + majorOpMu sync.Mutex + requestBody []byte + requestHeaders = http.Header{} + requestHeadersMu sync.Mutex + snapSize = modelx.ComputeBodySnapSize(root.MaxBodySnapSize) + ) + + // Save a snapshot of the request body + if req.Body != nil { + requestBody, err = readSnap(&req.Body, snapSize, t.readAll) + if err != nil { + return nil, err + } + } + + // Prepare a tracer for delivering events + tracer := &httptrace.ClientTrace{ + TLSHandshakeStart: func() { + majorOpMu.Lock() + majorOp = "tls_handshake" + majorOpMu.Unlock() + // Event emitted by net/http when DialTLS is not + // configured in the http.Transport + root.Handler.OnMeasurement(modelx.Measurement{ + TLSHandshakeStart: &modelx.TLSHandshakeStartEvent{ + DurationSinceBeginning: time.Now().Sub(root.Beginning), + TransactionID: tid, + }, + }) + }, + TLSHandshakeDone: func(state tls.ConnectionState, err error) { + // Wrapping the error even if we're not returning it because it may + // less confusing to users to see the wrapped name + err = errwrapper.SafeErrWrapperBuilder{ + Error: err, + Operation: "tls_handshake", + TransactionID: tid, + }.MaybeBuild() + durationSinceBeginning := time.Now().Sub(root.Beginning) + // Event emitted by net/http when DialTLS is not + // configured in the http.Transport + root.Handler.OnMeasurement(modelx.Measurement{ + TLSHandshakeDone: &modelx.TLSHandshakeDoneEvent{ + ConnectionState: modelx.NewTLSConnectionState(state), + Error: err, + DurationSinceBeginning: durationSinceBeginning, + TransactionID: tid, + }, + }) + }, + GotConn: func(info httptrace.GotConnInfo) { + majorOpMu.Lock() + majorOp = "http_round_trip" + majorOpMu.Unlock() + root.Handler.OnMeasurement(modelx.Measurement{ + HTTPConnectionReady: &modelx.HTTPConnectionReadyEvent{ + ConnID: connid.Compute( + info.Conn.LocalAddr().Network(), + info.Conn.LocalAddr().String(), + ), + DurationSinceBeginning: time.Now().Sub(root.Beginning), + TransactionID: tid, + }, + }) + }, + WroteHeaderField: func(key string, values []string) { + requestHeadersMu.Lock() + // Important: do not set directly into the headers map using + // the [] operator because net/http expects to be able to + // perform normalization of header names! + for _, value := range values { + requestHeaders.Add(key, value) + } + requestHeadersMu.Unlock() + root.Handler.OnMeasurement(modelx.Measurement{ + HTTPRequestHeader: &modelx.HTTPRequestHeaderEvent{ + DurationSinceBeginning: time.Now().Sub(root.Beginning), + Key: key, + TransactionID: tid, + Value: values, + }, + }) + }, + WroteHeaders: func() { + root.Handler.OnMeasurement(modelx.Measurement{ + HTTPRequestHeadersDone: &modelx.HTTPRequestHeadersDoneEvent{ + DurationSinceBeginning: time.Now().Sub(root.Beginning), + Headers: requestHeaders, // [*] + Method: req.Method, // [*] + TransactionID: tid, + URL: req.URL, // [*] + }, + }) + }, + WroteRequest: func(info httptrace.WroteRequestInfo) { + // Wrapping the error even if we're not returning it because it may + // less confusing to users to see the wrapped name + err := errwrapper.SafeErrWrapperBuilder{ + Error: info.Err, + Operation: "http_round_trip", + TransactionID: tid, + }.MaybeBuild() + root.Handler.OnMeasurement(modelx.Measurement{ + HTTPRequestDone: &modelx.HTTPRequestDoneEvent{ + DurationSinceBeginning: time.Now().Sub(root.Beginning), + Error: err, + TransactionID: tid, + }, + }) + }, + GotFirstResponseByte: func() { + root.Handler.OnMeasurement(modelx.Measurement{ + HTTPResponseStart: &modelx.HTTPResponseStartEvent{ + DurationSinceBeginning: time.Now().Sub(root.Beginning), + TransactionID: tid, + }, + }) + }, + } + + // If we don't have already a tracer this is a toplevel request, so just + // set the tracer. Otherwise, we're doing DoH. We cannot set anothert trace + // because they'd be merged. Instead, replace the existing trace content + // with the new trace and then remember to reset it. + origtracer := httptrace.ContextClientTrace(req.Context()) + if origtracer != nil { + bkp := *origtracer + *origtracer = *tracer + defer func() { + *origtracer = bkp + }() + } else { + req = req.WithContext(httptrace.WithClientTrace(req.Context(), tracer)) + } + + resp, err := t.roundTripper.RoundTrip(req) + err = errwrapper.SafeErrWrapperBuilder{ + Error: err, + Operation: majorOp, + TransactionID: tid, + }.MaybeBuild() + // [*] Require less event joining work by providing info that + // makes this event alone actionable for OONI + event := &modelx.HTTPRoundTripDoneEvent{ + DurationSinceBeginning: time.Now().Sub(root.Beginning), + Error: err, + RequestBodySnap: requestBody, + RequestHeaders: requestHeaders, // [*] + RequestMethod: req.Method, // [*] + RequestURL: req.URL.String(), // [*] + MaxBodySnapSize: snapSize, + TransactionID: tid, + } + if resp != nil { + event.ResponseHeaders = resp.Header + event.ResponseStatusCode = int64(resp.StatusCode) + event.ResponseProto = resp.Proto + // Save a snapshot of the response body + var data []byte + data, err = readSnap(&resp.Body, snapSize, t.readAll) + if err != nil { + atomic.AddInt64(&t.readAllErrs, 1) + resp = nil // this is how net/http likes it + } else { + event.ResponseBodySnap = data + } + } + root.Handler.OnMeasurement(modelx.Measurement{ + HTTPRoundTripDone: event, + }) + return resp, err +} + +// CloseIdleConnections closes the idle connections. +func (t *Transport) CloseIdleConnections() { + // Adapted from net/http code + type closeIdler interface { + CloseIdleConnections() + } + if tr, ok := t.roundTripper.(closeIdler); ok { + tr.CloseIdleConnections() + } +} diff --git a/netx/internal/httptransport/tracetripper/tracetripper_test.go b/netx/internal/httptransport/tracetripper/tracetripper_test.go new file mode 100644 index 00000000..c410b197 --- /dev/null +++ b/netx/internal/httptransport/tracetripper/tracetripper_test.go @@ -0,0 +1,272 @@ +package tracetripper + +import ( + "bytes" + "context" + "errors" + "io" + "io/ioutil" + "net/http" + "net/http/httptrace" + "sync" + "testing" + "time" + + "github.com/miekg/dns" + "github.com/ooni/probe-engine/netx/modelx" +) + +func TestIntegration(t *testing.T) { + client := &http.Client{ + Transport: New(http.DefaultTransport), + } + resp, err := client.Get("https://www.google.com") + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + client.CloseIdleConnections() +} + +type roundTripHandler struct { + roundTrips []*modelx.HTTPRoundTripDoneEvent + mu sync.Mutex +} + +func (h *roundTripHandler) OnMeasurement(m modelx.Measurement) { + if m.HTTPRoundTripDone != nil { + h.mu.Lock() + defer h.mu.Unlock() + h.roundTrips = append(h.roundTrips, m.HTTPRoundTripDone) + } +} + +func TestIntegrationReadAllFailure(t *testing.T) { + transport := New(http.DefaultTransport) + transport.readAll = func(r io.Reader) ([]byte, error) { + return nil, io.EOF + } + client := &http.Client{Transport: transport} + resp, err := client.Get("https://google.com") + if err == nil { + t.Fatal("expected an error here") + } + if !errors.Is(err, io.EOF) { + t.Fatal("not the error we expected") + } + if resp != nil { + t.Fatal("expected nil response here") + } + if transport.readAllErrs <= 0 { + t.Fatal("not the error we expected") + } + client.CloseIdleConnections() +} + +func TestIntegrationFailure(t *testing.T) { + client := &http.Client{ + Transport: New(http.DefaultTransport), + } + // This fails the request because we attempt to speak cleartext HTTP with + // a server that instead is expecting TLS. + resp, err := client.Get("http://www.google.com:443") + if err == nil { + t.Fatal("expected an error here") + } + if resp != nil { + t.Fatal("expected a nil response here") + } + client.CloseIdleConnections() +} + +func TestIntegrationWithClientTrace(t *testing.T) { + client := &http.Client{ + Transport: New(http.DefaultTransport), + } + req, err := http.NewRequest("GET", "https://www.kernel.org/", nil) + if err != nil { + t.Fatal(err) + } + req = req.WithContext( + httptrace.WithClientTrace(req.Context(), new(httptrace.ClientTrace)), + ) + resp, err := client.Do(req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected a good response here") + } + resp.Body.Close() + client.CloseIdleConnections() +} + +func TestIntegrationWithCorrectSnaps(t *testing.T) { + // Prepare a DNS query for dns.google.com A, for which we + // know the answer in terms of well know IP addresses + query := new(dns.Msg) + query.Id = dns.Id() + query.RecursionDesired = true + query.Question = make([]dns.Question, 1) + query.Question[0] = dns.Question{ + Name: dns.Fqdn("dns.google.com"), + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + } + queryData, err := query.Pack() + if err != nil { + t.Fatal(err) + } + + // Prepare a new transport with limited snapshot size and + // use such transport to configure an ordinary client + transport := New(http.DefaultTransport) + const snapSize = 15 + client := &http.Client{Transport: transport} + + // Prepare a new request for Cloudflare DNS, register + // a handler, issue the request, fetch the response. + req, err := http.NewRequest( + "POST", "https://cloudflare-dns.com/dns-query", bytes.NewReader(queryData), + ) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/dns-message") + handler := &roundTripHandler{} + ctx := modelx.WithMeasurementRoot( + context.Background(), &modelx.MeasurementRoot{ + Beginning: time.Now(), + Handler: handler, + MaxBodySnapSize: snapSize, + }, + ) + req = req.WithContext(ctx) + resp, err := client.Do(req) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != 200 { + t.Fatal("HTTP request failed") + } + + // Read the whole response body, parse it as valid DNS + // reply and verify we obtained what we expected + replyData, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + reply := new(dns.Msg) + err = reply.Unpack(replyData) + if err != nil { + t.Fatal(err) + } + if reply.Rcode != 0 { + t.Fatal("unexpected Rcode") + } + if len(reply.Answer) < 1 { + t.Fatal("no answers?!") + } + found8888, found8844, foundother := false, false, false + for _, answer := range reply.Answer { + if rra, ok := answer.(*dns.A); ok { + ip := rra.A.String() + if ip == "8.8.8.8" { + found8888 = true + } else if ip == "8.8.4.4" { + found8844 = true + } else { + foundother = true + } + } + } + if !found8888 || !found8844 || foundother { + t.Fatal("unexpected reply") + } + + // Finally, make sure we have captured the correct + // snapshots for the request and response bodies + if len(handler.roundTrips) != 1 { + t.Fatal("more round trips than expected") + } + roundTrip := handler.roundTrips[0] + if len(roundTrip.RequestBodySnap) != snapSize { + t.Fatal("unexpected request body snap length") + } + if len(roundTrip.ResponseBodySnap) != snapSize { + t.Fatal("unexpected response body snap length") + } + if !bytes.Equal(roundTrip.RequestBodySnap, queryData[:snapSize]) { + t.Fatal("the request body snap is wrong") + } + if !bytes.Equal(roundTrip.ResponseBodySnap, replyData[:snapSize]) { + t.Fatal("the response body snap is wrong") + } +} + +func TestIntegrationWithReadAllFailingForBody(t *testing.T) { + // Prepare a DNS query for dns.google.com A, for which we + // know the answer in terms of well know IP addresses + query := new(dns.Msg) + query.Id = dns.Id() + query.RecursionDesired = true + query.Question = make([]dns.Question, 1) + query.Question[0] = dns.Question{ + Name: dns.Fqdn("dns.google.com"), + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + } + queryData, err := query.Pack() + if err != nil { + t.Fatal(err) + } + + // Prepare a new transport with limited snapshot size and + // use such transport to configure an ordinary client + transport := New(http.DefaultTransport) + errorMocked := errors.New("mocked error") + transport.readAll = func(r io.Reader) ([]byte, error) { + return nil, errorMocked + } + const snapSize = 15 + client := &http.Client{Transport: transport} + + // Prepare a new request for Cloudflare DNS, register + // a handler, issue the request, fetch the response. + req, err := http.NewRequest( + "POST", "https://cloudflare-dns.com/dns-query", bytes.NewReader(queryData), + ) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/dns-message") + handler := &roundTripHandler{} + ctx := modelx.WithMeasurementRoot( + context.Background(), &modelx.MeasurementRoot{ + Beginning: time.Now(), + Handler: handler, + MaxBodySnapSize: snapSize, + }, + ) + req = req.WithContext(ctx) + resp, err := client.Do(req) + if err == nil { + t.Fatal("expected an error here") + } + if !errors.Is(err, errorMocked) { + t.Fatal("not the error we expected") + } + if resp != nil { + t.Fatal("expected nil response here") + } + + // Finally, make sure we got something that makes sense + if len(handler.roundTrips) != 0 { + t.Fatal("more round trips than expected") + } +} diff --git a/netx/internal/httptransport/transactioner/transactioner.go b/netx/internal/httptransport/transactioner/transactioner.go new file mode 100644 index 00000000..06f5532c --- /dev/null +++ b/netx/internal/httptransport/transactioner/transactioner.go @@ -0,0 +1,39 @@ +// Package transactioner contains the transaction assigning round tripper +package transactioner + +import ( + "net/http" + + "github.com/ooni/probe-engine/netx/internal/transactionid" +) + +// Transport performs single HTTP transactions. +type Transport struct { + roundTripper http.RoundTripper +} + +// New creates a new Transport. +func New(roundTripper http.RoundTripper) *Transport { + return &Transport{ + roundTripper: roundTripper, + } +} + +// RoundTrip executes a single HTTP transaction, returning +// a Response for the provided Request. +func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { + return t.roundTripper.RoundTrip(req.WithContext( + transactionid.WithTransactionID(req.Context()), + )) +} + +// CloseIdleConnections closes the idle connections. +func (t *Transport) CloseIdleConnections() { + // Adapted from net/http code + type closeIdler interface { + CloseIdleConnections() + } + if tr, ok := t.roundTripper.(closeIdler); ok { + tr.CloseIdleConnections() + } +} diff --git a/netx/internal/httptransport/transactioner/transactioner_test.go b/netx/internal/httptransport/transactioner/transactioner_test.go new file mode 100644 index 00000000..14a5517b --- /dev/null +++ b/netx/internal/httptransport/transactioner/transactioner_test.go @@ -0,0 +1,57 @@ +package transactioner + +import ( + "io/ioutil" + "net/http" + "testing" + + "github.com/ooni/probe-engine/netx/internal/transactionid" +) + +type transport struct { + roundTripper http.RoundTripper + t *testing.T +} + +func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) { + ctx := req.Context() + if id := transactionid.ContextTransactionID(ctx); id == 0 { + t.t.Fatal("transaction ID not set") + } + return t.roundTripper.RoundTrip(req) +} + +func TestIntegration(t *testing.T) { + client := &http.Client{ + Transport: New(&transport{ + roundTripper: http.DefaultTransport, + t: t, + }), + } + resp, err := client.Get("https://www.google.com") + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + client.CloseIdleConnections() +} + +func TestIntegrationFailure(t *testing.T) { + client := &http.Client{ + Transport: New(http.DefaultTransport), + } + // This fails the request because we attempt to speak cleartext HTTP with + // a server that instead is expecting TLS. + resp, err := client.Get("http://www.google.com:443") + if err == nil { + t.Fatal("expected an error here") + } + if resp != nil { + t.Fatal("expected a nil response here") + } + client.CloseIdleConnections() +} diff --git a/netx/internal/internal.go b/netx/internal/internal.go new file mode 100644 index 00000000..a26e218f --- /dev/null +++ b/netx/internal/internal.go @@ -0,0 +1,355 @@ +// Package internal contains internal code. +package internal + +import ( + "context" + "crypto/tls" + "crypto/x509" + "errors" + "io/ioutil" + "net" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/ooni/probe-engine/netx/handlers" + "github.com/ooni/probe-engine/netx/internal/dialer" + "github.com/ooni/probe-engine/netx/internal/errwrapper" + "github.com/ooni/probe-engine/netx/internal/httptransport" + "github.com/ooni/probe-engine/netx/internal/resolver" + "github.com/ooni/probe-engine/netx/internal/resolver/chainresolver" + "github.com/ooni/probe-engine/netx/modelx" + "golang.org/x/net/http2" +) + +// Dialer defines the dialer API. We implement the most basic form +// of DNS, but more advanced resolutions are possible. +type Dialer struct { + Beginning time.Time + Handler modelx.Handler + Resolver modelx.DNSResolver + TLSConfig *tls.Config +} + +// NewDialer creates a new Dialer. +func NewDialer( + beginning time.Time, handler modelx.Handler, +) (d *Dialer) { + return &Dialer{ + Beginning: beginning, + Handler: handler, + Resolver: resolver.NewResolverSystem(), + TLSConfig: new(tls.Config), + } +} + +// Dial creates a TCP or UDP connection. See net.Dial docs. +func (d *Dialer) Dial(network, address string) (net.Conn, error) { + return d.DialContext(context.Background(), network, address) +} + +func maybeWithMeasurementRoot( + ctx context.Context, beginning time.Time, handler modelx.Handler, +) context.Context { + if modelx.ContextMeasurementRoot(ctx) != nil { + return ctx + } + return modelx.WithMeasurementRoot(ctx, &modelx.MeasurementRoot{ + Beginning: beginning, + Handler: handler, + }) +} + +// DialContext is like Dial but the context allows to interrupt a +// pending connection attempt at any time. +func (d *Dialer) DialContext( + ctx context.Context, network, address string, +) (conn net.Conn, err error) { + ctx = maybeWithMeasurementRoot(ctx, d.Beginning, d.Handler) + return dialer.New( + d.Resolver, new(net.Dialer), + ).DialContext(ctx, network, address) +} + +// DialTLS is like Dial, but creates TLS connections. +func (d *Dialer) DialTLS(network, address string) (net.Conn, error) { + ctx := context.Background() + return d.DialTLSContext(ctx, network, address) +} + +// DialTLSContext is like DialTLS, but with context +func (d *Dialer) DialTLSContext( + ctx context.Context, network, address string, +) (net.Conn, error) { + ctx = maybeWithMeasurementRoot(ctx, d.Beginning, d.Handler) + return dialer.NewTLS( + dialer.New(d.Resolver, new(net.Dialer)), + d.TLSConfig, + ).DialTLSContext(ctx, network, address) +} + +// SetCABundle configures the dialer to use a specific CA bundle. +func (d *Dialer) SetCABundle(path string) error { + cert, err := ioutil.ReadFile(path) + if err != nil { + return err + } + pool := x509.NewCertPool() + pool.AppendCertsFromPEM(cert) + d.TLSConfig.RootCAs = pool + return nil +} + +// ForceSpecificSNI forces using a specific SNI. +func (d *Dialer) ForceSpecificSNI(sni string) error { + d.TLSConfig.ServerName = sni + return nil +} + +// ForceSkipVerify forces to skip certificate verification +func (d *Dialer) ForceSkipVerify() error { + d.TLSConfig.InsecureSkipVerify = true + return nil +} + +// ConfigureDNS implements netx.Dialer.ConfigureDNS. +func (d *Dialer) ConfigureDNS(network, address string) error { + r, err := NewResolver(d.Beginning, d.Handler, network, address) + if err == nil { + d.Resolver = r + } + return err +} + +// SetResolver implements netx.Dialer.SetResolver. +func (d *Dialer) SetResolver(r modelx.DNSResolver) { + d.Resolver = r +} + +var ( + dohClientHandle *http.Client + dohClientOnce sync.Once +) + +func newHTTPClientForDoH(beginning time.Time, handler modelx.Handler) *http.Client { + if handler == handlers.NoHandler { + // A bit of extra complexity for a good reason: if the user is not + // interested into setting a default handler, then it is fine to + // always return the same *http.Client for DoH. This means that we + // don't need to care about closing the connections used by this + // *http.Client, therefore we don't leak resources because we fail + // to close the idle connections. + dohClientOnce.Do(func() { + transport := NewHTTPTransport( + time.Time{}, + handlers.NoHandler, + NewDialer(time.Time{}, handlers.NoHandler), + false, // DisableKeepAlives + http.ProxyFromEnvironment, + ) + dohClientHandle = &http.Client{Transport: transport} + }) + return dohClientHandle + } + // Otherwise, if the user wants to have a default handler, we + // return a transport that does not leak connections. + transport := NewHTTPTransport( + beginning, + handler, + NewDialer(beginning, handler), + true, // DisableKeepAlives + http.ProxyFromEnvironment, + ) + return &http.Client{Transport: transport} +} + +func withPort(address, port string) string { + // Handle the case where port was not specified. We have written in + // a bunch of places that we can just pass a domain in this case and + // so we need to gracefully ensure this is still possible. + _, _, err := net.SplitHostPort(address) + if err != nil && strings.Contains(err.Error(), "missing port in address") { + address = net.JoinHostPort(address, port) + } + return address +} + +type resolverWrapper struct { + beginning time.Time + handler modelx.Handler + resolver modelx.DNSResolver +} + +func newResolverWrapper( + beginning time.Time, handler modelx.Handler, + resolver modelx.DNSResolver, +) *resolverWrapper { + return &resolverWrapper{ + beginning: beginning, + handler: handler, + resolver: resolver, + } +} + +// LookupAddr returns the name of the provided IP address +func (r *resolverWrapper) LookupAddr(ctx context.Context, addr string) ([]string, error) { + ctx = maybeWithMeasurementRoot(ctx, r.beginning, r.handler) + return r.resolver.LookupAddr(ctx, addr) +} + +// LookupCNAME returns the canonical name of a host +func (r *resolverWrapper) LookupCNAME(ctx context.Context, host string) (string, error) { + ctx = maybeWithMeasurementRoot(ctx, r.beginning, r.handler) + return r.resolver.LookupCNAME(ctx, host) +} + +// LookupHost returns the IP addresses of a host +func (r *resolverWrapper) LookupHost(ctx context.Context, hostname string) ([]string, error) { + ctx = maybeWithMeasurementRoot(ctx, r.beginning, r.handler) + return r.resolver.LookupHost(ctx, hostname) +} + +// LookupMX returns the MX records of a specific name +func (r *resolverWrapper) LookupMX(ctx context.Context, name string) ([]*net.MX, error) { + ctx = maybeWithMeasurementRoot(ctx, r.beginning, r.handler) + return r.resolver.LookupMX(ctx, name) +} + +// LookupNS returns the NS records of a specific name +func (r *resolverWrapper) LookupNS(ctx context.Context, name string) ([]*net.NS, error) { + ctx = maybeWithMeasurementRoot(ctx, r.beginning, r.handler) + return r.resolver.LookupNS(ctx, name) +} + +// NewResolver returns a new resolver +func NewResolver( + beginning time.Time, handler modelx.Handler, network, address string, +) (modelx.DNSResolver, error) { + // Implementation note: system need to be dealt with + // separately because it doesn't have any transport. + if network == "system" || network == "" { + return newResolverWrapper( + beginning, handler, resolver.NewResolverSystem()), nil + } + if network == "doh" { + return newResolverWrapper(beginning, handler, resolver.NewResolverHTTPS( + newHTTPClientForDoH(beginning, handler), address, + )), nil + } + if network == "dot" { + // We need a child dialer here to avoid an endless loop where the + // dialer will ask us to resolve, we'll tell the dialer to dial, it + // will ask us to resolve, ... + return newResolverWrapper(beginning, handler, resolver.NewResolverTLS( + NewDialer(beginning, handler), withPort(address, "853"), + )), nil + } + if network == "tcp" { + // Same rationale as above: avoid possible endless loop + return newResolverWrapper(beginning, handler, resolver.NewResolverTCP( + NewDialer(beginning, handler), withPort(address, "53"), + )), nil + } + if network == "udp" { + // Same rationale as above: avoid possible endless loop + return newResolverWrapper(beginning, handler, resolver.NewResolverUDP( + NewDialer(beginning, handler), withPort(address, "53"), + )), nil + } + return nil, errors.New("resolver.New: unsupported network value") +} + +// HTTPTransport performs single HTTP transactions and emits +// measurement events as they happen. +type HTTPTransport struct { + Transport *http.Transport + Handler modelx.Handler + Beginning time.Time + roundTripper http.RoundTripper +} + +// NewHTTPTransport creates a new Transport. +func NewHTTPTransport( + beginning time.Time, + handler modelx.Handler, + dialer *Dialer, + disableKeepAlives bool, + proxyFunc func(*http.Request) (*url.URL, error), +) *HTTPTransport { + baseTransport := &http.Transport{ + // The following values are copied from Go 1.12 docs and match + // what should be used by the default transport + ExpectContinueTimeout: 1 * time.Second, + IdleConnTimeout: 90 * time.Second, + MaxIdleConns: 100, + Proxy: proxyFunc, + TLSHandshakeTimeout: 10 * time.Second, + DisableKeepAlives: disableKeepAlives, + } + ooniTransport := httptransport.New(baseTransport) + // Configure h2 and make sure that the custom TLSConfig we use for dialing + // is actually compatible with upgrading to h2. (This mainly means we + // need to make sure we include "h2" in the NextProtos array.) Because + // http2.ConfigureTransport only returns error when we have already + // configured http2, it is safe to ignore the return value. + http2.ConfigureTransport(baseTransport) + // Since we're not going to use our dialer for TLS, the main purpose of + // the following line is to make sure ForseSpecificSNI has impact on the + // config we are going to use when doing TLS. The code is as such since + // we used to force net/http through using dialer.DialTLS. + dialer.TLSConfig = baseTransport.TLSClientConfig + // Arrange the configuration such that we always use `dialer` for dialing + // cleartext connections. The net/http code will dial TLS connections. + baseTransport.DialContext = dialer.DialContext + // Better for Cloudflare DNS and also better because we have less + // noisy events and we can better understand what happened. + baseTransport.MaxConnsPerHost = 1 + // The following (1) reduces the number of headers that Go will + // automatically send for us and (2) ensures that we always receive + // back the true headers, such as Content-Length. This change is + // functional to OONI's goal of observing the network. + baseTransport.DisableCompression = true + return &HTTPTransport{ + Transport: baseTransport, + Handler: handler, + Beginning: beginning, + roundTripper: ooniTransport, + } +} + +// RoundTrip executes a single HTTP transaction, returning +// a Response for the provided Request. +func (t *HTTPTransport) RoundTrip( + req *http.Request, +) (resp *http.Response, err error) { + ctx := maybeWithMeasurementRoot(req.Context(), t.Beginning, t.Handler) + req = req.WithContext(ctx) + resp, err = t.roundTripper.RoundTrip(req) + // For safety wrap the error as "http_round_trip" but this + // will only be used if the error chain does not contain any + // other major operation failure. See modelx.ErrWrapper. + err = errwrapper.SafeErrWrapperBuilder{ + Error: err, + Operation: "http_round_trip", + }.MaybeBuild() + return resp, err +} + +// CloseIdleConnections closes the idle connections. +func (t *HTTPTransport) CloseIdleConnections() { + // Adapted from net/http code + type closeIdler interface { + CloseIdleConnections() + } + if tr, ok := t.roundTripper.(closeIdler); ok { + tr.CloseIdleConnections() + } +} + +// ChainResolvers chains a primary and a secondary resolver such that +// we can fallback to the secondary if primary is broken. +func ChainResolvers(primary, secondary modelx.DNSResolver) modelx.DNSResolver { + return chainresolver.New(primary, secondary) +} diff --git a/netx/internal/internal_test.go b/netx/internal/internal_test.go new file mode 100644 index 00000000..49c0ca72 --- /dev/null +++ b/netx/internal/internal_test.go @@ -0,0 +1,403 @@ +package internal + +import ( + "context" + "crypto/tls" + "crypto/x509" + "errors" + "io/ioutil" + "net" + "net/http" + "strings" + "testing" + "time" + + "github.com/ooni/probe-engine/netx/handlers" + "github.com/ooni/probe-engine/netx/internal/resolver/brokenresolver" + "github.com/ooni/probe-engine/netx/internal/resolver/systemresolver" +) + +func TestIntegrationDial(t *testing.T) { + dialer := NewDialer(time.Now(), handlers.NoHandler) + conn, err := dialer.Dial("tcp", "www.google.com:80") + if err != nil { + t.Fatal(err) + } + conn.Close() +} + +func TestIntegrationDialWithCustomResolver(t *testing.T) { + dialer := NewDialer(time.Now(), handlers.NoHandler) + dialer.SetResolver(new(net.Resolver)) + conn, err := dialer.Dial("tcp", "www.google.com:80") + if err != nil { + t.Fatal(err) + } + conn.Close() +} + +func TestIntegrationDialTLS(t *testing.T) { + dialer := NewDialer(time.Now(), handlers.NoHandler) + conn, err := dialer.DialTLS("tcp", "www.google.com:443") + if err != nil { + t.Fatal(err) + } + conn.Close() +} + +func TestIntegrationDialTLSForceSkipVerify(t *testing.T) { + dialer := NewDialer(time.Now(), handlers.NoHandler) + dialer.ForceSkipVerify() + conn, err := dialer.DialTLS("tcp", "self-signed.badssl.com:443") + if err != nil { + t.Fatal(err) + } + conn.Close() +} + +func TestIntegrationDialInvalidAddress(t *testing.T) { + dialer := NewDialer(time.Now(), handlers.NoHandler) + conn, err := dialer.Dial("tcp", "www.google.com") + if err == nil { + t.Fatal("expected an error here") + } + if conn != nil { + t.Fatal("expected a nil conn here") + } +} + +func TestIntegrationDialInvalidAddressTLS(t *testing.T) { + dialer := NewDialer(time.Now(), handlers.NoHandler) + conn, err := dialer.DialTLS("tcp", "www.google.com") + if err == nil { + t.Fatal("expected an error here") + } + if conn != nil { + t.Fatal("expected a nil conn here") + } +} + +func TestIntegrationDialInvalidSNI(t *testing.T) { + dialer := NewDialer(time.Now(), handlers.NoHandler) + dialer.TLSConfig = &tls.Config{ + ServerName: "www.google.com", + } + conn, err := dialer.DialTLS("tcp", "ooni.io:443") + if err == nil { + t.Fatal("expected an error here") + } + if conn != nil { + t.Fatal("expected a nil conn here") + } +} + +func TestDialerSetCABundleExisting(t *testing.T) { + dialer := NewDialer(time.Now(), handlers.NoHandler) + err := dialer.SetCABundle("../testdata/cacert.pem") + if err != nil { + t.Fatal(err) + } +} + +func TestDialerSetCABundleNonexisting(t *testing.T) { + dialer := NewDialer(time.Now(), handlers.NoHandler) + err := dialer.SetCABundle("../testdata/cacert-nonexistent.pem") + if err == nil { + t.Fatal("expected an error here") + } +} + +func TestDialerSetCABundleWAI(t *testing.T) { + dialer := NewDialer(time.Now(), handlers.NoHandler) + err := dialer.SetCABundle("../testdata/cacert.pem") + if err != nil { + t.Fatal(err) + } + conn, err := dialer.DialTLS("tcp", "www.google.com:443") + if err == nil { + t.Fatal("expected an error here") + } + var target x509.UnknownAuthorityError + if errors.As(err, &target) == false { + t.Fatal("not the error we expected") + } + if conn != nil { + t.Fatal("expected a nil conn here") + } +} + +func TestDialerForceSpecificSNI(t *testing.T) { + dialer := NewDialer(time.Now(), handlers.NoHandler) + err := dialer.ForceSpecificSNI("www.facebook.com") + conn, err := dialer.DialTLS("tcp", "www.google.com:443") + if err == nil { + t.Fatal("expected an error here") + } + var target x509.HostnameError + if errors.As(err, &target) == false { + t.Fatal("not the error we expected") + } + if conn != nil { + t.Fatal("expected a nil connection here") + } +} + +func testresolverquick(t *testing.T, network, address string) { + resolver, err := NewResolver(time.Now(), handlers.NoHandler, network, address) + if err != nil { + t.Fatal(err) + } + if resolver == nil { + t.Fatal("expected non-nil resolver here") + } + addrs, err := resolver.LookupHost(context.Background(), "dns.google.com") + if err != nil { + t.Fatal(err) + } + if addrs == nil { + t.Fatal("expected non-nil addrs here") + } + var foundquad8 bool + for _, addr := range addrs { + if addr == "8.8.8.8" { + foundquad8 = true + } + } + if !foundquad8 { + t.Fatal("did not find 8.8.8.8 in ouput") + } +} + +func TestIntegrationNewResolverUDPAddress(t *testing.T) { + testresolverquick(t, "udp", "8.8.8.8:53") +} + +func TestIntegrationNewResolverUDPAddressNoPort(t *testing.T) { + testresolverquick(t, "udp", "8.8.8.8") +} + +func TestIntegrationNewResolverUDPDomain(t *testing.T) { + testresolverquick(t, "udp", "dns.google.com:53") +} + +func TestIntegrationNewResolverUDPDomainNoPort(t *testing.T) { + testresolverquick(t, "udp", "dns.google.com") +} + +func TestIntegrationNewResolverSystem(t *testing.T) { + testresolverquick(t, "system", "") +} + +func TestIntegrationNewResolverTCPAddress(t *testing.T) { + testresolverquick(t, "tcp", "8.8.8.8:53") +} + +func TestIntegrationNewResolverTCPAddressNoPort(t *testing.T) { + testresolverquick(t, "tcp", "8.8.8.8") +} + +func TestIntegrationNewResolverTCPDomain(t *testing.T) { + testresolverquick(t, "tcp", "dns.google.com:53") +} + +func TestIntegrationNewResolverTCPDomainNoPort(t *testing.T) { + testresolverquick(t, "tcp", "dns.google.com") +} + +func TestIntegrationNewResolverDoTAddress(t *testing.T) { + testresolverquick(t, "dot", "9.9.9.9:853") +} + +func TestIntegrationNewResolverDoTAddressNoPort(t *testing.T) { + testresolverquick(t, "dot", "9.9.9.9") +} + +func TestIntegrationNewResolverDoTDomain(t *testing.T) { + testresolverquick(t, "dot", "dns.quad9.net:853") +} + +func TestIntegrationNewResolverDoTDomainNoPort(t *testing.T) { + testresolverquick(t, "dot", "dns.quad9.net") +} + +func TestIntegrationNewResolverDoH(t *testing.T) { + testresolverquick(t, "doh", "https://cloudflare-dns.com/dns-query") +} + +func TestIntegrationNewResolverInvalid(t *testing.T) { + resolver, err := NewResolver( + time.Now(), handlers.StdoutHandler, + "antani", "https://cloudflare-dns.com/dns-query", + ) + if err == nil { + t.Fatal("expected an error here") + } + if resolver != nil { + t.Fatal("expected a nil resolver here") + } +} + +func testconfigurednsquick(t *testing.T, network, address string) { + d := NewDialer(time.Now(), handlers.NoHandler) + err := d.ConfigureDNS(network, address) + if err != nil { + t.Fatal(err) + } + conn, err := d.DialTLS("tcp", "www.google.com:443") + if err != nil { + t.Fatal(err) + } + if conn == nil { + t.Fatal("expected non-nil conn here") + } + conn.Close() +} + +func TestIntegrationConfigureSystemDNS(t *testing.T) { + testconfigurednsquick(t, "system", "") +} + +func TestIntegrationHTTPTransport(t *testing.T) { + client := &http.Client{ + Transport: NewHTTPTransport( + time.Now(), handlers.NoHandler, + NewDialer(time.Now(), handlers.NoHandler), + false, + http.ProxyFromEnvironment, + ), + } + resp, err := client.Get("https://www.google.com") + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + client.CloseIdleConnections() +} + +func TestIntegrationHTTPTransportTimeout(t *testing.T) { + client := &http.Client{ + Transport: NewHTTPTransport( + time.Now(), handlers.NoHandler, + NewDialer(time.Now(), handlers.NoHandler), + false, + http.ProxyFromEnvironment, + ), + } + req, err := http.NewRequest("GET", "https://www.google.com", nil) + if err != nil { + t.Fatal(err) + } + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + req = req.WithContext(ctx) + resp, err := client.Do(req) + if err == nil { + t.Fatal("expected an error here") + } + if !strings.HasSuffix(err.Error(), "generic_timeout_error") { + t.Fatal("not the error we expected") + } + if resp != nil { + t.Fatal("expected nil resp here") + } +} +func TestIntegrationChainResolvers(t *testing.T) { + dialer := NewDialer(time.Now(), handlers.NoHandler) + resolver := ChainResolvers( + brokenresolver.New(), + systemresolver.New(new(net.Resolver)), + ) + dialer.SetResolver(resolver) + conn, err := dialer.Dial("tcp", "www.google.com:80") + if err != nil { + t.Fatal(err) + } + defer conn.Close() +} + +func TestIntegrationFailure(t *testing.T) { + client := &http.Client{ + Transport: NewHTTPTransport( + time.Now(), handlers.NoHandler, + NewDialer(time.Now(), handlers.NoHandler), + false, + http.ProxyFromEnvironment, + ), + } + // This fails the request because we attempt to speak cleartext HTTP with + // a server that instead is expecting TLS. + resp, err := client.Get("http://www.google.com:443") + if err == nil { + t.Fatal("expected an error here") + } + if resp != nil { + t.Fatal("expected a nil response here") + } + client.CloseIdleConnections() +} + +func TestLookupAddrWrapper(t *testing.T) { + resolver := newResolverWrapper(time.Now(), handlers.NoHandler, new(net.Resolver)) + names, err := resolver.LookupAddr(context.Background(), "8.8.8.8") + if err != nil { + t.Fatal(err) + } + if len(names) < 1 { + t.Fatal("unexpected result") + } +} + +func TestLookupCNAMEWrapper(t *testing.T) { + resolver := newResolverWrapper(time.Now(), handlers.NoHandler, new(net.Resolver)) + name, err := resolver.LookupCNAME(context.Background(), "www.google.com") + if err != nil { + t.Fatal(err) + } + if name == "" { + t.Fatal("unexpected result") + } +} + +func TestLookupMXWrapper(t *testing.T) { + resolver := newResolverWrapper(time.Now(), handlers.NoHandler, new(net.Resolver)) + entries, err := resolver.LookupMX(context.Background(), "google.com") + if err != nil { + t.Fatal(err) + } + if len(entries) < 1 { + t.Fatal("unexpected result") + } +} + +func TestLookupNSWrapper(t *testing.T) { + resolver := newResolverWrapper(time.Now(), handlers.NoHandler, new(net.Resolver)) + entries, err := resolver.LookupNS(context.Background(), "google.com") + if err != nil { + t.Fatal(err) + } + if len(entries) < 1 { + t.Fatal("unexpected result") + } +} + +func TestUnitNewHTTPClientForDoH(t *testing.T) { + first := newHTTPClientForDoH( + time.Now(), handlers.NoHandler, + ) + second := newHTTPClientForDoH( + time.Now(), handlers.NoHandler, + ) + if first != second { + t.Fatal("expected to see same client here") + } + third := newHTTPClientForDoH( + time.Now(), handlers.StdoutHandler, + ) + if first == third { + t.Fatal("expected to see different client here") + } +} diff --git a/netx/internal/resolver/bogondetector/bogondetector.go b/netx/internal/resolver/bogondetector/bogondetector.go new file mode 100644 index 00000000..f68a4e97 --- /dev/null +++ b/netx/internal/resolver/bogondetector/bogondetector.go @@ -0,0 +1,53 @@ +// Package bogondetector contains code to determine if an IP is private/bogon. The +// code was adapted from https://stackoverflow.com/a/50825191/4354461. +// +// See https://badpackets.net/hunting-for-bogons-and-the-isps-that-announce-them/ +// from which I have drawn the full list of private/bogons. +package bogondetector + +import ( + "net" + + "github.com/m-lab/go/rtx" +) + +var privateIPBlocks []*net.IPNet + +func init() { + for _, cidr := range []string{ + "0.0.0.0/8", // "This" network (however, Linux...) + "10.0.0.0/8", // RFC1918 + "100.64.0.0/10", // Carrier grade NAT + "127.0.0.0/8", // IPv4 loopback + "169.254.0.0/16", // RFC3927 link-local + "172.16.0.0/12", // RFC1918 + "192.168.0.0/16", // RFC1918 + "224.0.0.0/4", // Multicast + "::1/128", // IPv6 loopback + "fe80::/10", // IPv6 link-local + "fc00::/7", // IPv6 unique local addr + } { + _, block, err := net.ParseCIDR(cidr) + rtx.PanicOnError(err, "net.ParseCIDR failed") + privateIPBlocks = append(privateIPBlocks, block) + } +} + +func isPrivate(ip net.IP) bool { + if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { + return true + } + for _, block := range privateIPBlocks { + if block.Contains(ip) { + return true + } + } + return false +} + +// Check returns whether if an IP address is bogon. Passing to this +// function a non-IP address causes it to return bogon. +func Check(address string) bool { + ip := net.ParseIP(address) + return ip == nil || isPrivate(ip) +} diff --git a/netx/internal/resolver/bogondetector/bogondetector_test.go b/netx/internal/resolver/bogondetector/bogondetector_test.go new file mode 100644 index 00000000..2d9ec9d3 --- /dev/null +++ b/netx/internal/resolver/bogondetector/bogondetector_test.go @@ -0,0 +1,18 @@ +package bogondetector + +import "testing" + +func TestIntegration(t *testing.T) { + if Check("antani") != true { + t.Fatal("unexpected result") + } + if Check("127.0.0.1") != true { + t.Fatal("unexpected result") + } + if Check("1.1.1.1") != false { + t.Fatal("unexpected result") + } + if Check("10.0.1.1") != true { + t.Fatal("unexpected result") + } +} diff --git a/netx/internal/resolver/brokenresolver/brokenresolver.go b/netx/internal/resolver/brokenresolver/brokenresolver.go new file mode 100644 index 00000000..92c8ac03 --- /dev/null +++ b/netx/internal/resolver/brokenresolver/brokenresolver.go @@ -0,0 +1,44 @@ +// Package brokenresolver is a broken resolver +package brokenresolver + +import ( + "context" + "net" +) + +// Resolver is a broken resolver. +type Resolver struct{} + +// New creates a new broken Resolver instance. +func New() *Resolver { + return &Resolver{} +} + +var errNotFound = &net.DNSError{ + Err: "no such host", +} + +// LookupAddr returns the name of the provided IP address +func (c *Resolver) LookupAddr(ctx context.Context, addr string) ([]string, error) { + return nil, errNotFound +} + +// LookupCNAME returns the canonical name of a host +func (c *Resolver) LookupCNAME(ctx context.Context, host string) (string, error) { + return "", errNotFound +} + +// LookupHost returns the IP addresses of a host +func (c *Resolver) LookupHost(ctx context.Context, hostname string) ([]string, error) { + return nil, errNotFound +} + +// LookupMX returns the MX records of a specific name +func (c *Resolver) LookupMX(ctx context.Context, name string) ([]*net.MX, error) { + return nil, errNotFound +} + +// LookupNS returns the NS records of a specific name +func (c *Resolver) LookupNS(ctx context.Context, name string) ([]*net.NS, error) { + return nil, errNotFound +} diff --git a/netx/internal/resolver/brokenresolver/brokenresolver_test.go b/netx/internal/resolver/brokenresolver/brokenresolver_test.go new file mode 100644 index 00000000..9a7c935e --- /dev/null +++ b/netx/internal/resolver/brokenresolver/brokenresolver_test.go @@ -0,0 +1,61 @@ +package brokenresolver + +import ( + "context" + "testing" +) + +func TestLookupAddr(t *testing.T) { + client := New() + names, err := client.LookupAddr(context.Background(), "8.8.8.8") + if err == nil { + t.Fatal("expected an error here") + } + if names != nil { + t.Fatal("expected nil here") + } +} + +func TestLookupCNAME(t *testing.T) { + client := New() + cname, err := client.LookupCNAME(context.Background(), "www.ooni.io") + if err == nil { + t.Fatal("expected an error here") + } + if cname != "" { + t.Fatal("expected empty string here") + } +} + +func TestLookupHost(t *testing.T) { + client := New() + addrs, err := client.LookupHost(context.Background(), "www.google.com") + if err == nil { + t.Fatal("expected an error here") + } + if addrs != nil { + t.Fatal("expected nil here") + } +} + +func TestLookupMX(t *testing.T) { + client := New() + records, err := client.LookupMX(context.Background(), "ooni.io") + if err == nil { + t.Fatal("expected an error here") + } + if records != nil { + t.Fatal("expected nil here") + } +} + +func TestLookupNS(t *testing.T) { + client := New() + records, err := client.LookupNS(context.Background(), "ooni.io") + if err == nil { + t.Fatal("expected an error here") + } + if records != nil { + t.Fatal("expected nil here") + } +} diff --git a/netx/internal/resolver/chainresolver/chainresolver.go b/netx/internal/resolver/chainresolver/chainresolver.go new file mode 100644 index 00000000..96fe9386 --- /dev/null +++ b/netx/internal/resolver/chainresolver/chainresolver.go @@ -0,0 +1,68 @@ +// Package chainresolver allows to chain two resolvers +package chainresolver + +import ( + "context" + "net" + + "github.com/ooni/probe-engine/netx/modelx" +) + +// Resolver is a chain resolver. +type Resolver struct { + primary modelx.DNSResolver + secondary modelx.DNSResolver +} + +// New creates a new chain Resolver instance. +func New(primary, secondary modelx.DNSResolver) *Resolver { + return &Resolver{ + primary: primary, + secondary: secondary, + } +} + +// LookupAddr returns the name of the provided IP address +func (c *Resolver) LookupAddr(ctx context.Context, addr string) ([]string, error) { + names, err := c.primary.LookupAddr(ctx, addr) + if err != nil { + names, err = c.secondary.LookupAddr(ctx, addr) + } + return names, err +} + +// LookupCNAME returns the canonical name of a host +func (c *Resolver) LookupCNAME(ctx context.Context, host string) (string, error) { + cname, err := c.primary.LookupCNAME(ctx, host) + if err != nil { + cname, err = c.secondary.LookupCNAME(ctx, host) + } + return cname, err +} + +// LookupHost returns the IP addresses of a host +func (c *Resolver) LookupHost(ctx context.Context, hostname string) ([]string, error) { + addrs, err := c.primary.LookupHost(ctx, hostname) + if err != nil { + addrs, err = c.secondary.LookupHost(ctx, hostname) + } + return addrs, err +} + +// LookupMX returns the MX records of a specific name +func (c *Resolver) LookupMX(ctx context.Context, name string) ([]*net.MX, error) { + records, err := c.primary.LookupMX(ctx, name) + if err != nil { + records, err = c.secondary.LookupMX(ctx, name) + } + return records, err +} + +// LookupNS returns the NS records of a specific name +func (c *Resolver) LookupNS(ctx context.Context, name string) ([]*net.NS, error) { + records, err := c.primary.LookupNS(ctx, name) + if err != nil { + records, err = c.secondary.LookupNS(ctx, name) + } + return records, err +} diff --git a/netx/internal/resolver/chainresolver/chainresolver_test.go b/netx/internal/resolver/chainresolver/chainresolver_test.go new file mode 100644 index 00000000..31b69074 --- /dev/null +++ b/netx/internal/resolver/chainresolver/chainresolver_test.go @@ -0,0 +1,64 @@ +package chainresolver + +import ( + "context" + "net" + "testing" + + "github.com/ooni/probe-engine/netx/internal/resolver/brokenresolver" +) + +func TestLookupAddr(t *testing.T) { + client := New(brokenresolver.New(), new(net.Resolver)) + names, err := client.LookupAddr(context.Background(), "8.8.8.8") + if err != nil { + t.Fatal(err) + } + if names == nil { + t.Fatal("expect non nil return value here") + } +} + +func TestLookupCNAME(t *testing.T) { + client := New(brokenresolver.New(), new(net.Resolver)) + cname, err := client.LookupCNAME(context.Background(), "www.ooni.io") + if err != nil { + t.Fatal(err) + } + if cname == "" { + t.Fatal("expect non empty return value here") + } +} + +func TestLookupHost(t *testing.T) { + client := New(brokenresolver.New(), new(net.Resolver)) + addrs, err := client.LookupHost(context.Background(), "www.google.com") + if err != nil { + t.Fatal(err) + } + if addrs == nil { + t.Fatal("expect non nil return value here") + } +} + +func TestLookupMX(t *testing.T) { + client := New(brokenresolver.New(), new(net.Resolver)) + records, err := client.LookupMX(context.Background(), "ooni.io") + if err != nil { + t.Fatal(err) + } + if records == nil { + t.Fatal("expect non nil return value here") + } +} + +func TestLookupNS(t *testing.T) { + client := New(brokenresolver.New(), new(net.Resolver)) + records, err := client.LookupNS(context.Background(), "ooni.io") + if err != nil { + t.Fatal(err) + } + if records == nil { + t.Fatal("expect non nil return value here") + } +} diff --git a/netx/internal/resolver/dnstransport/dnsoverhttps/dnsoverhttps.go b/netx/internal/resolver/dnstransport/dnsoverhttps/dnsoverhttps.go new file mode 100644 index 00000000..3d5642a7 --- /dev/null +++ b/netx/internal/resolver/dnstransport/dnsoverhttps/dnsoverhttps.go @@ -0,0 +1,69 @@ +// Package dnsoverhttps implements DNS over HTTPS. +package dnsoverhttps + +import ( + "bytes" + "context" + "errors" + "io/ioutil" + "net/http" +) + +// Transport is a DNS over HTTPS modelx.DNSRoundTripper. +// +// As a known bug, this implementation does not cache the domain +// name in the URL for reuse, but this should be easy to fix. +type Transport struct { + clientDo func(req *http.Request) (*http.Response, error) + url string +} + +// NewTransport creates a new Transport +func NewTransport(client *http.Client, URL string) *Transport { + return &Transport{ + clientDo: client.Do, + url: URL, + } +} + +// RoundTrip sends a request and receives a response. +func (t *Transport) RoundTrip(ctx context.Context, query []byte) (reply []byte, err error) { + req, err := http.NewRequest("POST", t.url, bytes.NewReader(query)) + if err != nil { + return nil, err + } + req.Header.Set("content-type", "application/dns-message") + var resp *http.Response + resp, err = t.clientDo(req.WithContext(ctx)) + if err != nil { + return + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + // TODO(bassosimone): we should map the status code to a + // proper Error in the DNS context. + err = errors.New("doh: server returned error") + return + } + if resp.Header.Get("content-type") != "application/dns-message" { + err = errors.New("doh: invalid content-type") + return + } + reply, err = ioutil.ReadAll(resp.Body) + return +} + +// RequiresPadding returns true for DoH according to RFC8467 +func (t *Transport) RequiresPadding() bool { + return true +} + +// Network returns the transport network (e.g., doh, dot) +func (t *Transport) Network() string { + return "doh" +} + +// Address returns the upstream server address. +func (t *Transport) Address() string { + return t.url +} diff --git a/netx/internal/resolver/dnstransport/dnsoverhttps/dnsoverhttps_test.go b/netx/internal/resolver/dnstransport/dnsoverhttps/dnsoverhttps_test.go new file mode 100644 index 00000000..22b8984f --- /dev/null +++ b/netx/internal/resolver/dnstransport/dnsoverhttps/dnsoverhttps_test.go @@ -0,0 +1,114 @@ +package dnsoverhttps + +import ( + "context" + "errors" + "io/ioutil" + "net/http" + "strings" + "testing" + + "github.com/miekg/dns" +) + +func TestIntegrationSuccess(t *testing.T) { + const queryURL = "https://cloudflare-dns.com/dns-query" + transport := NewTransport( + http.DefaultClient, queryURL, + ) + if transport.Network() != "doh" { + t.Fatal("invalid network") + } + if transport.Address() != queryURL { + t.Fatal("invalid address") + } + err := threeRounds(transport) + if err != nil { + t.Fatal(err) + } +} + +func TestIntegrationNewRequestFailure(t *testing.T) { + transport := NewTransport( + http.DefaultClient, "\t", // invalid URL + ) + err := threeRounds(transport) + if err == nil { + t.Fatal("expected an error here") + } +} + +func TestIntegrationClientDoFailure(t *testing.T) { + transport := NewTransport( + http.DefaultClient, "https://cloudflare-dns.com/dns-query", + ) + transport.clientDo = func(*http.Request) (*http.Response, error) { + return nil, errors.New("mocked error") + } + err := threeRounds(transport) + if err == nil { + t.Fatal("expected an error here") + } +} + +func TestIntegrationHTTPFailure(t *testing.T) { + transport := NewTransport( + http.DefaultClient, "https://cloudflare-dns.com/dns-query", + ) + transport.clientDo = func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 500, + Body: ioutil.NopCloser(strings.NewReader("")), + }, nil + } + err := threeRounds(transport) + if err == nil { + t.Fatal("expected an error here") + } +} + +func TestIntegrationMissingHeader(t *testing.T) { + transport := NewTransport( + http.DefaultClient, "https://cloudflare-dns.com/dns-query", + ) + transport.clientDo = func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader("")), + }, nil + } + err := threeRounds(transport) + if err == nil { + t.Fatal("expected an error here") + } +} + +func threeRounds(transport *Transport) error { + err := roundTrip(transport, "ooni.io.") + if err != nil { + return err + } + err = roundTrip(transport, "slashdot.org.") + if err != nil { + return err + } + err = roundTrip(transport, "kernel.org.") + if err != nil { + return err + } + return nil +} + +func roundTrip(transport *Transport, domain string) error { + query := new(dns.Msg) + query.SetQuestion(domain, dns.TypeA) + data, err := query.Pack() + if err != nil { + return err + } + data, err = transport.RoundTrip(context.Background(), data) + if err != nil { + return err + } + return query.Unpack(data) +} diff --git a/netx/internal/resolver/dnstransport/dnsovertcp/dnsovertcp.go b/netx/internal/resolver/dnstransport/dnsovertcp/dnsovertcp.go new file mode 100644 index 00000000..1c7fe44c --- /dev/null +++ b/netx/internal/resolver/dnstransport/dnsovertcp/dnsovertcp.go @@ -0,0 +1,136 @@ +// Package dnsovertcp implements DNS over TCP. It is possible to +// use both plaintext TCP and TLS. +package dnsovertcp + +import ( + "bufio" + "context" + "io" + "net" + "time" + + "github.com/m-lab/go/rtx" + "github.com/ooni/probe-engine/netx/modelx" +) + +// Transport is a DNS over TCP/TLS modelx.DNSRoundTripper. +// +// As a known bug, this implementation always creates a new connection +// for each incoming query, thus increasing the response delay. +type Transport struct { + dialer dialerAdapter + address string + requiresPadding bool +} + +type dialerAdapter interface { + modelx.Dialer + Network() string +} + +// NewTransportTCP creates a new TCP Transport +func NewTransportTCP(dialer modelx.Dialer, address string) *Transport { + return &Transport{ + dialer: newTCPDialerAdapter(dialer), + address: address, + requiresPadding: false, + } +} + +// NewTransportTLS creates a new TLS Transport +func NewTransportTLS(dialer modelx.TLSDialer, address string) *Transport { + return &Transport{ + dialer: newTLSDialerAdapter(dialer), + address: address, + requiresPadding: true, + } +} + +// RoundTrip sends a request and receives a response. +func (t *Transport) RoundTrip(ctx context.Context, query []byte) ([]byte, error) { + conn, err := t.dialer.DialContext(ctx, "tcp", t.address) + if err != nil { + return nil, err + } + defer conn.Close() + return t.doWithConn(conn, query) +} + +// RequiresPadding returns true for DoT and false for TCP +// according to RFC8467. +func (t *Transport) RequiresPadding() bool { + return t.requiresPadding +} + +func (t *Transport) doWithConn(conn net.Conn, query []byte) (reply []byte, err error) { + defer func() { + if r := recover(); r != nil { + reply = nil // we already got the error just clear the reply + } + }() + err = conn.SetDeadline(time.Now().Add(10 * time.Second)) + rtx.PanicOnError(err, "conn.SetDeadline failed") + // Write request + writer := bufio.NewWriter(conn) + err = writer.WriteByte(byte(len(query) >> 8)) + rtx.PanicOnError(err, "writer.WriteByte failed for first byte") + err = writer.WriteByte(byte(len(query))) + rtx.PanicOnError(err, "writer.WriteByte failed for second byte") + _, err = writer.Write(query) + rtx.PanicOnError(err, "writer.Write failed for query") + err = writer.Flush() + rtx.PanicOnError(err, "writer.Flush failed") + // Read response + header := make([]byte, 2) + _, err = io.ReadFull(conn, header) + rtx.PanicOnError(err, "io.ReadFull failed") + length := int(header[0])<<8 | int(header[1]) + reply = make([]byte, length) + _, err = io.ReadFull(conn, reply) + rtx.PanicOnError(err, "io.ReadFull failed") + return reply, nil +} + +type tlsDialerAdapter struct { + dialer modelx.TLSDialer +} + +func newTLSDialerAdapter(dialer modelx.TLSDialer) *tlsDialerAdapter { + return &tlsDialerAdapter{dialer: dialer} +} + +func (d *tlsDialerAdapter) Dial(network, address string) (net.Conn, error) { + return d.dialer.DialTLS(network, address) +} + +func (d *tlsDialerAdapter) DialContext( + ctx context.Context, network, address string, +) (net.Conn, error) { + return d.dialer.DialTLSContext(ctx, network, address) +} + +func (d *tlsDialerAdapter) Network() string { + return "dot" +} + +type tcpDialerAdapter struct { + modelx.Dialer +} + +func newTCPDialerAdapter(dialer modelx.Dialer) *tcpDialerAdapter { + return &tcpDialerAdapter{Dialer: dialer} +} + +func (d *tcpDialerAdapter) Network() string { + return "tcp" +} + +// Network returns the transport network (e.g., doh, dot) +func (t *Transport) Network() string { + return t.dialer.Network() +} + +// Address returns the upstream server address. +func (t *Transport) Address() string { + return t.address +} diff --git a/netx/internal/resolver/dnstransport/dnsovertcp/dnsovertcp_test.go b/netx/internal/resolver/dnstransport/dnsovertcp/dnsovertcp_test.go new file mode 100644 index 00000000..d5b2d72c --- /dev/null +++ b/netx/internal/resolver/dnstransport/dnsovertcp/dnsovertcp_test.go @@ -0,0 +1,192 @@ +package dnsovertcp + +import ( + "context" + "crypto/tls" + "errors" + "net" + "testing" + "time" + + "github.com/miekg/dns" +) + +type tlsdialer struct { + config *tls.Config +} + +func (d *tlsdialer) DialTLS(network, address string) (net.Conn, error) { + return d.DialTLSContext(context.Background(), network, address) +} + +func (d *tlsdialer) DialTLSContext( + ctx context.Context, network, address string, +) (net.Conn, error) { + return tls.Dial(network, address, d.config) +} + +func TestIntegrationSuccessTLS(t *testing.T) { + // "Dial interprets a nil configuration as equivalent to + // the zero configuration; see the documentation of Config + // for the defaults." + address := "dns.quad9.net:853" + transport := NewTransportTLS(&tlsdialer{}, address) + if transport.Network() != "dot" { + t.Fatal("unexpected network") + } + if transport.Address() != address { + t.Fatal("unexpected address") + } + if err := threeRounds(transport); err != nil { + t.Fatal(err) + } +} + +func TestIntegrationSuccessTCP(t *testing.T) { + address := "9.9.9.9:53" + transport := NewTransportTCP(&net.Dialer{}, address) + if transport.Network() != "tcp" { + t.Fatal("unexpected network") + } + if transport.Address() != address { + t.Fatal("unexpected address") + } + if err := threeRounds(transport); err != nil { + t.Fatal(err) + } +} + +func TestIntegrationLookupHostError(t *testing.T) { + transport := NewTransportTCP(&net.Dialer{}, "antani.local") + if err := roundTrip(transport, "ooni.io."); err == nil { + t.Fatal("expected an error here") + } +} + +func TestIntegrationCustomTLSConfig(t *testing.T) { + transport := NewTransportTLS(&tlsdialer{ + config: &tls.Config{ + MinVersion: tls.VersionTLS10, + }, + }, "dns.quad9.net:853") + if err := roundTrip(transport, "ooni.io."); err != nil { + t.Fatal(err) + } +} + +func TestUnitRoundTripWithConnFailure(t *testing.T) { + // fakeconn will fail in the SetDeadline, therefore we will have + // an immediate error and we expect all errors the be alike + transport := NewTransportTCP(&fakeconnDialer{}, "8.8.8.8:53") + query := make([]byte, 1<<10) + reply, err := transport.doWithConn(&fakeconn{}, query) + if err == nil { + t.Fatal("expected an error here") + } + if reply != nil { + t.Fatal("expected nil error here") + } +} + +func threeRounds(transport *Transport) error { + err := roundTrip(transport, "ooni.io.") + if err != nil { + return err + } + err = roundTrip(transport, "slashdot.org.") + if err != nil { + return err + } + err = roundTrip(transport, "kernel.org.") + if err != nil { + return err + } + return nil +} + +func roundTrip(transport *Transport, domain string) error { + query := new(dns.Msg) + query.SetQuestion(domain, dns.TypeA) + data, err := query.Pack() + if err != nil { + return err + } + data, err = transport.RoundTrip(context.Background(), data) + if err != nil { + return err + } + return query.Unpack(data) +} + +type fakeconnDialer struct { + fakeconn fakeconn +} + +func (d *fakeconnDialer) Dial(network, address string) (net.Conn, error) { + return d.DialContext(context.Background(), network, address) +} + +func (d *fakeconnDialer) DialContext( + ctx context.Context, network, address string, +) (net.Conn, error) { + return &d.fakeconn, nil +} + +type fakeconn struct{} + +func (fakeconn) Read(b []byte) (n int, err error) { + n = len(b) + return +} +func (fakeconn) Write(b []byte) (n int, err error) { + n = len(b) + return +} +func (fakeconn) Close() (err error) { + return +} +func (fakeconn) LocalAddr() net.Addr { + return &net.TCPAddr{} +} +func (fakeconn) RemoteAddr() net.Addr { + return &net.TCPAddr{} +} +func (fakeconn) SetDeadline(t time.Time) (err error) { + return errors.New("cannot set deadline") +} +func (fakeconn) SetReadDeadline(t time.Time) (err error) { + return +} +func (fakeconn) SetWriteDeadline(t time.Time) (err error) { + return +} + +func TestTLSDialerAdapter(t *testing.T) { + fake := &fakeTLSDialer{} + adapter := newTLSDialerAdapter(fake) + adapter.Dial("tcp", "www.google.com:443") + if !fake.calledDialTLS { + t.Fatal("redirection to DialTLS not working") + } + adapter.DialContext(context.Background(), "tcp", "www.google.com:443") + if !fake.calledDialTLSContext { + t.Fatal("redirection to DialTLSContext not working") + } +} + +type fakeTLSDialer struct { + calledDialTLS bool + calledDialTLSContext bool +} + +func (d *fakeTLSDialer) DialTLS(network, address string) (net.Conn, error) { + d.calledDialTLS = true + return nil, errors.New("mocked error") +} + +func (d *fakeTLSDialer) DialTLSContext( + ctx context.Context, network, address string, +) (net.Conn, error) { + d.calledDialTLSContext = true + return nil, errors.New("mocked error") +} diff --git a/netx/internal/resolver/dnstransport/dnsoverudp/dnsoverudp.go b/netx/internal/resolver/dnstransport/dnsoverudp/dnsoverudp.go new file mode 100644 index 00000000..ad7eab69 --- /dev/null +++ b/netx/internal/resolver/dnstransport/dnsoverudp/dnsoverudp.go @@ -0,0 +1,64 @@ +// Package dnsoverudp implements DNS over UDP. +package dnsoverudp + +import ( + "context" + "time" + + "github.com/ooni/probe-engine/netx/modelx" +) + +// Transport is a DNS over UDP modelx.DNSRoundTripper. +type Transport struct { + dialer modelx.Dialer + address string +} + +// NewTransport creates a new Transport +func NewTransport(dialer modelx.Dialer, address string) *Transport { + return &Transport{ + dialer: dialer, + address: address, + } +} + +// RoundTrip sends a request and receives a response. +func (t *Transport) RoundTrip(ctx context.Context, query []byte) (reply []byte, err error) { + conn, err := t.dialer.DialContext(ctx, "udp", t.address) + if err != nil { + return + } + defer conn.Close() + // Use five seconds timeout like Bionic does. See + // https://labs.ripe.net/Members/baptiste_jonglez_1/persistent-dns-connections-for-reliability-and-performance + err = conn.SetDeadline(time.Now().Add(5 * time.Second)) + if err != nil { + return + } + _, err = conn.Write(query) + if err != nil { + return + } + reply = make([]byte, 1<<17) + var n int + n, err = conn.Read(reply) + if err == nil { + reply = reply[:n] + } + return +} + +// RequiresPadding returns false for UDP according to RFC8467 +func (t *Transport) RequiresPadding() bool { + return false +} + +// Network returns the transport network (e.g., doh, dot) +func (t *Transport) Network() string { + return "udp" +} + +// Address returns the upstream server address. +func (t *Transport) Address() string { + return t.address +} diff --git a/netx/internal/resolver/dnstransport/dnsoverudp/dnsoverudp_test.go b/netx/internal/resolver/dnstransport/dnsoverudp/dnsoverudp_test.go new file mode 100644 index 00000000..162f4f30 --- /dev/null +++ b/netx/internal/resolver/dnstransport/dnsoverudp/dnsoverudp_test.go @@ -0,0 +1,164 @@ +package dnsoverudp + +import ( + "context" + "errors" + "net" + "testing" + "time" + + "github.com/miekg/dns" +) + +func TestIntegrationSuccessWithAddress(t *testing.T) { + const address = "9.9.9.9:53" + transport := NewTransport( + &net.Dialer{}, address, + ) + if transport.Network() != "udp" { + t.Fatal("invalid network") + } + if transport.Address() != address { + t.Fatal("invalid address") + } + err := threeRounds(transport) + if err != nil { + t.Fatal(err) + } +} + +func TestIntegrationSuccessWithDomain(t *testing.T) { + transport := NewTransport( + &net.Dialer{}, "dns.quad9.net:53", + ) + err := threeRounds(transport) + if err != nil { + t.Fatal(err) + } +} + +func TestIntegrationDialFailure(t *testing.T) { + transport := NewTransport( + &failingDialer{}, "9.9.9.9:53", + ) + err := threeRounds(transport) + if err == nil { + t.Fatal("expected an error here") + } +} + +func TestIntegrationSetDeadlineError(t *testing.T) { + transport := NewTransport( + &fakeconnDialer{ + fakeconn: fakeconn{ + setDeadlineError: errors.New("mocked error"), + }, + }, "9.9.9.9:53", + ) + err := threeRounds(transport) + if err == nil { + t.Fatal("expected an error here") + } +} + +func TestIntegrationWriteError(t *testing.T) { + transport := NewTransport( + &fakeconnDialer{ + fakeconn: fakeconn{ + writeError: errors.New("mocked error"), + }, + }, "9.9.9.9:53", + ) + err := threeRounds(transport) + if err == nil { + t.Fatal("expected an error here") + } +} + +func threeRounds(transport *Transport) error { + err := roundTrip(transport, "ooni.io.") + if err != nil { + return err + } + err = roundTrip(transport, "slashdot.org.") + if err != nil { + return err + } + err = roundTrip(transport, "kernel.org.") + if err != nil { + return err + } + return nil +} + +func roundTrip(transport *Transport, domain string) error { + query := new(dns.Msg) + query.SetQuestion(domain, dns.TypeA) + data, err := query.Pack() + if err != nil { + return err + } + data, err = transport.RoundTrip(context.Background(), data) + if err != nil { + return err + } + return query.Unpack(data) +} + +type failingDialer struct{} + +func (d *failingDialer) Dial(network, address string) (net.Conn, error) { + return d.DialContext(context.Background(), network, address) +} + +func (d *failingDialer) DialContext( + ctx context.Context, network, address string, +) (net.Conn, error) { + return nil, errors.New("mocked error") +} + +type fakeconnDialer struct { + fakeconn fakeconn +} + +func (d *fakeconnDialer) Dial(network, address string) (net.Conn, error) { + return d.DialContext(context.Background(), network, address) +} + +func (d *fakeconnDialer) DialContext( + ctx context.Context, network, address string, +) (net.Conn, error) { + return &d.fakeconn, nil +} + +type fakeconn struct { + setDeadlineError error + writeError error +} + +func (fakeconn) Read(b []byte) (n int, err error) { + n = len(b) + return +} +func (c fakeconn) Write(b []byte) (n int, err error) { + n, err = len(b), c.writeError + return +} +func (fakeconn) Close() (err error) { + return +} +func (fakeconn) LocalAddr() net.Addr { + return &net.TCPAddr{} +} +func (fakeconn) RemoteAddr() net.Addr { + return &net.TCPAddr{} +} +func (c fakeconn) SetDeadline(t time.Time) error { + return c.setDeadlineError +} +func (c fakeconn) SetReadDeadline(t time.Time) error { + return c.SetDeadline(t) +} +func (c fakeconn) SetWriteDeadline(t time.Time) error { + return c.SetDeadline(t) +} diff --git a/netx/internal/resolver/ooniresolver/ooniresolver.go b/netx/internal/resolver/ooniresolver/ooniresolver.go new file mode 100644 index 00000000..927dfdbc --- /dev/null +++ b/netx/internal/resolver/ooniresolver/ooniresolver.go @@ -0,0 +1,233 @@ +// Package ooniresolver is OONI's DNS resolver. +package ooniresolver + +import ( + "context" + "errors" + "net" + "sync/atomic" + "time" + + "github.com/miekg/dns" + "github.com/ooni/probe-engine/netx/internal/dialid" + "github.com/ooni/probe-engine/netx/modelx" +) + +// Resolver is OONI's DNS client. It is a simplistic client where we +// manually create and submit queries. It can use all the transports +// for DNS supported by this library, however. +type Resolver struct { + ntimeouts int64 + transport modelx.DNSRoundTripper +} + +// New creates a new OONI Resolver instance. +func New(t modelx.DNSRoundTripper) *Resolver { + return &Resolver{transport: t} +} + +// Transport returns the transport being used. +func (c *Resolver) Transport() modelx.DNSRoundTripper { + return c.transport +} + +var errNotImpl = errors.New("Not implemented") + +// LookupAddr returns the name of the provided IP address +func (c *Resolver) LookupAddr(ctx context.Context, addr string) (names []string, err error) { + err = errNotImpl + return +} + +// LookupCNAME returns the canonical name of a host +func (c *Resolver) LookupCNAME(ctx context.Context, host string) (cname string, err error) { + err = errNotImpl + return +} + +// LookupHost returns the IP addresses of a host +func (c *Resolver) LookupHost(ctx context.Context, hostname string) ([]string, error) { + var addrs []string + var reply *dns.Msg + reply, errA := c.roundTripWithRetry(ctx, hostname, dns.TypeA) + if errA == nil { + for _, answer := range reply.Answer { + if rra, ok := answer.(*dns.A); ok { + ip := rra.A + addrs = append(addrs, ip.String()) + } + } + } + reply, errAAAA := c.roundTripWithRetry(ctx, hostname, dns.TypeAAAA) + if errAAAA == nil { + for _, answer := range reply.Answer { + if rra, ok := answer.(*dns.AAAA); ok { + ip := rra.AAAA + addrs = append(addrs, ip.String()) + } + } + } + return lookupHostResult(addrs, errA, errAAAA) +} + +func lookupHostResult(addrs []string, errA, errAAAA error) ([]string, error) { + if len(addrs) > 0 { + return addrs, nil + } + if errA != nil { + return nil, errA + } + if errAAAA != nil { + return nil, errAAAA + } + return nil, errors.New("ooniresolver: no response returned") +} + +// LookupMX returns the MX records of a specific name +func (c *Resolver) LookupMX(ctx context.Context, name string) (mx []*net.MX, err error) { + err = errNotImpl + return +} + +// LookupNS returns the NS records of a specific name +func (c *Resolver) LookupNS(ctx context.Context, name string) (ns []*net.NS, err error) { + err = errNotImpl + return +} + +const ( + // desiredBlockSize is the size that the padded query should be multiple of + desiredBlockSize = 128 + + // maxResponseSize is the maximum response size for EDNS0 + maxResponseSize = 4096 + + // dnssecEnabled turns on support for DNSSEC when using EDNS0 + dnssecEnabled = true +) + +func (c *Resolver) newQueryWithQuestion(q dns.Question, needspadding bool) (query *dns.Msg) { + query = new(dns.Msg) + query.Id = dns.Id() + query.RecursionDesired = true + query.Question = make([]dns.Question, 1) + query.Question[0] = q + if needspadding { + query.SetEdns0(maxResponseSize, dnssecEnabled) + // Clients SHOULD pad queries to the closest multiple of + // 128 octets RFC8467#section-4.1. We inflate the query + // length by the size of the option (i.e. 4 octets). The + // cast to uint is necessary to make the modulus operation + // work as intended when the desiredBlockSize is smaller + // than (query.Len()+4) ¯\_(ツ)_/¯. + remainder := (desiredBlockSize - uint(query.Len()+4)) % desiredBlockSize + opt := new(dns.EDNS0_PADDING) + opt.Padding = make([]byte, remainder) + query.IsEdns0().Option = append(query.IsEdns0().Option, opt) + } + return +} + +func (c *Resolver) roundTripWithRetry( + ctx context.Context, hostname string, qtype uint16, +) (*dns.Msg, error) { + var errorslist []error + for i := 0; i < 3; i++ { + reply, err := c.roundTrip(ctx, c.newQueryWithQuestion(dns.Question{ + Name: dns.Fqdn(hostname), + Qtype: qtype, + Qclass: dns.ClassINET, + }, c.Transport().RequiresPadding())) + if err == nil { + return reply, nil + } + errorslist = append(errorslist, err) + var operr *net.OpError + if errors.As(err, &operr) == false || operr.Timeout() == false { + // The first error is the one that is most likely to be caused + // by the network. Subsequent errors are more likely to be caused + // by context deadlines. So, the first error is attached to an + // operation, while subsequent errors may possibly not be. If + // so, the resulting failing operation is not correct. + break + } + atomic.AddInt64(&c.ntimeouts, 1) + } + // bugfix: we MUST return one of the errors otherwise we confuse the + // mechanism in errwrap that classifies the root cause operation, since + // it would not be able to find a child with a major operation error + return nil, errorslist[0] +} + +func (c *Resolver) roundTrip(ctx context.Context, query *dns.Msg) (reply *dns.Msg, err error) { + return c.mockableRoundTrip( + ctx, query, func(msg *dns.Msg) ([]byte, error) { + return msg.Pack() + }, + func(t modelx.DNSRoundTripper, query []byte) (reply []byte, err error) { + // Pass ctx to round tripper for cancellation as well + // as to propagate context information + return t.RoundTrip(ctx, query) + }, + func(msg *dns.Msg, data []byte) (err error) { + return msg.Unpack(data) + }, + ) +} + +func (c *Resolver) mockableRoundTrip( + ctx context.Context, + query *dns.Msg, + pack func(msg *dns.Msg) ([]byte, error), + roundTrip func(t modelx.DNSRoundTripper, query []byte) (reply []byte, err error), + unpack func(msg *dns.Msg, data []byte) (err error), +) (reply *dns.Msg, err error) { + var ( + querydata []byte + replydata []byte + ) + querydata, err = pack(query) + if err != nil { + return + } + root := modelx.ContextMeasurementRootOrDefault(ctx) + root.Handler.OnMeasurement(modelx.Measurement{ + DNSQuery: &modelx.DNSQueryEvent{ + Data: querydata, + DialID: dialid.ContextDialID(ctx), + DurationSinceBeginning: time.Now().Sub(root.Beginning), + Msg: query, + }, + }) + replydata, err = roundTrip(c.transport, querydata) + if err != nil { + return + } + reply = new(dns.Msg) + err = unpack(reply, replydata) + if err != nil { + return + } + root.Handler.OnMeasurement(modelx.Measurement{ + DNSReply: &modelx.DNSReplyEvent{ + Data: replydata, + DialID: dialid.ContextDialID(ctx), + DurationSinceBeginning: time.Now().Sub(root.Beginning), + Msg: reply, + }, + }) + err = mapError(reply.Rcode) + return +} + +func mapError(rcode int) error { + // TODO(bassosimone): map more errors to net.DNSError names + switch rcode { + case dns.RcodeSuccess: + return nil + case dns.RcodeNameError: + return errors.New("ooniresolver: no such host") + default: + return errors.New("ooniresolver: query failed") + } +} diff --git a/netx/internal/resolver/ooniresolver/ooniresolver_test.go b/netx/internal/resolver/ooniresolver/ooniresolver_test.go new file mode 100644 index 00000000..e7ff1820 --- /dev/null +++ b/netx/internal/resolver/ooniresolver/ooniresolver_test.go @@ -0,0 +1,278 @@ +package ooniresolver + +import ( + "context" + "errors" + "net" + "strings" + "testing" + + "github.com/miekg/dns" + "github.com/ooni/probe-engine/netx/internal/resolver/dnstransport/dnsovertcp" + "github.com/ooni/probe-engine/netx/internal/resolver/dnstransport/dnsoverudp" + "github.com/ooni/probe-engine/netx/modelx" +) + +func newtransport() modelx.DNSRoundTripper { + return dnsovertcp.NewTransportTCP(&net.Dialer{}, "dns.quad9.net:53") +} + +func TestGettingTransport(t *testing.T) { + transport := newtransport() + client := New(transport) + if transport != client.Transport() { + t.Fatal("the transport is not correctly set") + } +} + +func TestLookupAddr(t *testing.T) { + client := New(newtransport()) + names, err := client.LookupAddr(context.Background(), "8.8.8.8") + if err == nil { + t.Fatal("expected an error here") + } + if names != nil { + t.Fatal("expected nil result here") + } +} + +func TestLookupCNAME(t *testing.T) { + client := New(newtransport()) + cname, err := client.LookupCNAME(context.Background(), "www.ooni.io") + if err == nil { + t.Fatal("expected an error here") + } + if cname != "" { + t.Fatal("expected empty result here") + } +} + +func TestLookupHostWithRetry(t *testing.T) { + // Because there is no server there, if there is no DNS injection + // then we are going to see several timeouts. However, this test is + // going to fail if you're under permanent DNS hijacking, which is + // what happens with Vodafone "Rete Sicura" (on by default) in Italy. + client := New(dnsoverudp.NewTransport( + &net.Dialer{}, "www.example.com:53", + )) + addrs, err := client.LookupHost(context.Background(), "www.google.com") + if err == nil { + t.Fatal("expected an error here") + } + if !strings.HasSuffix(err.Error(), "i/o timeout") { + t.Fatal("not the error we expected") + } + if client.ntimeouts <= 0 { + t.Fatal("no timeouts?") + } + if addrs != nil { + t.Fatal("expected nil addr here") + } +} + +type faketransport struct{} + +func (t *faketransport) RoundTrip( + ctx context.Context, query []byte, +) (reply []byte, err error) { + return nil, errors.New("mocked error") +} + +func (t *faketransport) RequiresPadding() bool { + return true +} + +func TestLookupHostWithNonTimeoutError(t *testing.T) { + client := New(&faketransport{}) + addrs, err := client.LookupHost(context.Background(), "www.google.com") + if err == nil { + t.Fatal("expected an error here") + } + // Not a typo! Check for equality to make sure that we are + // in the case where no timeout was returned but something else. + if err.Error() == "context deadline exceeded" { + t.Fatal("not the error we expected") + } + if client.ntimeouts != 0 { + t.Fatal("we saw a timeout?") + } + if addrs != nil { + t.Fatal("expected nil addr here") + } +} + +func TestLookupHost(t *testing.T) { + client := New(newtransport()) + addrs, err := client.LookupHost(context.Background(), "www.google.com") + if err != nil { + t.Fatal(err) + } + if addrs == nil { + t.Fatal("expected non-nil result here") + } +} + +func TestLookupNonexistent(t *testing.T) { + client := New(newtransport()) + addrs, err := client.LookupHost(context.Background(), "nonexistent.ooni.io") + if err == nil { + t.Fatal("expected an error here") + } + if !strings.HasSuffix(err.Error(), "no such host") { + t.Fatal("not the error we expected") + } + if addrs != nil { + t.Fatal("expected nil addr here") + } +} + +func TestLookupMX(t *testing.T) { + client := New(newtransport()) + records, err := client.LookupMX(context.Background(), "ooni.io") + if err == nil { + t.Fatal("expected an error here") + } + if records != nil { + t.Fatal("expected nil result here") + } +} + +func TestLookupNS(t *testing.T) { + client := New(newtransport()) + records, err := client.LookupNS(context.Background(), "ooni.io") + if err == nil { + t.Fatal("expected an error here") + } + if records != nil { + t.Fatal("expected nil result here") + } +} + +func TestRoundTripExPackFailure(t *testing.T) { + client := New(newtransport()) + _, err := client.mockableRoundTrip( + context.Background(), nil, + func(msg *dns.Msg) ([]byte, error) { + return nil, errors.New("mocked error") + }, + func(t modelx.DNSRoundTripper, query []byte) (reply []byte, err error) { + return nil, nil + }, + func(msg *dns.Msg, data []byte) (err error) { + return nil + }, + ) + if err == nil { + t.Fatal("expeced an error here") + } +} + +func TestRoundTripExRoundTripFailure(t *testing.T) { + client := New(newtransport()) + _, err := client.mockableRoundTrip( + context.Background(), nil, + func(msg *dns.Msg) ([]byte, error) { + return nil, nil + }, + func(t modelx.DNSRoundTripper, query []byte) (reply []byte, err error) { + return nil, errors.New("mocked error") + }, + func(msg *dns.Msg, data []byte) (err error) { + return nil + }, + ) + if err == nil { + t.Fatal("expeced an error here") + } +} + +func TestRoundTripExUnpackFailure(t *testing.T) { + client := New(newtransport()) + _, err := client.mockableRoundTrip( + context.Background(), nil, + func(msg *dns.Msg) ([]byte, error) { + return nil, nil + }, + func(t modelx.DNSRoundTripper, query []byte) (reply []byte, err error) { + return nil, nil + }, + func(msg *dns.Msg, data []byte) (err error) { + return errors.New("mocked error") + }, + ) + if err == nil { + t.Fatal("expeced an error here") + } +} + +func TestLookupHostResultNoName(t *testing.T) { + addrs, err := lookupHostResult(nil, nil, nil) + if err == nil { + t.Fatal("expected an error here") + } + if addrs != nil { + t.Fatal("expected nil addrs") + } +} + +func TestLookupHostResultAAAAError(t *testing.T) { + addrs, err := lookupHostResult(nil, nil, errors.New("mocked error")) + if err == nil { + t.Fatal("expected an error here") + } + if addrs != nil { + t.Fatal("expected nil addrs") + } +} + +func TestUnitMapError(t *testing.T) { + if mapError(dns.RcodeSuccess) != nil { + t.Fatal("unexpected return value") + } + if err := mapError(dns.RcodeNameError); !strings.HasSuffix( + err.Error(), "no such host", + ) { + t.Fatal("unexpected return value") + } + if err := mapError(dns.RcodeBadName); !strings.HasSuffix( + err.Error(), "query failed", + ) { + t.Fatal("unexpected return value") + } +} + +func TestUnitPadding(t *testing.T) { + // The purpose of this unit test is to make sure that for a wide + // array of values we obtain the right query size. + getquerylen := func(domainlen int, padding bool) int { + reso := new(Resolver) + query := reso.newQueryWithQuestion(dns.Question{ + // This is not a valid name because it ends up being way + // longer than 255 octets. However, the library is allowing + // us to generate such name and we are not going to send + // it on the wire. Also, we check below that the query that + // we generate is long enough, so we should be good. + Name: dns.Fqdn(strings.Repeat("x.", domainlen)), + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }, padding) + data, err := query.Pack() + if err != nil { + t.Fatal(err) + } + return len(data) + } + for domainlen := 1; domainlen <= 4000; domainlen++ { + vanillalen := getquerylen(domainlen, false) + paddedlen := getquerylen(domainlen, true) + if vanillalen < domainlen { + t.Fatal("vanillalen is smaller than domainlen") + } + if (paddedlen % desiredBlockSize) != 0 { + t.Fatal("paddedlen is not a multiple of desiredQuerySize") + } + if paddedlen < vanillalen { + t.Fatal("paddedlen is smaller than vanillalen") + } + } +} diff --git a/netx/internal/resolver/parentresolver/parentresolver.go b/netx/internal/resolver/parentresolver/parentresolver.go new file mode 100644 index 00000000..83aad758 --- /dev/null +++ b/netx/internal/resolver/parentresolver/parentresolver.go @@ -0,0 +1,138 @@ +// Package parentresolver contains the parent resolver +package parentresolver + +import ( + "context" + "errors" + "net" + "sync/atomic" + "time" + + "github.com/ooni/probe-engine/netx/internal/dialid" + "github.com/ooni/probe-engine/netx/internal/errwrapper" + "github.com/ooni/probe-engine/netx/internal/resolver/bogondetector" + "github.com/ooni/probe-engine/netx/internal/transactionid" + "github.com/ooni/probe-engine/netx/modelx" +) + +// Resolver is the emitter resolver +type Resolver struct { + bogonsCount int64 + resolver modelx.DNSResolver +} + +// New creates a new emitter resolver +func New(resolver modelx.DNSResolver) *Resolver { + return &Resolver{resolver: resolver} +} + +// LookupAddr returns the name of the provided IP address +func (r *Resolver) LookupAddr(ctx context.Context, addr string) ([]string, error) { + return r.resolver.LookupAddr(ctx, addr) +} + +// LookupCNAME returns the canonical name of a host +func (r *Resolver) LookupCNAME(ctx context.Context, host string) (string, error) { + return r.resolver.LookupCNAME(ctx, host) +} + +type queryableTransport interface { + Network() string + Address() string +} + +type queryableResolver interface { + Transport() modelx.DNSRoundTripper +} + +func (r *Resolver) queryTransport() (network string, address string) { + if reso, okay := r.resolver.(queryableResolver); okay { + if transport, okay := reso.Transport().(queryableTransport); okay { + network, address = transport.Network(), transport.Address() + } + } + return +} + +// LookupHost returns the IP addresses of a host +func (r *Resolver) LookupHost(ctx context.Context, hostname string) ([]string, error) { + network, address := r.queryTransport() + dialID := dialid.ContextDialID(ctx) + txID := transactionid.ContextTransactionID(ctx) + root := modelx.ContextMeasurementRootOrDefault(ctx) + root.Handler.OnMeasurement(modelx.Measurement{ + ResolveStart: &modelx.ResolveStartEvent{ + DialID: dialID, + DurationSinceBeginning: time.Now().Sub(root.Beginning), + Hostname: hostname, + TransactionID: txID, + TransportAddress: address, + TransportNetwork: network, + }, + }) + addrs, err := r.lookupHost(ctx, hostname) + containsBogons := errors.Is(err, modelx.ErrDNSBogon) + if containsBogons { + // By default root.ErrDNSBogon is nil. Treating bogons as + // errors could prevent us from measuring, e.g., legitimate + // internal-only servers in Iran. This is why we have not + // enabled this functionality by default. Of course, it is + // instead smart to treat bogons as errors when we're using + // a website that we _know_ cannot have bogons. + // + // See also . + err = root.ErrDNSBogon + } + err = errwrapper.SafeErrWrapperBuilder{ + DialID: dialID, + Error: err, + Operation: "resolve", + TransactionID: txID, + }.MaybeBuild() + root.Handler.OnMeasurement(modelx.Measurement{ + ResolveDone: &modelx.ResolveDoneEvent{ + Addresses: addrs, + ContainsBogons: containsBogons, + DialID: dialID, + DurationSinceBeginning: time.Now().Sub(root.Beginning), + Error: err, + Hostname: hostname, + TransactionID: txID, + TransportAddress: address, + TransportNetwork: network, + }, + }) + // Respect general Go expectation that one doesn't return + // both a value and a non-nil error + if errors.Is(err, modelx.ErrDNSBogon) { + addrs = nil + } + return addrs, err +} + +func (r *Resolver) lookupHost(ctx context.Context, hostname string) ([]string, error) { + addrs, err := r.resolver.LookupHost(ctx, hostname) + for _, addr := range addrs { + if bogondetector.Check(addr) == true { + return r.detectedBogon(ctx, hostname, addrs) + } + } + return addrs, err +} + +func (r *Resolver) detectedBogon( + ctx context.Context, hostname string, addrs []string, +) ([]string, error) { + atomic.AddInt64(&r.bogonsCount, 1) + return addrs, modelx.ErrDNSBogon +} + +// LookupMX returns the MX records of a specific name +func (r *Resolver) LookupMX(ctx context.Context, name string) ([]*net.MX, error) { + return r.resolver.LookupMX(ctx, name) +} + +// LookupNS returns the NS records of a specific name +func (r *Resolver) LookupNS(ctx context.Context, name string) ([]*net.NS, error) { + return r.resolver.LookupNS(ctx, name) +} diff --git a/netx/internal/resolver/parentresolver/parentresolver_test.go b/netx/internal/resolver/parentresolver/parentresolver_test.go new file mode 100644 index 00000000..73e0fb65 --- /dev/null +++ b/netx/internal/resolver/parentresolver/parentresolver_test.go @@ -0,0 +1,153 @@ +package parentresolver + +import ( + "context" + "net" + "sync" + "testing" + "time" + + "github.com/ooni/probe-engine/netx/internal/resolver/systemresolver" + "github.com/ooni/probe-engine/netx/modelx" +) + +func TestLookupAddr(t *testing.T) { + client := New(new(net.Resolver)) + names, err := client.LookupAddr(context.Background(), "8.8.8.8") + if err != nil { + t.Fatal(err) + } + if names == nil { + t.Fatal("expected non-nil result here") + } +} + +func TestLookupCNAME(t *testing.T) { + client := New(new(net.Resolver)) + cname, err := client.LookupCNAME(context.Background(), "www.ooni.io") + if err != nil { + t.Fatal(err) + } + if cname == "" { + t.Fatal("expected non-empty result here") + } +} + +type emitterchecker struct { + containsBogons bool + gotResolveStart bool + gotResolveDone bool + mu sync.Mutex +} + +func (h *emitterchecker) OnMeasurement(m modelx.Measurement) { + h.mu.Lock() + defer h.mu.Unlock() + if m.ResolveStart != nil { + h.gotResolveStart = true + } + if m.ResolveDone != nil { + h.gotResolveDone = true + h.containsBogons = m.ResolveDone.ContainsBogons + } +} + +func TestLookupHost(t *testing.T) { + client := New(systemresolver.New(new(net.Resolver))) + handler := new(emitterchecker) + ctx := modelx.WithMeasurementRoot( + context.Background(), &modelx.MeasurementRoot{ + Beginning: time.Now(), + Handler: handler, + }) + addrs, err := client.LookupHost(ctx, "www.google.com") + if err != nil { + t.Fatal(err) + } + for _, addr := range addrs { + t.Log(addr) + } + handler.mu.Lock() + defer handler.mu.Unlock() + if handler.gotResolveStart == false { + t.Fatal("did not see resolve start event") + } + if handler.gotResolveDone == false { + t.Fatal("did not see resolve done event") + } + if handler.containsBogons == true { + t.Fatal("did not expect to see bogons here") + } +} + +func TestLookupHostBogonHardError(t *testing.T) { + client := New(systemresolver.New(new(net.Resolver))) + handler := new(emitterchecker) + ctx := modelx.WithMeasurementRoot( + context.Background(), &modelx.MeasurementRoot{ + Beginning: time.Now(), + ErrDNSBogon: modelx.ErrDNSBogon, + Handler: handler, + }) + addrs, err := client.LookupHost(ctx, "localhost") + if err == nil { + t.Fatal("expected an error here") + } + if err.Error() != "dns_bogon_error" { + t.Fatal("not the error that we expected") + } + if addrs != nil { + t.Fatal("expected nil addr here") + } + if handler.gotResolveDone == false { + t.Fatal("did not get the ResolveDone event") + } + if handler.containsBogons == false { + t.Fatal("expected acknowledgement of bogons") + } +} + +func TestLookupHostBogonAsWarning(t *testing.T) { + client := New(systemresolver.New(new(net.Resolver))) + handler := new(emitterchecker) + ctx := modelx.WithMeasurementRoot( + context.Background(), &modelx.MeasurementRoot{ + Beginning: time.Now(), + Handler: handler, + }) + addrs, err := client.LookupHost(ctx, "localhost") + if err != nil { + t.Fatal(err) + } + if addrs == nil { + t.Fatal("expected non-nil addr here") + } + if handler.gotResolveDone == false { + t.Fatal("did not get the ResolveDone event") + } + if handler.containsBogons == false { + t.Fatal("expected acknowledgement of bogons") + } +} + +func TestLookupMX(t *testing.T) { + client := New(new(net.Resolver)) + records, err := client.LookupMX(context.Background(), "ooni.io") + if err != nil { + t.Fatal(err) + } + if records == nil { + t.Fatal("expected non-nil result here") + } +} + +func TestLookupNS(t *testing.T) { + client := New(new(net.Resolver)) + records, err := client.LookupNS(context.Background(), "ooni.io") + if err != nil { + t.Fatal(err) + } + if records == nil { + t.Fatal("expected non-nil result here") + } +} diff --git a/netx/internal/resolver/resolver.go b/netx/internal/resolver/resolver.go new file mode 100644 index 00000000..0fff68bf --- /dev/null +++ b/netx/internal/resolver/resolver.go @@ -0,0 +1,50 @@ +// Package resolver contains code to create a resolver +package resolver + +import ( + "net" + "net/http" + + "github.com/ooni/probe-engine/netx/internal/resolver/dnstransport/dnsoverhttps" + "github.com/ooni/probe-engine/netx/internal/resolver/dnstransport/dnsovertcp" + "github.com/ooni/probe-engine/netx/internal/resolver/dnstransport/dnsoverudp" + "github.com/ooni/probe-engine/netx/internal/resolver/ooniresolver" + "github.com/ooni/probe-engine/netx/internal/resolver/parentresolver" + "github.com/ooni/probe-engine/netx/internal/resolver/systemresolver" + "github.com/ooni/probe-engine/netx/modelx" +) + +// NewResolverSystem creates a new Go/system resolver. +func NewResolverSystem() *parentresolver.Resolver { + return parentresolver.New( + systemresolver.New(new(net.Resolver)), + ) +} + +// NewResolverUDP creates a new UDP resolver. +func NewResolverUDP(dialer modelx.Dialer, address string) *parentresolver.Resolver { + return parentresolver.New( + ooniresolver.New(dnsoverudp.NewTransport(dialer, address)), + ) +} + +// NewResolverTCP creates a new TCP resolver. +func NewResolverTCP(dialer modelx.Dialer, address string) *parentresolver.Resolver { + return parentresolver.New( + ooniresolver.New(dnsovertcp.NewTransportTCP(dialer, address)), + ) +} + +// NewResolverTLS creates a new DoT resolver. +func NewResolverTLS(dialer modelx.TLSDialer, address string) *parentresolver.Resolver { + return parentresolver.New( + ooniresolver.New(dnsovertcp.NewTransportTLS(dialer, address)), + ) +} + +// NewResolverHTTPS creates a new DoH resolver. +func NewResolverHTTPS(client *http.Client, address string) *parentresolver.Resolver { + return parentresolver.New( + ooniresolver.New(dnsoverhttps.NewTransport(client, address)), + ) +} diff --git a/netx/internal/resolver/resolver_test.go b/netx/internal/resolver/resolver_test.go new file mode 100644 index 00000000..47162634 --- /dev/null +++ b/netx/internal/resolver/resolver_test.go @@ -0,0 +1,103 @@ +package resolver + +import ( + "context" + "crypto/tls" + "net" + "net/http" + "testing" + "time" + + "github.com/ooni/probe-engine/netx/handlers" + "github.com/ooni/probe-engine/netx/modelx" +) + +func testresolverquick(t *testing.T, resolver modelx.DNSResolver) { + addrs, err := resolver.LookupHost(context.Background(), "dns.google.com") + if err != nil { + t.Fatal(err) + } + if addrs == nil { + t.Fatal("expected non-nil addrs here") + } + var foundquad8 bool + for _, addr := range addrs { + if addr == "8.8.8.8" { + foundquad8 = true + } + } + if !foundquad8 { + t.Fatal("did not find 8.8.8.8 in ouput") + } +} + +func TestIntegrationDetectBogon(t *testing.T) { + resolver := NewResolverSystem() + ctx := modelx.WithMeasurementRoot( + context.Background(), &modelx.MeasurementRoot{ + Beginning: time.Now(), + ErrDNSBogon: modelx.ErrDNSBogon, + Handler: handlers.NoHandler, + }) + addrs, err := resolver.LookupHost(ctx, "localhost") + if err == nil { + t.Fatal("expected an error here") + } + if err.Error() != "dns_bogon_error" { + t.Fatal("not the error we expected to see") + } + if addrs != nil { + t.Fatal("expected nil addrs here") + } +} + +func TestIntegrationNewResolverSystem(t *testing.T) { + testresolverquick(t, NewResolverSystem()) +} + +func TestIntegrationNewResolverUDPAddress(t *testing.T) { + testresolverquick(t, NewResolverUDP( + new(net.Dialer), "8.8.8.8:53")) +} + +func TestIntegrationNewResolverUDPDomain(t *testing.T) { + testresolverquick(t, NewResolverUDP( + new(net.Dialer), "dns.google.com:53")) +} + +func TestIntegrationNewResolverTCPAddress(t *testing.T) { + testresolverquick(t, NewResolverTCP( + new(net.Dialer), "8.8.8.8:53")) +} + +func TestIntegrationNewResolverTCPDomain(t *testing.T) { + testresolverquick(t, NewResolverTCP( + new(net.Dialer), "dns.google.com:53")) +} + +func TestIntegrationNewResolverDoTAddress(t *testing.T) { + testresolverquick(t, NewResolverTLS( + &tlsdialer{}, "9.9.9.9:853")) +} + +func TestIntegrationNewResolverDoTDomain(t *testing.T) { + testresolverquick(t, NewResolverTLS( + &tlsdialer{}, "dns.quad9.net:853")) +} + +func TestIntegrationNewResolverDoH(t *testing.T) { + testresolverquick(t, NewResolverHTTPS( + http.DefaultClient, "https://cloudflare-dns.com/dns-query")) +} + +type tlsdialer struct{} + +func (*tlsdialer) DialTLS(network, address string) (net.Conn, error) { + return tls.Dial(network, address, new(tls.Config)) +} + +func (*tlsdialer) DialTLSContext( + ctx context.Context, network, address string, +) (net.Conn, error) { + return tls.Dial(network, address, new(tls.Config)) +} diff --git a/netx/internal/resolver/systemresolver/systemresolver.go b/netx/internal/resolver/systemresolver/systemresolver.go new file mode 100644 index 00000000..925df719 --- /dev/null +++ b/netx/internal/resolver/systemresolver/systemresolver.go @@ -0,0 +1,70 @@ +// Package systemresolver contains the system resolver +package systemresolver + +import ( + "context" + "errors" + "net" + + "github.com/ooni/probe-engine/netx/modelx" +) + +// Resolver is the system resolver +type Resolver struct { + resolver modelx.DNSResolver +} + +// New creates a new system resolver +func New(resolver modelx.DNSResolver) *Resolver { + return &Resolver{resolver: resolver} +} + +// LookupAddr returns the name of the provided IP address +func (r *Resolver) LookupAddr(ctx context.Context, addr string) ([]string, error) { + return r.resolver.LookupAddr(ctx, addr) +} + +// LookupCNAME returns the canonical name of a host +func (r *Resolver) LookupCNAME(ctx context.Context, host string) (string, error) { + return r.resolver.LookupCNAME(ctx, host) +} + +type fakeTransport struct{} + +func (*fakeTransport) RoundTrip( + ctx context.Context, query []byte, +) (reply []byte, err error) { + return nil, errors.New("not implemented") +} + +func (*fakeTransport) RequiresPadding() bool { + return false +} + +func (*fakeTransport) Network() string { + return "system" +} + +func (*fakeTransport) Address() string { + return "" +} + +// Transport returns the transport being used +func (r *Resolver) Transport() modelx.DNSRoundTripper { + return &fakeTransport{} +} + +// LookupHost returns the IP addresses of a host +func (r *Resolver) LookupHost(ctx context.Context, hostname string) ([]string, error) { + return r.resolver.LookupHost(ctx, hostname) +} + +// LookupMX returns the MX records of a specific name +func (r *Resolver) LookupMX(ctx context.Context, name string) ([]*net.MX, error) { + return r.resolver.LookupMX(ctx, name) +} + +// LookupNS returns the NS records of a specific name +func (r *Resolver) LookupNS(ctx context.Context, name string) ([]*net.NS, error) { + return r.resolver.LookupNS(ctx, name) +} diff --git a/netx/internal/resolver/systemresolver/systemresolver_test.go b/netx/internal/resolver/systemresolver/systemresolver_test.go new file mode 100644 index 00000000..88ef1ef7 --- /dev/null +++ b/netx/internal/resolver/systemresolver/systemresolver_test.go @@ -0,0 +1,99 @@ +package systemresolver + +import ( + "context" + "net" + "testing" + + "github.com/ooni/probe-engine/netx/modelx" +) + +type queryableTransport interface { + Network() string + Address() string + RequiresPadding() bool +} + +type queryableResolver interface { + Transport() modelx.DNSRoundTripper +} + +func TestCanQuery(t *testing.T) { + var client modelx.DNSResolver = New(new(net.Resolver)) + transport := client.(queryableResolver).Transport() + reply, err := transport.RoundTrip(context.Background(), nil) + if err == nil { + t.Fatal("expected an error here") + } + if err.Error() != "not implemented" { + t.Fatal("not the error we expected") + } + if reply != nil { + t.Fatal("expected nil reply here") + } + queryableTransport := transport.(queryableTransport) + if queryableTransport.Address() != "" { + t.Fatal("invalid address") + } + if queryableTransport.Network() != "system" { + t.Fatal("invalid network") + } + if queryableTransport.RequiresPadding() != false { + t.Fatal("we should require padding here") + } +} + +func TestLookupAddr(t *testing.T) { + client := New(new(net.Resolver)) + names, err := client.LookupAddr(context.Background(), "8.8.8.8") + if err != nil { + t.Fatal(err) + } + if names == nil { + t.Fatal("expected non-nil result here") + } +} + +func TestLookupCNAME(t *testing.T) { + client := New(new(net.Resolver)) + name, err := client.LookupCNAME(context.Background(), "www.ooni.io") + if err != nil { + t.Fatal(err) + } + if name == "" { + t.Fatal("expected non-empty result here") + } +} + +func TestLookupHost(t *testing.T) { + client := New(new(net.Resolver)) + addrs, err := client.LookupHost(context.Background(), "www.google.com") + if err != nil { + t.Fatal(err) + } + if addrs == nil { + t.Fatal("expected non-nil result here") + } +} + +func TestLookupMX(t *testing.T) { + client := New(new(net.Resolver)) + records, err := client.LookupMX(context.Background(), "ooni.io") + if err != nil { + t.Fatal(err) + } + if records == nil { + t.Fatal("expected non-nil result here") + } +} + +func TestLookupNS(t *testing.T) { + client := New(new(net.Resolver)) + records, err := client.LookupNS(context.Background(), "ooni.io") + if err != nil { + t.Fatal(err) + } + if records == nil { + t.Fatal("expected non-nil result here") + } +} diff --git a/netx/internal/transactionid/transactionid.go b/netx/internal/transactionid/transactionid.go new file mode 100644 index 00000000..b4c1e0e5 --- /dev/null +++ b/netx/internal/transactionid/transactionid.go @@ -0,0 +1,24 @@ +// Package transactionid contains code to share the transactionID +package transactionid + +import ( + "context" + "sync/atomic" +) + +type contextkey struct{} + +var id int64 + +// WithTransactionID returns a copy of ctx with TransactionID +func WithTransactionID(ctx context.Context) context.Context { + return context.WithValue( + ctx, contextkey{}, atomic.AddInt64(&id, 1), + ) +} + +// ContextTransactionID returns the TransactionID of the context, or zero +func ContextTransactionID(ctx context.Context) int64 { + id, _ := ctx.Value(contextkey{}).(int64) + return id +} diff --git a/netx/internal/transactionid/transactionid_test.go b/netx/internal/transactionid/transactionid_test.go new file mode 100644 index 00000000..577d888b --- /dev/null +++ b/netx/internal/transactionid/transactionid_test.go @@ -0,0 +1,24 @@ +package transactionid + +import ( + "context" + "testing" +) + +func TestIntegration(t *testing.T) { + ctx := context.Background() + id := ContextTransactionID(ctx) + if id != 0 { + t.Fatal("unexpected ID for empty context") + } + ctx = WithTransactionID(ctx) + id = ContextTransactionID(ctx) + if id != 1 { + t.Fatal("expected ID equal to 1") + } + ctx = WithTransactionID(ctx) + id = ContextTransactionID(ctx) + if id != 2 { + t.Fatal("expected ID equal to 2") + } +} diff --git a/netx/modelx/modelx.go b/netx/modelx/modelx.go new file mode 100644 index 00000000..c291fdca --- /dev/null +++ b/netx/modelx/modelx.go @@ -0,0 +1,807 @@ +// Package modelx contains the data modelx. +package modelx + +import ( + "context" + "crypto/tls" + "crypto/x509" + "errors" + "math" + "net" + "net/http" + "net/url" + "time" + + "github.com/miekg/dns" +) + +// Measurement contains zero or more events. Do not assume that at any +// time a Measurement will only contain a single event. When a Measurement +// contains an event, the corresponding pointer is non nil. +// +// All events contain a time measurement, `DurationSinceBeginning`, that +// uses a monotonic clock and is relative to a preconfigured "zero". +type Measurement struct { + // DNS events + // + // These are all identifed by a DialID. A ResolveEvent optionally has + // a reference to the TransactionID that started the dial, if any. + ResolveStart *ResolveStartEvent `json:",omitempty"` + DNSQuery *DNSQueryEvent `json:",omitempty"` + DNSReply *DNSReplyEvent `json:",omitempty"` + ResolveDone *ResolveDoneEvent `json:",omitempty"` + + // Syscalls + // + // These are all identified by a ConnID. A ConnectEvent has a reference + // to the DialID that caused this connection to be attempted. + // + // Because they are syscalls, we don't split them in start/done pairs + // but we record the amount of time in which we were blocked. + Connect *ConnectEvent `json:",omitempty"` + Read *ReadEvent `json:",omitempty"` + Write *WriteEvent `json:",omitempty"` + Close *CloseEvent `json:",omitempty"` + + // TLS events + // + // Identified by either ConnID or TransactionID. In the former case + // the TLS handshake is managed by net code, in the latter case it is + // instead managed by Golang's HTTP engine. It should not happen to + // have both ConnID and TransactionID different from zero. + TLSHandshakeStart *TLSHandshakeStartEvent `json:",omitempty"` + TLSHandshakeDone *TLSHandshakeDoneEvent `json:",omitempty"` + + // HTTP roundtrip events + // + // A round trip starts when we need a connection to send a request + // and ends when we've got the response headers or an error. + // + // The identifer here is TransactionID, where the transaction is + // like the round trip except that it terminates when we've finished + // reading the whole response body. + HTTPRoundTripStart *HTTPRoundTripStartEvent `json:",omitempty"` + HTTPConnectionReady *HTTPConnectionReadyEvent `json:",omitempty"` + HTTPRequestHeader *HTTPRequestHeaderEvent `json:",omitempty"` + HTTPRequestHeadersDone *HTTPRequestHeadersDoneEvent `json:",omitempty"` + HTTPRequestDone *HTTPRequestDoneEvent `json:",omitempty"` + HTTPResponseStart *HTTPResponseStartEvent `json:",omitempty"` + HTTPRoundTripDone *HTTPRoundTripDoneEvent `json:",omitempty"` + + // HTTP body events + // + // They are identified by the TransactionID. You are not going to see + // these events if you don't fully read response bodies. But that's + // something you are supposed to do, so you should be fine. + HTTPResponseBodyPart *HTTPResponseBodyPartEvent `json:",omitempty"` + HTTPResponseDone *HTTPResponseDoneEvent `json:",omitempty"` + + // Extension events. + // + // The purpose of these events is to give us some flexibility to + // experiment with message formats before blessing something as + // part of the official API of the library. The intent however is + // to avoid keeping something as an extension for a long time. + Extension *ExtensionEvent `json:",omitempty"` +} + +// ErrWrapper is our error wrapper for Go errors. The key objective of +// this structure is to properly set Failure, which is also returned by +// the Error() method, so be one of the OONI defined strings. +type ErrWrapper struct { + // ConnID is the connection ID, or zero if not known. + ConnID int64 + + // DialID is the dial ID, or zero if not known. + DialID int64 + + // Failure is the OONI failure string. The failure strings are + // loosely backward compatible with Measurement Kit. + // + // Supported failure strings + // + // - `connection_refused`: ECONNREFUSED + // - `connection_reset`: ECONNRESET + // - `dns_bogon_error`: detected bogon in DNS reply + // - `dns_nxdomain_error`: NXDOMAIN in DNS reply + // - `eof_error`: unexpected EOF on connection + // - `generic_timeout_error`: some timer has expired + // - `ssl_invalid_hostname`: certificate not valid for SNI + // - `ssl_unknown_autority`: cannot find CA validating certificate + // - `ssl_invalid_certificate`: e.g. certificate expried + // - `unknown_failure ...`: any other error + Failure string + + // Operation is the operation that failed. If possible, it + // SHOULD be a _major_ operation. Major operations are: + // + // - `resolve`: resolving a domain name failed + // - `connect`: connecting to an IP failed + // - `tls_handshake`: TLS handshaking failed + // - `http_round_trip`: other errors during round trip + // + // Because a network connection doesn't necessarily know + // what is the current major operation we also have the + // following _minor_ operations: + // + // - `close`: CLOSE failed + // - `read`: READ failed + // - `write`: WRITE failed + // + // If an ErrWrapper referring to a major operation is wrapping + // another ErrWrapper and such ErrWrapper already refers to + // a major operation, then the new ErrWrapper should use the + // child ErrWrapper major operation. Otherwise, it should use + // its own major operation. This way, the topmost wrapper is + // supposed to refer to the major operation that failed. + Operation string + + // TransactionID is the transaction ID, or zero if not known. + TransactionID int64 + + // WrappedErr is the error that we're wrapping. + WrappedErr error +} + +// Error returns a description of the error that occurred. +func (e *ErrWrapper) Error() string { + return e.Failure +} + +// Unwrap allows to access the underlying error +func (e *ErrWrapper) Unwrap() error { + return e.WrappedErr +} + +// CloseEvent is emitted when the CLOSE syscall returns. +type CloseEvent struct { + // ConnID is the identifier of this connection. + ConnID int64 + + // DurationSinceBeginning is the number of nanoseconds since + // the time configured as the "zero" time. + DurationSinceBeginning time.Duration + + // Error is the error returned by CLOSE. + Error error + + // SyscallDuration is the number of nanoseconds we were + // blocked waiting for the syscall to return. + SyscallDuration time.Duration +} + +// ConnectEvent is emitted when the CONNECT syscall returns. +type ConnectEvent struct { + // ConnID is the identifier of this connection. + ConnID int64 + + // DialID is the identifier of the dial operation as + // part of which we called CONNECT. + DialID int64 + + // DurationSinceBeginning is the number of nanoseconds since + // the time configured as the "zero" time. + DurationSinceBeginning time.Duration + + // Error is the error returned by CONNECT. + Error error + + // Network is the network we're dialing for, e.g. "tcp" + Network string + + // RemoteAddress is the remote IP address we're dialing for + RemoteAddress string + + // SyscallDuration is the number of nanoseconds we were + // blocked waiting for the syscall to return. + SyscallDuration time.Duration + + // TransactionID is the ID of the HTTP transaction that caused the + // current dial to run, or zero if there's no such transaction. + TransactionID int64 `json:",omitempty"` +} + +// DNSQueryEvent is emitted when we send a DNS query. +type DNSQueryEvent struct { + // Data is the raw data we're sending to the server. + Data []byte + + // DialID is the identifier of the dial operation as + // part of which we're sending this query. + DialID int64 + + // DurationSinceBeginning is the number of nanoseconds since + // the time configured as the "zero" time. + DurationSinceBeginning time.Duration + + // Msg is the parsed message we're sending to the server. + Msg *dns.Msg `json:"-"` +} + +// DNSReplyEvent is emitted when we receive byte that are +// successfully parsed into a DNS reply. +type DNSReplyEvent struct { + // Data is the raw data we've received and parsed. + Data []byte + + // DialID is the identifier of the dial operation as + // part of which we've received this query. + DialID int64 + + // DurationSinceBeginning is the number of nanoseconds since + // the time configured as the "zero" time. + DurationSinceBeginning time.Duration + + // Msg is the received parsed message. + Msg *dns.Msg `json:"-"` +} + +// ExtensionEvent is emitted by a netx extension. +type ExtensionEvent struct { + // DurationSinceBeginning is the number of nanoseconds since + // the time configured as the "zero" time. + DurationSinceBeginning time.Duration + + // Key is the unique identifier of the event. A good rule of + // thumb is to use `${packageName}.${messageType}`. + Key string + + // Severity of the emitted message ("WARN", "INFO", "DEBUG") + Severity string + + // TransactionID is the identifier of this transaction, provided + // that we have an active one, otherwise is zero. + TransactionID int64 + + // Value is the extension dependent message. This message + // has the only requirement of being JSON serializable. + Value interface{} +} + +// HTTPRoundTripStartEvent is emitted when the HTTP transport +// starts the HTTP "round trip". That is, when the transport +// receives from the HTTP client a request to sent. The round +// trip terminates when we receive headers. What we call the +// "transaction" here starts with this event and does not finish +// until we have also finished receiving the response body. +type HTTPRoundTripStartEvent struct { + // DialID is the identifier of the dial operation that + // caused this round trip to start. Typically, this occures + // when doing DoH. If zero, means that this round trip has + // not been started by any dial operation. + DialID int64 `json:",omitempty"` + + // DurationSinceBeginning is the number of nanoseconds since + // the time configured as the "zero" time. + DurationSinceBeginning time.Duration + + // Method is the request method + Method string + + // TransactionID is the identifier of this transaction + TransactionID int64 + + // URL is the request URL + URL string +} + +// HTTPConnectionReadyEvent is emitted when the HTTP transport has got +// a connection which is ready for sending the request. +type HTTPConnectionReadyEvent struct { + // ConnID is the identifier of the connection that is ready. Knowing + // this ID allows you to bind HTTP events to net events. + ConnID int64 + + // DurationSinceBeginning is the number of nanoseconds since + // the time configured as the "zero" time. + DurationSinceBeginning time.Duration + + // TransactionID is the identifier of this transaction + TransactionID int64 +} + +// HTTPRequestHeaderEvent is emitted when we have written a header, +// where written typically means just "buffered". +type HTTPRequestHeaderEvent struct { + // DurationSinceBeginning is the number of nanoseconds since + // the time configured as the "zero" time. + DurationSinceBeginning time.Duration + + // Key is the header key + Key string + + // TransactionID is the identifier of this transaction + TransactionID int64 + + // Value is the value/values of this header. + Value []string +} + +// HTTPRequestHeadersDoneEvent is emitted when we have written, or more +// correctly, "buffered" all headers. +type HTTPRequestHeadersDoneEvent struct { + // DurationSinceBeginning is the number of nanoseconds since + // the time configured as the "zero" time. + DurationSinceBeginning time.Duration + + // Headers contain the original request headers. This is included + // here to make this event actionable without needing to join it with + // other events, i.e., to simplify logging. + Headers http.Header + + // Method is the original request method. This is here + // for the same reason of Headers. + Method string + + // TransactionID is the identifier of this transaction + TransactionID int64 + + // URL is the original request URL. This is here + // for the same reason of Headers. We use an object + // rather than a string, because here you want to + // use specific subfields directly for logging. + URL *url.URL +} + +// HTTPRequestDoneEvent is emitted when we have sent the request +// body or there has been any failure in sending the request. +type HTTPRequestDoneEvent struct { + // DurationSinceBeginning is the number of nanoseconds since + // the time configured as the "zero" time. + DurationSinceBeginning time.Duration + + // Error is non nil if we could not write the request headers or + // some specific part of the body. When this step of writing + // the request fails, of course the whole transaction will fail + // as well. This error however tells you that the issue was + // when sending the request, not when receiving the response. + Error error + + // TransactionID is the identifier of this transaction + TransactionID int64 +} + +// HTTPResponseStartEvent is emitted when we receive the byte from +// the response on the wire. +type HTTPResponseStartEvent struct { + // DurationSinceBeginning is the number of nanoseconds since + // the time configured as the "zero" time. + DurationSinceBeginning time.Duration + + // TransactionID is the identifier of this transaction + TransactionID int64 +} + +const defaultBodySnapSize int64 = 1 << 20 + +// ComputeBodySnapSize computes the body snap size. If snapSize is negative +// we return MaxInt64. If it's zero we return the default snap size. Otherwise +// the value of snapSize is returned. +func ComputeBodySnapSize(snapSize int64) int64 { + if snapSize < 0 { + snapSize = math.MaxInt64 + } else if snapSize == 0 { + snapSize = defaultBodySnapSize + } + return snapSize +} + +// HTTPRoundTripDoneEvent is emitted at the end of the round trip. Either +// we have an error, or a valid HTTP response. An error could be caused +// either by not being able to send the request or not being able to receive +// the response. Note that here errors are network/TLS/dialing errors or +// protocol violation errors. No status code will cause errors here. +type HTTPRoundTripDoneEvent struct { + // DurationSinceBeginning is the number of nanoseconds since + // the time configured as the "zero" time. + DurationSinceBeginning time.Duration + + // Error is the overall result of the round trip. If non-nil, checking + // also the result of HTTPResponseDone helps to disambiguate whether the + // error was in sending the request or receiving the response. + Error error + + // RequestBodySnap contains a snap of the request body. We'll + // not read more than SnapSize bytes of the body. Because typically + // you control the request bodies that you send, perhaps think + // about saving them using other means. + RequestBodySnap []byte + + // RequestHeaders contain the original request headers. This is + // included here to make this event actionable without needing to + // join it with other events, as it's too important. + RequestHeaders http.Header + + // RequestMethod is the original request method. This is here + // for the same reason of RequestHeaders. + RequestMethod string + + // RequestURL is the original request URL. This is here + // for the same reason of RequestHeaders. + RequestURL string + + // ResponseBodySnap is like RequestBodySnap but for the response. You + // can still save the whole body by just reading it, if this + // is something that you need to do. We're using the snaps here + // mainly to log small stuff like DoH and redirects. + ResponseBodySnap []byte + + // ResponseHeaders contains the response headers if error is nil. + ResponseHeaders http.Header + + // ResponseProto contains the response protocol + ResponseProto string + + // ResponseStatusCode contains the HTTP status code if error is nil. + ResponseStatusCode int64 + + // MaxBodySnapSize is the maximum size of the bodies snapshot. + MaxBodySnapSize int64 + + // TransactionID is the identifier of this transaction + TransactionID int64 +} + +// HTTPResponseBodyPartEvent is emitted after we have received +// a part of the response body, or an error reading it. Note that +// bytes read here does not necessarily match bytes returned by +// ReadEvent because of (1) transparent gzip decompression by Go, +// (2) HTTP overhead (headers and chunked body), (3) TLS. This +// is the reason why we also want to record the error here rather +// than just recording the error in ReadEvent. +// +// Note that you are not going to see this event if you do not +// drain the response body, which you're supposed to do, tho. +type HTTPResponseBodyPartEvent struct { + // DurationSinceBeginning is the number of nanoseconds since + // the time configured as the "zero" time. + DurationSinceBeginning time.Duration + + // Error indicates whether we could not read a part of the body + Error error + + // Data is a reference to the body we've just read. + Data []byte + + // TransactionID is the identifier of this transaction + TransactionID int64 +} + +// HTTPResponseDoneEvent is emitted after we have received the body, +// when the response body is being closed. +// +// Note that you are not going to see this event if you do not +// drain the response body, which you're supposed to do, tho. +type HTTPResponseDoneEvent struct { + // DurationSinceBeginning is the number of nanoseconds since + // the time configured as the "zero" time. + DurationSinceBeginning time.Duration + + // TransactionID is the identifier of this transaction + TransactionID int64 +} + +// ReadEvent is emitted when the READ/RECV syscall returns. +type ReadEvent struct { + // ConnID is the identifier of this connection. + ConnID int64 + + // DurationSinceBeginning is the number of nanoseconds since + // the time configured as the "zero" time. + DurationSinceBeginning time.Duration + + // Error is the error returned by READ/RECV. + Error error + + // NumBytes is the number of bytes received, which may in + // principle also be nonzero on error. + NumBytes int64 + + // SyscallDuration is the number of nanoseconds we were + // blocked waiting for the syscall to return. + SyscallDuration time.Duration +} + +// ResolveStartEvent is emitted when we start resolving a domain name. +type ResolveStartEvent struct { + // DialID is the identifier of the dial operation as + // part of which we're resolving this domain. + DialID int64 + + // DurationSinceBeginning is the number of nanoseconds since + // the time configured as the "zero" time. + DurationSinceBeginning time.Duration + + // Hostname is the domain name to resolve. + Hostname string + + // TransactionID is the ID of the HTTP transaction that caused the + // current dial to run, or zero if there's no such transaction. + TransactionID int64 `json:",omitempty"` + + // TransportNetwork is the network used by the DNS transport, which + // can be one of "doh", "dot", "tcp", "udp", or "system". + TransportNetwork string + + // TransportAddress is the address used by the DNS transport, which + // is of course relative to the TransportNetwork. + TransportAddress string +} + +// ResolveDoneEvent is emitted when we know the IP addresses of a +// specific domain name, or the resolution failed. +type ResolveDoneEvent struct { + // Addresses is the list of returned addresses (empty on error). + Addresses []string + + // ContainsBogons indicates whether Addresses contains one + // or more IP addresses that classify as bogons. + ContainsBogons bool + + // DialID is the identifier of the dial operation as + // part of which we're resolving this domain. + DialID int64 + + // DurationSinceBeginning is the number of nanoseconds since + // the time configured as the "zero" time. + DurationSinceBeginning time.Duration + + // Error is the result of the dial operation. + Error error + + // Hostname is the domain name to resolve. + Hostname string + + // TransactionID is the ID of the HTTP transaction that caused the + // current dial to run, or zero if there's no such transaction. + TransactionID int64 `json:",omitempty"` + + // TransportNetwork is the network used by the DNS transport, which + // can be one of "doh", "dot", "tcp", "udp", or "system". + TransportNetwork string + + // TransportAddress is the address used by the DNS transport, which + // is of course relative to the TransportNetwork. + TransportAddress string +} + +// X509Certificate is an x.509 certificate. +type X509Certificate struct { + // Data contains the certificate bytes in DER format. + Data []byte +} + +// TLSConnectionState contains the TLS connection state. +type TLSConnectionState struct { + CipherSuite uint16 + NegotiatedProtocol string + PeerCertificates []X509Certificate + Version uint16 +} + +// NewTLSConnectionState creates a new TLSConnectionState. +func NewTLSConnectionState(s tls.ConnectionState) TLSConnectionState { + return TLSConnectionState{ + CipherSuite: s.CipherSuite, + NegotiatedProtocol: s.NegotiatedProtocol, + PeerCertificates: SimplifyCerts(s.PeerCertificates), + Version: s.Version, + } +} + +// SimplifyCerts simplifies a certificate chain for archival +func SimplifyCerts(in []*x509.Certificate) (out []X509Certificate) { + for _, cert := range in { + out = append(out, X509Certificate{ + Data: cert.Raw, + }) + } + return +} + +// TLSHandshakeStartEvent is emitted when the TLS handshake starts. +type TLSHandshakeStartEvent struct { + // ConnID is the ID of the connection that started the TLS + // handshake, or zero if we don't know it. Typically, it is + // zero for connections managed by the HTTP transport, for + // which we know instead the TransactionID. + ConnID int64 `json:",omitempty"` + + // DurationSinceBeginning is the number of nanoseconds since + // the time configured as the "zero" time. + DurationSinceBeginning time.Duration + + // SNI is the SNI used when we force a specific SNI. + SNI string + + // TransactionID is the ID of the transaction that started + // this TLS handshake, or zero if we don't know it. Typically, + // it is zero for explicit dials, and it's nonzero instead + // when a connection is managed by HTTP code. + TransactionID int64 `json:",omitempty"` +} + +// TLSHandshakeDoneEvent is emitted when conn.Handshake returns. +type TLSHandshakeDoneEvent struct { + // ConnectionState is the TLS connection state. Depending on the + // error type, some fields may have little meaning. + ConnectionState TLSConnectionState + + // ConnID is the ID of the connection that started the TLS + // handshake, or zero if we don't know it. Typically, it is + // zero for connections managed by the HTTP transport, for + // which we know instead the TransactionID. + ConnID int64 + + // DurationSinceBeginning is the number of nanoseconds since + // the time configured as the "zero" time. + DurationSinceBeginning time.Duration + + // Error is the result of the TLS handshake. + Error error + + // TransactionID is the ID of the transaction that started + // this TLS handshake, or zero if we don't know it. Typically, + // it is zero for explicit dials, and it's nonzero instead + // when a connection is managed by HTTP code. + TransactionID int64 +} + +// WriteEvent is emitted when the WRITE/SEND syscall returns. +type WriteEvent struct { + // ConnID is the identifier of this connection. + ConnID int64 + + // DurationSinceBeginning is the number of nanoseconds since + // the time configured as the "zero" time. + DurationSinceBeginning time.Duration + + // Error is the error returned by WRITE/SEND. + Error error + + // NumBytes is the number of bytes sent, which may in + // principle also be nonzero on error. + NumBytes int64 + + // SyscallDuration is the number of nanoseconds we were + // blocked waiting for the syscall to return. + SyscallDuration time.Duration +} + +// Handler handles measurement events. +type Handler interface { + // OnMeasurement is called when an event occurs. There will be no + // events after the code that is using the modified Dialer, Transport, + // or Client is returned. OnMeasurement may be called by background + // goroutines and OnMeasurement calls may happen concurrently. + OnMeasurement(Measurement) +} + +// DNSResolver is a DNS resolver. The *net.Resolver used by Go implements +// this interface, but other implementations are possible. +type DNSResolver interface { + // LookupAddr performs a reverse lookup of an address. + LookupAddr(ctx context.Context, addr string) (names []string, err error) + + // LookupCNAME returns the canonical name of a given host. + LookupCNAME(ctx context.Context, host string) (cname string, err error) + + // LookupHost resolves a hostname to a list of IP addresses. + LookupHost(ctx context.Context, hostname string) (addrs []string, err error) + + // LookupMX resolves the DNS MX records for a given domain name. + LookupMX(ctx context.Context, name string) ([]*net.MX, error) + + // LookupNS resolves the DNS NS records for a given domain name. + LookupNS(ctx context.Context, name string) ([]*net.NS, error) +} + +// DNSRoundTripper represents an abstract DNS transport. +type DNSRoundTripper interface { + // RoundTrip sends a DNS query and receives the reply. + RoundTrip(ctx context.Context, query []byte) (reply []byte, err error) + + // RequiresPadding return true for DoH and DoT according to RFC8467 + RequiresPadding() bool +} + +// Dialer is a dialer for network connections. +type Dialer interface { + // Dial dials a new connection + Dial(network, address string) (net.Conn, error) + + // DialContext is like Dial but with context + DialContext(ctx context.Context, network, address string) (net.Conn, error) +} + +// TLSDialer is a dialer for TLS connections. +type TLSDialer interface { + // DialTLS dials a new TLS connection + DialTLS(network, address string) (net.Conn, error) + + // DialTLSContext is like DialTLS but with context + DialTLSContext(ctx context.Context, network, address string) (net.Conn, error) +} + +// ErrDNSBogon indicates that we found a bogon address. This is the +// correct value with which to initialize MeasurementRoot.ErrDNSBogon +// to tell this library to return an error when a bogon is found. +var ErrDNSBogon = errors.New("dns: detected bogon address") + +// MeasurementRoot is the measurement root. +// +// If you attach this to a context, we'll use it rather than using +// the beginning and hndler configured with resolvers, dialers, HTTP +// clients, and HTTP transports. By attaching a measurement root to +// a context, you can naturally split events by HTTP round trip. +type MeasurementRoot struct { + // Beginning is the "zero" used to compute the elapsed time. + Beginning time.Time + + // ErrDNSBogon is the kind of error that you would like this + // library to return when a bogon IP address is found. The + // default value, nil, causes this library to consider bogons + // as valid IP addresses. Setting this field to non-nil + // error causes the library instead fail when a bogon has + // been detected. The best value with which to initialize this + // field is the ErrDNSBogon variable in this package. + ErrDNSBogon error + + // Handler is the handler that will handle events. + Handler Handler + + // MaxBodySnapSize is the maximum size after which we'll stop + // reading request and response bodies. They will of course + // be fully transmitted, but we'll save only MaxBodySnapSize + // bytes as part of the event stream. If this value is negative, + // we use math.MaxInt64. If the value is zero, we use a + // reasonable large value. Otherwise, we'll use this value. + MaxBodySnapSize int64 + + // LookupHost allows to override the host lookup for all the request + // and dials that use this measurement root. + LookupHost func(ctx context.Context, hostname string) ([]string, error) +} + +type measurementRootContextKey struct{} + +type dummyHandler struct{} + +func (*dummyHandler) OnMeasurement(Measurement) {} + +// ContextMeasurementRoot returns the MeasurementRoot configured in the +// provided context, or a nil pointer, if not set. +func ContextMeasurementRoot(ctx context.Context) *MeasurementRoot { + root, _ := ctx.Value(measurementRootContextKey{}).(*MeasurementRoot) + return root +} + +// ContextMeasurementRootOrDefault returns the MeasurementRoot configured in +// the provided context, or a working, dummy, MeasurementRoot otherwise. +func ContextMeasurementRootOrDefault(ctx context.Context) *MeasurementRoot { + root := ContextMeasurementRoot(ctx) + if root == nil { + root = &MeasurementRoot{ + Beginning: time.Now(), + Handler: &dummyHandler{}, + } + } + return root +} + +// WithMeasurementRoot returns a copy of the context with the +// configured MeasurementRoot set. Panics if the provided root +// is a nil pointer, like httptrace.WithClientTrace. +// +// Merging more than one root is not supported. Setting again +// the root is just going to replace the original root. +func WithMeasurementRoot( + ctx context.Context, root *MeasurementRoot, +) context.Context { + if root == nil { + panic("nil measurement root") + } + return context.WithValue( + ctx, measurementRootContextKey{}, root, + ) +} diff --git a/netx/modelx/modelx_test.go b/netx/modelx/modelx_test.go new file mode 100644 index 00000000..00d3fdf5 --- /dev/null +++ b/netx/modelx/modelx_test.go @@ -0,0 +1,84 @@ +package modelx + +import ( + "context" + "crypto/tls" + "errors" + "math" + "testing" + "time" +) + +func TestNewTLSConnectionState(t *testing.T) { + conn, err := tls.Dial("tcp", "www.google.com:443", nil) + if err != nil { + t.Fatal(err) + } + state := NewTLSConnectionState(conn.ConnectionState()) + if len(state.PeerCertificates) < 1 { + t.Fatal("too few certificates") + } + if state.Version < tls.VersionSSL30 || state.Version > 0x0304 /*tls.VersionTLS13*/ { + t.Fatal("unexpected TLS version") + } +} + +func TestMeasurementRoot(t *testing.T) { + ctx := context.Background() + if ContextMeasurementRoot(ctx) != nil { + t.Fatal("unexpected value for ContextMeasurementRoot") + } + if ContextMeasurementRootOrDefault(ctx) == nil { + t.Fatal("unexpected value ContextMeasurementRootOrDefault") + } + handler := &dummyHandler{} + root := &MeasurementRoot{ + Handler: handler, + Beginning: time.Time{}, + } + ctx = WithMeasurementRoot(ctx, root) + v := ContextMeasurementRoot(ctx) + if v != root { + t.Fatal("unexpected ContextMeasurementRoot value") + } + v = ContextMeasurementRootOrDefault(ctx) + if v != root { + t.Fatal("unexpected ContextMeasurementRoot value") + } +} + +func TestMeasurementRootWithMeasurementRootPanic(t *testing.T) { + defer func() { + if recover() == nil { + t.Fatal("expected panic") + } + }() + ctx := context.Background() + ctx = WithMeasurementRoot(ctx, nil) +} + +func TestErrWrapperPublicAPI(t *testing.T) { + child := errors.New("mocked error") + wrapper := &ErrWrapper{ + Failure: "moobar", + WrappedErr: child, + } + if wrapper.Error() != "moobar" { + t.Fatal("The Error() method is misbehaving") + } + if wrapper.Unwrap() != child { + t.Fatal("The Unwrap() method is misbehaving") + } +} + +func TestUnitComputeBodySnapSize(t *testing.T) { + if ComputeBodySnapSize(-1) != math.MaxInt64 { + t.Fatal("unexpected result") + } + if ComputeBodySnapSize(0) != defaultBodySnapSize { + t.Fatal("unexpected result") + } + if ComputeBodySnapSize(127) != 127 { + t.Fatal("unexpected result") + } +} diff --git a/netx/netx.go b/netx/netx.go new file mode 100644 index 00000000..e2ea8059 --- /dev/null +++ b/netx/netx.go @@ -0,0 +1,147 @@ +// Package netx contains OONI's net extensions. +// +// This package provides a replacement for net.Dialer that can Dial, +// DialContext, and DialTLS. During its lifecycle this modified Dialer +// will emit network level events on a channel. +package netx + +import ( + "context" + "net" + "time" + + "github.com/ooni/probe-engine/netx/handlers" + "github.com/ooni/probe-engine/netx/internal" + "github.com/ooni/probe-engine/netx/modelx" +) + +// Dialer performs measurements while dialing. +type Dialer struct { + dialer *internal.Dialer +} + +// NewDialer returns a new Dialer instance. +func NewDialer(handler modelx.Handler) *Dialer { + return &Dialer{ + dialer: internal.NewDialer(time.Now(), handler), + } +} + +// NewDialerWithoutHandler returns a new Dialer instance. +func NewDialerWithoutHandler() *Dialer { + return &Dialer{ + dialer: internal.NewDialer(time.Now(), handlers.NoHandler), + } +} + +// ConfigureDNS configures the DNS resolver. The network argument +// selects the type of resolver. The address argument indicates the +// resolver address and depends on the network. +// +// This functionality is not goroutine safe. You should only change +// the DNS settings before starting to use the Dialer. +// +// The following is a list of all the possible network values: +// +// - "": behaves exactly like "system" +// +// - "system": this indicates that Go should use the system resolver +// and prevents us from seeing any DNS packet. The value of the +// address parameter is ignored when using "system". If you do +// not ConfigureDNS, this is the default resolver used. +// +// - "udp": indicates that we should send queries using UDP. In this +// case the address is a host, port UDP endpoint. +// +// - "tcp": like "udp" but we use TCP. +// +// - "dot": we use DNS over TLS (DoT). In this case the address is +// the domain name of the DoT server. +// +// - "doh": we use DNS over HTTPS (DoH). In this case the address is +// the URL of the DoH server. +// +// For example: +// +// d.ConfigureDNS("system", "") +// d.ConfigureDNS("udp", "8.8.8.8:53") +// d.ConfigureDNS("tcp", "8.8.8.8:53") +// d.ConfigureDNS("dot", "dns.quad9.net") +// d.ConfigureDNS("doh", "https://cloudflare-dns.com/dns-query") +func (d *Dialer) ConfigureDNS(network, address string) error { + return d.dialer.ConfigureDNS(network, address) +} + +// SetResolver is a more flexible way of configuring a resolver +// that should perhaps be used instead of ConfigureDNS. +func (d *Dialer) SetResolver(r modelx.DNSResolver) { + d.dialer.SetResolver(r) +} + +// Dial creates a TCP or UDP connection. See net.Dial docs. +func (d *Dialer) Dial(network, address string) (net.Conn, error) { + return d.dialer.Dial(network, address) +} + +// DialContext is like Dial but the context allows to interrupt a +// pending connection attempt at any time. +func (d *Dialer) DialContext( + ctx context.Context, network, address string, +) (net.Conn, error) { + return d.dialer.DialContext(ctx, network, address) +} + +// DialTLS is like Dial, but creates TLS connections. +func (d *Dialer) DialTLS(network, address string) (conn net.Conn, err error) { + return d.DialTLSContext(context.Background(), network, address) +} + +// DialTLSContext is like DialTLS, but with context +func (d *Dialer) DialTLSContext( + ctx context.Context, network, address string, +) (net.Conn, error) { + return d.dialer.DialTLSContext(ctx, network, address) +} + +// NewResolver returns a new resolver using the same handler of this +// Dialer. The arguments have the same meaning of ConfigureDNS. The +// returned resolver will not be used by this Dialer, and will not use +// this Dialer as well. The fact that it's a method of Dialer rather +// than an independent method is an historical oddity. There is also a +// standalone NewResolver factory and you should probably use it. +func (d *Dialer) NewResolver(network, address string) (modelx.DNSResolver, error) { + return internal.NewResolver(d.dialer.Beginning, d.dialer.Handler, network, address) +} + +// NewResolver is a standalone Dialer.NewResolver +func NewResolver(handler modelx.Handler, network, address string) (modelx.DNSResolver, error) { + return internal.NewResolver(time.Now(), handler, network, address) +} + +// NewResolverWithoutHandler creates a standalone Resolver +func NewResolverWithoutHandler(network, address string) (modelx.DNSResolver, error) { + return internal.NewResolver(time.Now(), handlers.NoHandler, network, address) +} + +// SetCABundle configures the dialer to use a specific CA bundle. This +// function is not goroutine safe. Make sure you call it before starting +// to use this specific dialer. +func (d *Dialer) SetCABundle(path string) error { + return d.dialer.SetCABundle(path) +} + +// ForceSpecificSNI forces using a specific SNI. +func (d *Dialer) ForceSpecificSNI(sni string) error { + return d.dialer.ForceSpecificSNI(sni) +} + +// ForceSkipVerify forces to skip certificate verification +func (d *Dialer) ForceSkipVerify() error { + return d.dialer.ForceSkipVerify() +} + +// ChainResolvers chains a primary and a secondary resolver such that +// we can fallback to the secondary if primary is broken. +func ChainResolvers(primary, secondary modelx.DNSResolver) modelx.DNSResolver { + return internal.ChainResolvers(primary, secondary) +} diff --git a/netx/netx_test.go b/netx/netx_test.go new file mode 100644 index 00000000..85500591 --- /dev/null +++ b/netx/netx_test.go @@ -0,0 +1,155 @@ +package netx_test + +import ( + "context" + "crypto/x509" + "errors" + "net" + "testing" + + "github.com/ooni/probe-engine/netx" + "github.com/ooni/probe-engine/netx/handlers" + "github.com/ooni/probe-engine/netx/internal/resolver/brokenresolver" +) + +func TestIntegrationDialer(t *testing.T) { + dialer := netx.NewDialerWithoutHandler() + err := dialer.ConfigureDNS("udp", "1.1.1.1:53") + if err != nil { + t.Fatal(err) + } + conn, err := dialer.Dial("tcp", "www.google.com:80") + if err != nil { + t.Fatal(err) + } + conn.Close() + conn, err = dialer.DialContext( + context.Background(), "tcp", "www.google.com:80", + ) + if err != nil { + t.Fatal(err) + } + conn.Close() + conn, err = dialer.DialTLS("tcp", "www.google.com:443") + if err != nil { + t.Fatal(err) + } + conn.Close() +} + +func TestIntegrationDialerWithSetResolver(t *testing.T) { + dialer := netx.NewDialer(handlers.NoHandler) + dialer.SetResolver(new(net.Resolver)) + conn, err := dialer.Dial("tcp", "www.google.com:80") + if err != nil { + t.Fatal(err) + } + conn.Close() + conn, err = dialer.DialContext( + context.Background(), "tcp", "www.google.com:80", + ) + if err != nil { + t.Fatal(err) + } + conn.Close() + conn, err = dialer.DialTLS("tcp", "www.google.com:443") + if err != nil { + t.Fatal(err) + } + conn.Close() +} + +func TestIntegrationResolver(t *testing.T) { + dialer := netx.NewDialer(handlers.NoHandler) + resolver, err := dialer.NewResolver("tcp", "1.1.1.1:53") + if err != nil { + t.Fatal(err) + } + addrs, err := resolver.LookupHost(context.Background(), "ooni.io") + if err != nil { + t.Fatal(err) + } + if len(addrs) < 1 { + t.Fatal("No addresses returned") + } +} + +func TestIntegrationStandaloneResolver(t *testing.T) { + resolver, err := netx.NewResolver(handlers.NoHandler, "tcp", "1.1.1.1:53") + if err != nil { + t.Fatal(err) + } + addrs, err := resolver.LookupHost(context.Background(), "ooni.io") + if err != nil { + t.Fatal(err) + } + if len(addrs) < 1 { + t.Fatal("No addresses returned") + } +} + +func TestIntegrationStandaloneResolverWithoutHandler(t *testing.T) { + resolver, err := netx.NewResolverWithoutHandler("tcp", "1.1.1.1:53") + if err != nil { + t.Fatal(err) + } + addrs, err := resolver.LookupHost(context.Background(), "ooni.io") + if err != nil { + t.Fatal(err) + } + if len(addrs) < 1 { + t.Fatal("No addresses returned") + } +} + +func TestSetCABundle(t *testing.T) { + dialer := netx.NewDialer(handlers.NoHandler) + err := dialer.SetCABundle("testdata/cacert.pem") + if err != nil { + t.Fatal(err) + } +} + +func TestForceSpecificSNI(t *testing.T) { + dialer := netx.NewDialer(handlers.NoHandler) + err := dialer.ForceSpecificSNI("www.facebook.com") + if err != nil { + t.Fatal(err) + } + conn, err := dialer.DialTLS("tcp", "www.google.com:443") + if err == nil { + t.Fatal("expected an error here") + } + var target x509.HostnameError + if errors.As(err, &target) == false { + t.Fatal("not the error we expected") + } + if conn != nil { + t.Fatal("expected nil conn here") + } +} + +func TestIntegrationDialTLSForceSkipVerify(t *testing.T) { + dialer := netx.NewDialer(handlers.NoHandler) + dialer.ForceSkipVerify() + conn, err := dialer.DialTLS("tcp", "self-signed.badssl.com:443") + if err != nil { + t.Fatal(err) + } + conn.Close() +} + +func TestChainResolvers(t *testing.T) { + fallback, err := netx.NewResolver(handlers.NoHandler, "udp", "1.1.1.1:53") + if err != nil { + t.Fatal(err) + } + dialer := netx.NewDialer(handlers.NoHandler) + resolver := netx.ChainResolvers(brokenresolver.New(), fallback) + dialer.SetResolver(resolver) + conn, err := dialer.Dial("tcp", "www.google.com:80") + if err != nil { + t.Fatal(err) + } + defer conn.Close() +} diff --git a/netx/testdata/cacert.pem b/netx/testdata/cacert.pem new file mode 100644 index 00000000..3a96ced8 --- /dev/null +++ b/netx/testdata/cacert.pem @@ -0,0 +1,54 @@ +# +# The following is a minimal, valid CA bundle. We do not include +# however the certificates required to validate www.google.com +# and we check in tests that we cannot connect to it and successfully +# complete a TLS handshake. This gives us confidence that we can +# actually override the CA bundle path. +# + +emSign ECC Root CA - C3 +======================= +-----BEGIN CERTIFICATE----- +MIICKzCCAbGgAwIBAgIKe3G2gla4EnycqDAKBggqhkjOPQQDAzBaMQswCQYDVQQGEwJVUzETMBEG +A1UECxMKZW1TaWduIFBLSTEUMBIGA1UEChMLZU11ZGhyYSBJbmMxIDAeBgNVBAMTF2VtU2lnbiBF +Q0MgUm9vdCBDQSAtIEMzMB4XDTE4MDIxODE4MzAwMFoXDTQzMDIxODE4MzAwMFowWjELMAkGA1UE +BhMCVVMxEzARBgNVBAsTCmVtU2lnbiBQS0kxFDASBgNVBAoTC2VNdWRocmEgSW5jMSAwHgYDVQQD +ExdlbVNpZ24gRUNDIFJvb3QgQ0EgLSBDMzB2MBAGByqGSM49AgEGBSuBBAAiA2IABP2lYa57JhAd +6bciMK4G9IGzsUJxlTm801Ljr6/58pc1kjZGDoeVjbk5Wum739D+yAdBPLtVb4OjavtisIGJAnB9 +SMVK4+kiVCJNk7tCDK93nCOmfddhEc5lx/h//vXyqaNCMEAwHQYDVR0OBBYEFPtaSNCAIEDyqOkA +B2kZd6fmw/TPMA4GA1UdDwEB/wQEAwIBBjAPBgNVHRMBAf8EBTADAQH/MAoGCCqGSM49BAMDA2gA +MGUCMQC02C8Cif22TGK6Q04ThHK1rt0c3ta13FaPWEBaLd4gTCKDypOofu4SQMfWh0/434UCMBwU +ZOR8loMRnLDRWmFLpg9J0wD8ofzkpf9/rdcw0Md3f76BB1UwUCAU9Vc4CqgxUQ== +-----END CERTIFICATE----- + +Hongkong Post Root CA 3 +======================= +-----BEGIN CERTIFICATE----- +MIIFzzCCA7egAwIBAgIUCBZfikyl7ADJk0DfxMauI7gcWqQwDQYJKoZIhvcNAQELBQAwbzELMAkG +A1UEBhMCSEsxEjAQBgNVBAgTCUhvbmcgS29uZzESMBAGA1UEBxMJSG9uZyBLb25nMRYwFAYDVQQK +Ew1Ib25na29uZyBQb3N0MSAwHgYDVQQDExdIb25na29uZyBQb3N0IFJvb3QgQ0EgMzAeFw0xNzA2 +MDMwMjI5NDZaFw00MjA2MDMwMjI5NDZaMG8xCzAJBgNVBAYTAkhLMRIwEAYDVQQIEwlIb25nIEtv +bmcxEjAQBgNVBAcTCUhvbmcgS29uZzEWMBQGA1UEChMNSG9uZ2tvbmcgUG9zdDEgMB4GA1UEAxMX +SG9uZ2tvbmcgUG9zdCBSb290IENBIDMwggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoICAQCz +iNfqzg8gTr7m1gNt7ln8wlffKWihgw4+aMdoWJwcYEuJQwy51BWy7sFOdem1p+/l6TWZ5Mwc50tf +jTMwIDNT2aa71T4Tjukfh0mtUC1Qyhi+AViiE3CWu4mIVoBc+L0sPOFMV4i707mV78vH9toxdCim +5lSJ9UExyuUmGs2C4HDaOym71QP1mbpV9WTRYA6ziUm4ii8F0oRFKHyPaFASePwLtVPLwpgchKOe +sL4jpNrcyCse2m5FHomY2vkALgbpDDtw1VAliJnLzXNg99X/NWfFobxeq81KuEXryGgeDQ0URhLj +0mRiikKYvLTGCAj4/ahMZJx2Ab0vqWwzD9g/KLg8aQFChn5pwckGyuV6RmXpwtZQQS4/t+TtbNe/ +JgERohYpSms0BpDsE9K2+2p20jzt8NYt3eEV7KObLyzJPivkaTv/ciWxNoZbx39ri1UbSsUgYT2u +y1DhCDq+sI9jQVMwCFk8mB13umOResoQUGC/8Ne8lYePl8X+l2oBlKN8W4UdKjk60FSh0Tlxnf0h ++bV78OLgAo9uliQlLKAeLKjEiafv7ZkGL7YKTE/bosw3Gq9HhS2KX8Q0NEwA/RiTZxPRN+ZItIsG +xVd7GYYKecsAyVKvQv83j+GjHno9UKtjBucVtT+2RTeUN7F+8kjDf8V1/peNRY8apxpyKBpADwID +AQABo2MwYTAPBgNVHRMBAf8EBTADAQH/MA4GA1UdDwEB/wQEAwIBBjAfBgNVHSMEGDAWgBQXnc0e +i9Y5K3DTXNSguB+wAPzFYTAdBgNVHQ4EFgQUF53NHovWOStw01zUoLgfsAD8xWEwDQYJKoZIhvcN +AQELBQADggIBAFbVe27mIgHSQpsY1Q7XZiNc4/6gx5LS6ZStS6LG7BJ8dNVI0lkUmcDrudHr9Egw +W62nV3OZqdPlt9EuWSRY3GguLmLYauRwCy0gUCCkMpXRAJi70/33MvJJrsZ64Ee+bs7Lo3I6LWld +y8joRTnU+kLBEUx3XZL7av9YROXrgZ6voJmtvqkBZss4HTzfQx/0TW60uhdG/H39h4F5ag0zD/ov ++BS5gLNdTaqX4fnkGMX41TiMJjz98iji7lpJiCzfeT2OnpA8vUFKOt1b9pq0zj8lMH8yfaIDlNDc +eqFS3m6TjRgm/VWsvY+b0s+v54Ysyx8Jb6NvqYTUc79NoXQbTiNg8swOqn+knEwlqLJmOzj/2ZQw +9nKEvmhVEA/GcywWaZMH/rFF7buiVWqw2rVKAiUnhde3t4ZEFolsgCs+l6mc1X5VTMbeRRAc6uk7 +nwNT7u56AQIWeNTowr5GdogTPyK7SBIdUgC0An4hGh6cJfTzPV4e0hz5sy229zdcxsshTrD3mUcY +hcErulWuBurQB7Lcq9CClnXO0lD+mefPL5/ndtFhKvshuzHQqp9HpLIiyhY6UFfEW0NnxWViA0kB +60PZ2Pierc+xYw5F9KBaLJstxabArahH9CdMOA0uG0k7UvToiIMrVCjU8jVStDKDYmlkDJGcn5fq +dBb9HxEGmpv0 +-----END CERTIFICATE-----