diff --git a/README.md b/README.md index 8449f00..6362720 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,18 @@ Forked from [looterz/grimd](https://github.com/looterz/grimd) +# Features +- [x] DNS over UTP +- [x] DNS over TCP +- [x] DNS over HTTP(S) (DoH as per [RFC-8484](https://datatracker.ietf.org/doc/html/rfc8484)) +- [x] Prometheus metrics API +- [x] Custom DNS records supports +- [x] Blocklist fetching +- [x] Hardcoded blocklist config +- [x] Hardcoded whitelist config +- [x] Fast startup _(so it can be used with templating for service discovery)_ +- [x] Small memory footprint (~50MBs with metrics and DoH enabled) + # Installation ``` go install github.com/cottand/grimd@latest @@ -46,7 +58,7 @@ Usage of grimd: # Building Requires golang 1.7 or higher, you build grimd like any other golang application, for example to build for linux x64 ```shell -env GOOS=linux GOARCH=amd64 go build -v github.com/looterz/grimd +env GOOS=linux GOARCH=amd64 go build -v github.com/cottand/grimd ``` # Building Docker @@ -78,8 +90,8 @@ These are some of the things I would like to contribute in this fork: - [x] ~~Fix multi-record responses issue#5~~ - [ ] DNS record flattening issue#1 - [ ] Service discovery integrations? issue#4 -- [ ] Prometheus metrics exporter issue#3 -- [ ] DNS over HTTPS #2 +- [x] Prometheus metrics exporter issue#3 +- [x] DNS over HTTPS #2 - [ ] Add lots of docs ## Non-objectives diff --git a/config.go b/config.go index 507be7e..4917243 100644 --- a/config.go +++ b/config.go @@ -1,18 +1,23 @@ package main import ( + cTls "crypto/tls" + "errors" "fmt" + "github.com/cottand/grimd/tls" "github.com/jonboulle/clockwork" "github.com/pelletier/go-toml/v2" "log" "os" + "path/filepath" + "strings" ) // BuildVersion returns the build version of grimd, this should be incremented every new release -var BuildVersion = "2.2.1" +var BuildVersion = "1.3.0" // ConfigVersion returns the version of grimd, this should be incremented every time the config changes so grimd presents a warning -var ConfigVersion = "2.2.1" +var ConfigVersion = "1.3.0" // Config holds the configuration parameters type Config struct { @@ -40,6 +45,7 @@ type Config struct { APIDebug bool DoH string Metrics Metrics `toml:"metrics"` + DnsOverHttpServer DnsOverHttpServer } type Metrics struct { @@ -47,6 +53,26 @@ type Metrics struct { Path string } +type DnsOverHttpServer struct { + Enabled bool + Bind string + TimeoutMs int64 + TLS TlsConfig + parsedTls *cTls.Config +} + +type TlsConfig struct { + certPath, keyPath, caPath string + enabled bool +} + +func (c TlsConfig) parsedConfig() (*cTls.Config, error) { + if !c.enabled { + return nil, nil + } + return tls.NewTLSConfig(c.certPath, c.keyPath, c.caPath) +} + var defaultConfig = ` # version this config was generated from version = "%s" @@ -133,13 +159,26 @@ togglename = "" # having been turned off. reactivationdelay = 300 -#Dns over HTTPS provider to use. +# Dns over HTTPS upstream provider to use DoH = "https://cloudflare-dns.com/dns-query" -# Prometheus metrics - enable +# Prometheus metrics - disabled by default [Metrics] enabled = false path = "/metrics" + +[DnsOverHttpServer] + enabled = false + bind = "0.0.0.0:80" + timeoutMs = 5000 + + # TLS config is not required for DoH if you have some proxy (ie, caddy, nginx, traefik...) manage HTTPS for you + [DnsOverHttpServer.TLS] + enabled = false + certPath = "" + keyPath = "" + # if empty, system CAs will be used + caPath = "" ` func parseDefaultConfig() Config { @@ -147,7 +186,7 @@ func parseDefaultConfig() Config { err := toml.Unmarshal([]byte(defaultConfig), &config) if err != nil { - logger.Fatalf("There was an error parsing the default config %v", err) + logger.Fatalf("There was an error parsing the default config: %v", err) } config.Version = ConfigVersion @@ -157,6 +196,19 @@ func parseDefaultConfig() Config { // WallClock is the wall clock var WallClock = clockwork.NewRealClock() +func contextualisedParsingErrorFrom(err error) error { + errString := strings.Builder{} + var derr *toml.DecodeError + _, _ = fmt.Fprint(&errString, "could not load config:", err) + if errors.As(err, &derr) { + errString.WriteByte('\n') + _, _ = fmt.Fprintln(&errString, derr.String()) + row, col := derr.Position() + _, _ = fmt.Fprintln(&errString, "error occurred at row", row, "column", col) + } + return errors.New(errString.String()) +} + // LoadConfig loads the given config file func LoadConfig(path string) (*Config, error) { @@ -167,9 +219,28 @@ func LoadConfig(path string) (*Config, error) { return &config, nil } - if err := toml.Unmarshal([]byte(path), &config); err != nil { - return nil, fmt.Errorf("could not load config: %s", err) + path = filepath.Clean(path) + file, err := os.Open(path) + if err != nil { + log.Printf("warning, failed to open config (%v) - using defaults", err) + return &config, nil + } + + defer func(file *os.File) { + _ = file.Close() + }(file) + + d := toml.NewDecoder(file) + + if err := d.Decode(&config); err != nil { + return nil, contextualisedParsingErrorFrom(err) + } + + dohTls, err := config.DnsOverHttpServer.TLS.parsedConfig() + if err != nil { + return nil, fmt.Errorf("could not load TLS config: %s", err) } + config.DnsOverHttpServer.parsedTls = dohTls if config.Version != ConfigVersion { if config.Version == "" { diff --git a/config_test.go b/config_test.go index 99522a7..ee0802c 100644 --- a/config_test.go +++ b/config_test.go @@ -3,6 +3,7 @@ package main import ( "github.com/pelletier/go-toml/v2" "github.com/stretchr/testify/assert" + "strings" "testing" ) @@ -35,3 +36,22 @@ path = "/voo" } assert.Equal(t, true, config.Metrics.Enabled, "expected overridden value for config.bind") } + +func TestFriendlyErrors(t *testing.T) { + config := parseDefaultConfig() + + badConfig := ` +[metrics] +enabled = 3 +` + + err := toml.Unmarshal([]byte(badConfig), &config) + if err == nil { + t.Fatalf("expected an error!") + } + err = contextualisedParsingErrorFrom(err) + + if !(strings.Contains(err.Error(), "enabled = 3") && strings.Contains(err.Error(), "row 3 column 11")) { + t.Fatalf("expected error string to contain contextual info, but was %v", err.Error()) + } +} diff --git a/doh.go b/doh.go new file mode 100644 index 0000000..8c4a8ba --- /dev/null +++ b/doh.go @@ -0,0 +1,255 @@ +package main + +import ( + "context" + "crypto/tls" + "encoding/base64" + "fmt" + "github.com/cottand/grimd/internal/metric" + "github.com/miekg/dns" + "github.com/prometheus/client_golang/prometheus" + "io" + stdlog "log" + "net" + "net/http" + "time" +) + +/** +This implementation is heavily inspired by CoreDNS and used as per their Apache 2 license +see https://github.com/coredns/coredns/blob/v1.11.1/core/dnsserver/server_https.go + +There is no NOTICE redistribution as, at the time of producing the derivative work, CoreDNS did +not distribute such a notice with their work. +*/ + +const mimeTypeDOH = "application/dns-message" + +// pathDOH is the URL path that should be used. +const pathDOH = "/dns-query" + +// ServerHTTPS represents an instance of a DNS-over-HTTPS server. +type ServerHTTPS struct { + Net string + handler dns.Handler + httpsServer *http.Server + tlsConfig *tls.Config + validRequest func(*http.Request) bool + bind string + ttl time.Duration +} + +// loggerAdapter is a simple adapter around CoreDNS logger made to implement io.Writer in order to log errors from HTTP server +type loggerAdapter struct { +} + +func (l *loggerAdapter) Write(p []byte) (n int, err error) { + logger.Debugf("Writing HTTP request=%v", string(p)) + return len(p), nil +} + +// NewServerHTTPS returns a new HTTPS server capable of performing DoH with dns +func NewServerHTTPS( + dns dns.Handler, + bind string, + timeout time.Duration, + ttl time.Duration, + tls *tls.Config, +) (*ServerHTTPS, error) { + + // http/2 is recommended when using DoH. We need to specify it in next protos + // or the upgrade won't happen. + if tls != nil { + tls.NextProtos = []string{"h2", "http/1.1"} + } + + // Use a custom request validation func or use the standard DoH path check. + + srv := &http.Server{ + ReadTimeout: timeout, + WriteTimeout: timeout, + ErrorLog: stdlog.New(&loggerAdapter{}, "", 0), + Addr: bind, + } + sh := &ServerHTTPS{ + handler: dns, httpsServer: srv, ttl: ttl, bind: bind, + } + srv.Handler = sh + + return sh, nil +} + +func (s *ServerHTTPS) ListenAndServe() error { + return s.httpsServer.ListenAndServe() +} + +// Stop stops the server. It blocks until the server is totally stopped. +func (s *ServerHTTPS) Stop() error { + if s.httpsServer != nil { + _ = s.httpsServer.Shutdown(context.Background()) + } + return nil +} + +// ServeHTTP is the handler that gets the HTTP request and converts to the dns format, calls the resolver, +// converts it back and write it to the client. +func (s *ServerHTTPS) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if !(r.URL.Path == pathDOH) { + http.Error(w, "", http.StatusNotFound) + countResponse(http.StatusNotFound) + return + } + + msg, err := requestToMsg(r) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + countResponse(http.StatusBadRequest) + logger.Noticef("error when serving DoH request: %v", err) + return + } + + var writer = DohResponseWriter{remoteAddr: r.RemoteAddr, host: r.Host, delegate: w, completed: make(chan empty, 1)} + s.handler.ServeDNS(&writer, msg) + _, ok := <-writer.completed + if writer.err != nil || ok != true { + return + } + + age := s.ttl // seconds + + w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%v", age.Seconds())) +} + +func countResponse(status int) { + metric.DohResponseCount.With(prometheus.Labels{"status": fmt.Sprint(status)}) +} + +// Shutdown stops the server (non gracefully). +func (s *ServerHTTPS) Shutdown() { + if s.httpsServer != nil { + _ = s.httpsServer.Shutdown(context.Background()) + } +} + +func requestToMsg(req *http.Request) (*dns.Msg, error) { + if req.Method == "GET" { + return getRequestToMsg(req) + } + if req.Method == "POST" { + return postRequestToMsg(req) + } + return nil, fmt.Errorf("unexpected method for DoH request %v", req.Method) +} + +// postRequestToMsg extracts the dns message from the request body. +func postRequestToMsg(req *http.Request) (*dns.Msg, error) { + defer func(Body io.ReadCloser) { + _ = Body.Close() + }(req.Body) + + buf, err := io.ReadAll(req.Body) + if err != nil { + return nil, err + } + m := new(dns.Msg) + err = m.Unpack(buf) + return m, err +} + +// getRequestToMsg extract the dns message from the GET request. +func getRequestToMsg(req *http.Request) (*dns.Msg, error) { + values := req.URL.Query() + b64, ok := values["dns"] + if !ok { + return nil, fmt.Errorf("no 'dns' query parameter found") + } + if len(b64) != 1 { + return nil, fmt.Errorf("multiple 'dns' query values found") + } + return base64ToMsg(b64[0]) +} + +func base64ToMsg(b64 string) (*dns.Msg, error) { + buf, err := base64.RawURLEncoding.DecodeString(b64) + if err != nil { + return nil, err + } + + m := new(dns.Msg) + err = m.Unpack(buf) + + return m, err +} + +type empty struct{} + +// DohResponseWriter implements dns.ResponseWriter +type DohResponseWriter struct { + msg *dns.Msg + remoteAddr string + delegate http.ResponseWriter + host string + err error + completed chan empty +} + +// See section 4.2.1 of RFC 8484. +// We are using code 500 to indicate an unexpected situation when the chain +// handler has not provided any response message. +func (w *DohResponseWriter) handleErr(err error) { + logger.Warningf("error when replying to DoH: %v", err) + http.Error(w.delegate, "No response", http.StatusInternalServerError) + countResponse(http.StatusInternalServerError) + w.err = err + return +} + +func (w *DohResponseWriter) LocalAddr() net.Addr { + addr, _ := net.ResolveTCPAddr("tcp", w.remoteAddr) + return addr +} + +func (w *DohResponseWriter) RemoteAddr() net.Addr { + addr, _ := net.ResolveTCPAddr("tcp", w.remoteAddr) + return addr +} + +func (w *DohResponseWriter) WriteMsg(msg *dns.Msg) error { + defer func() { + w.completed <- empty{} + close(w.completed) + }() + w.msg = msg + buf, err := msg.Pack() + if err != nil { + w.handleErr(err) + return err + } + w.delegate.Header().Set("Content-Type", mimeTypeDOH) + _, err = w.Write(buf) + if err != nil { + w.handleErr(err) + return err + } + countResponse(http.StatusOK) + return nil +} + +func (w *DohResponseWriter) Write(bytes []byte) (int, error) { + return w.delegate.Write(bytes) +} + +func (w *DohResponseWriter) Close() error { + return nil +} + +func (w *DohResponseWriter) TsigStatus() error { + return nil +} + +func (w *DohResponseWriter) TsigTimersOnly(_ bool) { +} + +func (w *DohResponseWriter) Hijack() { + return +} diff --git a/doh_test.go b/doh_test.go new file mode 100644 index 0000000..468135b --- /dev/null +++ b/doh_test.go @@ -0,0 +1,77 @@ +package main + +import ( + "github.com/miekg/dns" + "net/http" + "strings" + "testing" + "time" +) + +func dnsAQuestion(question string) (msg *dns.Msg) { + msg = new(dns.Msg) + msg.SetQuestion(question, dns.TypeA) + return msg +} + +func TestDohHappyPath(t *testing.T) { + handler := dns.NewServeMux() + custom := NewCustomDNSRecordsFromText([]string{"example.com. IN A 10.0.0.0 "}) + handler.HandleFunc("example.com", custom[0].serve(nil)) + + dohTest(t, handler, func(r Resolver, bind string) { + response, err := r.DoHLookup("http://"+bind+"/dns-query", 1, dnsAQuestion("example.com.")) + + if err != nil { + t.Fatalf("unexpected error during lookup %v", err) + } + + if !strings.Contains(response.Answer[0].String(), "10.0.0.0") { + t.Fatalf("failed to answer dns query for example.org - expected 10.0.0.0 but got %v", response.Answer) + } + }) + +} + +func TestDoh404(t *testing.T) { + handler := dns.NewServeMux() + custom := NewCustomDNSRecordsFromText([]string{"example.com A 10.0.0.0"}) + handler.HandleFunc("example.com", custom[0].serve(nil)) + + dohTest(t, handler, func(r Resolver, bind string) { + resp, err := http.Get("http://" + bind + "/unknown-path") + + if resp.StatusCode != 404 { + t.Fatalf("expected 404 but got %v", resp.StatusCode) + } + + if err != nil { + t.Fatalf("unexpected error during lookup %v", err) + } + }) +} +func dohTest(t *testing.T, handler dns.Handler, doTest func(r Resolver, bind string)) { + bind := "localhost:7698" + config := parseDefaultConfig() + loggingState, _ := loggerInit(config.LogConfig) + config.DnsOverHttpServer.Bind = bind + defer func() { + loggingState.cleanUp() + }() + doh, err := NewServerHTTPS(handler, bind, time.Second*5, time.Second*5, nil) + defer doh.Shutdown() + + if err != nil { + t.Fatalf("error when tarting server %v", err) + } + + go func() { + _ = doh.httpsServer.ListenAndServe() + }() + + time.Sleep(100 * time.Millisecond) + + r := Resolver{} + + doTest(r, bind) +} diff --git a/go.mod b/go.mod index 8cf2d8f..b69e428 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/cottand/grimd require ( github.com/gin-gonic/gin v1.9.1 github.com/jonboulle/clockwork v0.3.0 - github.com/miekg/dns v1.1.50 + github.com/miekg/dns v1.1.56 github.com/op/go-logging v0.0.0-20160315200505-970db520ece7 github.com/pelletier/go-toml/v2 v2.1.0 github.com/prometheus/client_golang v1.16.0 @@ -43,10 +43,10 @@ require ( github.com/ugorji/go/codec v1.2.11 // indirect golang.org/x/arch v0.3.0 // indirect golang.org/x/crypto v0.14.0 // indirect - golang.org/x/mod v0.8.0 // indirect + golang.org/x/mod v0.12.0 // indirect golang.org/x/net v0.17.0 // indirect golang.org/x/text v0.13.0 // indirect - golang.org/x/tools v0.6.0 // indirect + golang.org/x/tools v0.13.0 // indirect google.golang.org/protobuf v1.30.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 3ffcc21..e083221 100644 --- a/go.sum +++ b/go.sum @@ -56,6 +56,8 @@ github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zk github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4= github.com/miekg/dns v1.1.50 h1:DQUfb9uc6smULcREF09Uc+/Gd46YWqJd5DbpPE9xkcA= github.com/miekg/dns v1.1.50/go.mod h1:e3IlAVfNqAllflbibAZEWOXOQ+Ynzk/dDozDxY7XnME= +github.com/miekg/dns v1.1.56 h1:5imZaSeoRNvpM9SzWNhEcP9QliKiz20/dA2QabIGVnE= +github.com/miekg/dns v1.1.56/go.mod h1:cRm6Oo2C8TY9ZS/TqsSrseAcncm74lfK5G+ikN2SWWY= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -105,6 +107,8 @@ golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc= +golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= @@ -138,6 +142,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.1.6-0.20210726203631-07bc1bf47fb2/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/tools v0.13.0 h1:Iey4qkscZuv0VvIt8E0neZjtPVQFSc870HQ448QgEmQ= +golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/grimd_test.go b/grimd_test.go index a16f866..bfa1042 100644 --- a/grimd_test.go +++ b/grimd_test.go @@ -3,6 +3,8 @@ package main import ( "fmt" "github.com/pelletier/go-toml/v2" + "io" + "net/http" "strings" "testing" "time" @@ -148,6 +150,58 @@ func Test2in3DifferentARecords(t *testing.T) { ) } +func TestDohIntegration(t *testing.T) { + dohBind := "localhost:8181" + integrationTest(func(c *Config) { + c.DnsOverHttpServer.Bind = dohBind + c.DnsOverHttpServer.Enabled = true + c.CustomDNSRecords = []string{"example.com IN A 10.10.0.1 "} + }, func(_ *dns.Client, _ string) { + r := Resolver{} + + response, err := r.DoHLookup("http://"+dohBind+"/dns-query", 1, dnsAQuestion("example.com.")) + + if err != nil { + t.Fatalf("unexpected error during lookup %v", err) + } + + if !strings.Contains(response.Answer[0].String(), "10.10.0.1") { + t.Fatalf("failed to answer dns query for example.org - expected 10.10.0.1 but got %v", response.Answer) + } + + }) +} + +// TestDohAsProxy checks that DoH works for non-custom records +func TestDohAsProxy(t *testing.T) { + dohBind := "localhost:8181" + integrationTest(func(c *Config) { + c.DnsOverHttpServer.Bind = dohBind + c.DnsOverHttpServer.Enabled = true + }, func(_ *dns.Client, _ string) { + resp, err := http.Get("http://" + dohBind + "/dns-query?dns=AAABAAABAAAAAAAAA3d3dwdleGFtcGxlA2NvbQAAAQAB") + + if err != nil { + t.Fatalf("unexpected error during lookup %v", err) + } + respPacket, err := io.ReadAll(resp.Body) + + defer func(Body io.ReadCloser) { + _ = Body.Close() + }(resp.Body) + + msg := dns.Msg{} + err = msg.Unpack(respPacket) + if err != nil { + t.Fatalf("unexpected error during lookup %v (response len=%vB)", err, len(respPacket)) + } + + if len(msg.Answer) < 1 { + t.Fatalf("failed to answer dns query for example.com - expected some answer but got nothing") + } + + }) +} func TestConfigReloadForCustomRecords(t *testing.T) { testDnsHost := "127.0.0.1:5300" var config Config diff --git a/handler.go b/handler.go index a68ccb7..424eb4d 100644 --- a/handler.go +++ b/handler.go @@ -104,6 +104,8 @@ func (h *DNSHandler) do(config *Config, blockCache *MemoryBlockCache, questionCa var remote net.IP if Net == "tcp" { remote = w.RemoteAddr().(*net.TCPAddr).IP + } else if Net == "http" { + remote = w.RemoteAddr().(*net.TCPAddr).IP } else { remote = w.RemoteAddr().(*net.UDPAddr).IP } @@ -275,19 +277,27 @@ func (h *DNSHandler) do(config *Config, blockCache *MemoryBlockCache, questionCa // DoTCP begins a tcp query func (h *DNSHandler) DoTCP(w dns.ResponseWriter, req *dns.Msg) { h.muActive.RLock() + defer h.muActive.RUnlock() if h.active { h.requestChannel <- DNSOperationData{"tcp", w, req} } - h.muActive.RUnlock() } // DoUDP begins a udp query func (h *DNSHandler) DoUDP(w dns.ResponseWriter, req *dns.Msg) { h.muActive.RLock() + defer h.muActive.RUnlock() if h.active { h.requestChannel <- DNSOperationData{"udp", w, req} } - h.muActive.RUnlock() +} + +func (h *DNSHandler) DoHTTP(w dns.ResponseWriter, req *dns.Msg) { + h.muActive.RLock() + defer h.muActive.RUnlock() + if h.active { + h.requestChannel <- DNSOperationData{"http", w, req} + } } // HandleFailed handles dns failures diff --git a/internal/metric/metric.go b/internal/metric/metric.go index f45684b..8b533e6 100644 --- a/internal/metric/metric.go +++ b/internal/metric/metric.go @@ -53,6 +53,13 @@ var ( Name: "config_reload_customdns", Help: "Custom DNS config reloads", }) + + DohResponseCount = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: Namespace, + Name: "doh_response_count", + Help: "Successful DoH responses", + }, []string{"status"}) ) func init() { diff --git a/resolver.go b/resolver.go index 6f18232..8ca25f4 100644 --- a/resolver.go +++ b/resolver.go @@ -71,8 +71,8 @@ func (r *Resolver) Lookup(net string, req *dns.Msg, timeout int, interval int, n c := &dns.Client{ Net: net, - ReadTimeout: r.Timeout(timeout), - WriteTimeout: r.Timeout(timeout), + ReadTimeout: time.Duration(timeout) * time.Second, + WriteTimeout: time.Duration(timeout) * time.Second, } qname := req.Question[0].Name @@ -131,11 +131,6 @@ func (r *Resolver) Lookup(net string, req *dns.Msg, timeout int, interval int, n } } -// Timeout returns the resolver timeout -func (r *Resolver) Timeout(timeout int) time.Duration { - return time.Duration(timeout) * time.Second -} - // DoHLookup performs a DNS lookup over https func (r *Resolver) DoHLookup(url string, timeout int, req *dns.Msg) (msg *dns.Msg, err error) { qname := req.Question[0].Name @@ -155,7 +150,7 @@ func (r *Resolver) DoHLookup(url string, timeout int, req *dns.Msg) (msg *dns.Ms //Make the request to the server client := http.Client{ - Timeout: r.Timeout(timeout), + Timeout: time.Duration(timeout) * time.Second, } reader := bytes.NewReader(data) diff --git a/server.go b/server.go index a3be837..8283dd7 100644 --- a/server.go +++ b/server.go @@ -15,8 +15,10 @@ type Server struct { handler *DNSHandler udpServer *dns.Server tcpServer *dns.Server + httpServer *ServerHTTPS udpHandler *dns.ServeMux tcpHandler *dns.ServeMux + httpHandler *dns.ServeMux activeHandlerPatterns []string } @@ -35,18 +37,23 @@ func (s *Server) Run( udpHandler := dns.NewServeMux() udpHandler.HandleFunc(".", s.handler.DoUDP) + httpHandler := dns.NewServeMux() + httpHandler.HandleFunc(".", s.handler.DoHTTP) + handlerPatterns := make([]string, len(config.CustomDNSRecords)) for _, record := range NewCustomDNSRecordsFromText(config.CustomDNSRecords) { dnsHandler := record.serve(s.handler) tcpHandler.HandleFunc(record.name, dnsHandler) udpHandler.HandleFunc(record.name, dnsHandler) + httpHandler.HandleFunc(record.name, dnsHandler) handlerPatterns = append(handlerPatterns, record.name) } s.activeHandlerPatterns = handlerPatterns s.tcpHandler = tcpHandler s.udpHandler = udpHandler + s.httpHandler = httpHandler s.tcpServer = &dns.Server{ Addr: s.host, @@ -65,6 +72,16 @@ func (s *Server) Run( WriteTimeout: s.wTimeout, } + if config.DnsOverHttpServer.Enabled { + var err error + timeout := time.Duration(config.DnsOverHttpServer.TimeoutMs) * time.Millisecond + ttl := time.Duration(config.TTL) * time.Second + s.httpServer, err = NewServerHTTPS(httpHandler, config.DnsOverHttpServer.Bind, timeout, ttl, config.DnsOverHttpServer.parsedTls) + if err != nil { + logger.Criticalf("failed to create http server %v", err) + } + go s.startHttp(config.DnsOverHttpServer.Bind) + } go s.start(s.udpServer) go s.start(s.tcpServer) } @@ -77,6 +94,14 @@ func (s *Server) start(ds *dns.Server) { } } +func (s *Server) startHttp(addr string) { + logger.Criticalf("start http listener on %s\n", addr) + + if err := s.httpServer.ListenAndServe(); err != nil { + logger.Criticalf("start http listener on %s failed or was closed: %s\n", addr, err.Error()) + } +} + // Stop stops the server func (s *Server) Stop() { if s.handler != nil { @@ -97,6 +122,13 @@ func (s *Server) Stop() { logger.Critical(err) } } + + if s.httpServer != nil { + err := s.httpServer.Stop() + if err != nil { + logger.Critical(err) + } + } } // ReloadConfig only supports reloading the customDnsRecords section of the config for now @@ -118,12 +150,14 @@ func (s *Server) ReloadConfig(config *Config) { for _, deleted := range deletedRecords { s.tcpHandler.HandleRemove(deleted) s.udpHandler.HandleRemove(deleted) + s.httpHandler.HandleRemove(deleted) } for _, record := range newRecords { dnsHandler := record.serve(s.handler) s.tcpHandler.HandleFunc(record.name, dnsHandler) s.udpHandler.HandleFunc(record.name, dnsHandler) + s.httpHandler.HandleFunc(record.name, dnsHandler) } s.activeHandlerPatterns = newRecordsPatterns } diff --git a/shell.nix b/shell.nix new file mode 100644 index 0000000..e9017fe --- /dev/null +++ b/shell.nix @@ -0,0 +1,4 @@ +{ pkgs ? import (builtins.fetchTarball "https://api.github.com/repos/nixos/nixpkgs/tarball/nixos-unstable") {} }: + pkgs.mkShell { + nativeBuildInputs = with pkgs; [ go_1_21 ]; +} diff --git a/tls/tls.go b/tls/tls.go new file mode 100644 index 0000000..64b2075 --- /dev/null +++ b/tls/tls.go @@ -0,0 +1,62 @@ +package tls + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "os" + "path/filepath" +) + +func setTLSDefaults(ctls *tls.Config) { + ctls.MinVersion = tls.VersionTLS12 + ctls.MaxVersion = tls.VersionTLS13 + ctls.CipherSuites = []uint16{ + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + } +} + +// NewTLSConfig returns a TLS config that includes a certificate +// Use for server TLS config or when using a client certificate +// If caPath is empty, system CAs will be used +func NewTLSConfig(certPath, keyPath, caPath string) (*tls.Config, error) { + cert, err := tls.LoadX509KeyPair(certPath, keyPath) + if err != nil { + return nil, fmt.Errorf("could not load TLS cert: %s", err) + } + + roots, err := loadRoots(caPath) + if err != nil { + return nil, err + } + + tlsConfig := &tls.Config{Certificates: []tls.Certificate{cert}, RootCAs: roots, MinVersion: tls.VersionTLS12} + setTLSDefaults(tlsConfig) + + return tlsConfig, nil +} + +func loadRoots(caPath string) (*x509.CertPool, error) { + if caPath == "" { + return nil, nil + } + + roots := x509.NewCertPool() + pem, err := os.ReadFile(filepath.Clean(caPath)) + if err != nil { + return nil, fmt.Errorf("error reading %s: %s", caPath, err) + } + ok := roots.AppendCertsFromPEM(pem) + if !ok { + return nil, fmt.Errorf("could not read root certs: %s", err) + } + return roots, nil +}