Skip to content

Commit

Permalink
feat(RFC-1034): implement CNAME following
Browse files Browse the repository at this point in the history
  • Loading branch information
cottand committed Nov 7, 2023
1 parent 2add052 commit a46714f
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 80 deletions.
20 changes: 11 additions & 9 deletions grimd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ func integrationTest(changeConfig func(c *Config), test func(client *dns.Client,

go startActivation(actChannel, quitActivation, config.ReactivationDelay)
grimdActivation = <-actChannel
grimdActive = true
close(actChannel)

server := &Server{
Expand All @@ -52,6 +53,9 @@ func integrationTest(changeConfig func(c *Config), test func(client *dns.Client,

// BlockCache contains all blocked domains
blockCache := &MemoryBlockCache{Backend: make(map[string]bool)}
for _, blocked := range config.Blocklist {
_ = blockCache.Set(blocked, true)
}
// QuestionCache contains all queries to the dns server
questionCache := makeQuestionCache(config.QuestionCacheCap)

Expand Down Expand Up @@ -162,10 +166,10 @@ func TestCnameFollowHappyPath(t *testing.T) {
func(c *Config) {
c.CustomDNSRecords = []string{
"first.com IN CNAME second.com ",
"second.com IN CNAME first.com ",
"second.com IN CNAME third.com ",
"third.com IN A 10.10.0.42 ",
}

c.Timeout = 10000
},
func(client *dns.Client, target string) {
c := new(dns.Client)
Expand All @@ -175,8 +179,7 @@ func TestCnameFollowHappyPath(t *testing.T) {
m.SetQuestion(dns.Fqdn("first.com"), dns.TypeA)
reply, _, err := c.Exchange(m, target)
if err != nil {
t.Error(err)
t.FailNow()
t.Fatalf("failed to exchange %v", err)
}
if l := len(reply.Answer); l != 3 {
t.Fatalf("Expected 3 returned records but had %v: %v", l, reply.Answer)
Expand All @@ -195,10 +198,9 @@ func TestCnameFollowWithBlocked(t *testing.T) {
func(c *Config) {
c.CustomDNSRecords = []string{
"first.com IN CNAME second.com ",
"second.com IN CNAME first.com ",
"third.com IN A 10.10.0.42 ",
"second.com IN CNAME example.com ",
}
c.Blocklist = []string{"second.com"}
c.Blocklist = []string{"example.com"}

},
func(client *dns.Client, target string) {
Expand All @@ -212,8 +214,8 @@ func TestCnameFollowWithBlocked(t *testing.T) {
t.Error(err)
t.FailNow()
}
if slices.ContainsFunc(reply.Answer, contains("10.10.0.42")) {
t.Fatalf("Expected right A address to be blocked, but got %v", reply.Answer[0])
if !slices.ContainsFunc(reply.Answer, contains("0.0.0.0")) {
t.Fatalf("Expected right A address to be blocked, but got \n%v", reply.String())
}
},
)
Expand Down
69 changes: 45 additions & 24 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ type EventLoop struct {
config *Config
blockCache *MemoryBlockCache
questionCache *MemoryQuestionCache
customDns *CustomRecordsResolver
}

// DNSOperationData type
Expand Down Expand Up @@ -87,6 +88,7 @@ func NewEventLoop(config *Config, blockCache *MemoryBlockCache, questionCache *M
questionCache: questionCache,
active: true,
config: config,
customDns: NewCustomRecordsResolver(NewCustomDNSRecordsFromText(config.CustomDNSRecords)),
}

go handler.do()
Expand All @@ -105,7 +107,20 @@ func (h *EventLoop) do() {
}

// responseFor has side-effects, like writing to h's caches, so avoid calling it concurrently
func (h *EventLoop) responseFor(Net string, req *dns.Msg, remote net.IP) (_ *dns.Msg, success bool) {
func (h *EventLoop) responseFor(Net string, req *dns.Msg, _local net.Addr, _remote net.Addr) (_ *dns.Msg, success bool) {

var remote net.IP
if Net == "tcp" || Net == "http" {
remote = _remote.(*net.TCPAddr).IP
} else {
remote = _remote.(*net.UDPAddr).IP
}

// first of all, check custom DNS. No need to cache it because it is already in-mem and precedes the blocking
if custom := h.customDns.Resolve(req, _local, _remote); custom != nil {
return custom, true
}

q := req.Question[0]
Q := Question{UnFqdn(q.Name), dns.TypeToString[q.Qtype], dns.ClassToString[q.Qclass]}
logger.Infof("%s lookup %s\n", remote, Q.String())
Expand Down Expand Up @@ -272,16 +287,7 @@ func (h *EventLoop) doRequest(Net string, w dns.ResponseWriter, req *dns.Msg) {
}
}(w)

var remote net.IP
if Net == "tcp" {
remote = w.RemoteAddr().(*net.TCPAddr).IP
} else if Net == "http" {
remote = w.RemoteAddr().(*net.TCPAddr).IP
} else {
remote = w.RemoteAddr().(*net.UDPAddr).IP
}

resp, ok := h.responseFor(Net, req, remote)
resp, ok := h.responseFor(Net, req, w.LocalAddr(), w.RemoteAddr())

if !ok {
m := new(dns.Msg)
Expand All @@ -292,15 +298,24 @@ func (h *EventLoop) doRequest(Net string, w dns.ResponseWriter, req *dns.Msg) {
}

depthSoFar := uint32(0)
for cnames, ok := canFollow(resp); h.config.FollowCnameDepth > depthSoFar && ok; {
for h.config.FollowCnameDepth > depthSoFar {
cnames, ok := canFollow(req, resp)
depthSoFar++
if !ok {
break
}
for _, cname := range cnames {
r := dns.Msg{}
r.SetQuestion(cname.Target, req.Question[0].Qtype)
followed, ok := h.responseFor(Net, &r, remote)
if ok {
resp.Answer = append(resp.Answer, followed.Answer...)
followed, ok := h.responseFor(Net, &r, w.LocalAddr(), w.RemoteAddr())
for _, fAnswer := range followed.Answer {
containsNewAnswer := func(rr dns.RR) bool {
return rr.String() == fAnswer.String()
}
if ok && !slices.ContainsFunc(resp.Answer, containsNewAnswer) {
resp.Answer = append(resp.Answer, fAnswer)
}
}
depthSoFar++
}
}

Expand All @@ -309,28 +324,34 @@ func (h *EventLoop) doRequest(Net string, w dns.ResponseWriter, req *dns.Msg) {
}

// determines if resp contains no A records but some CNAME record
func canFollow(resp *dns.Msg) (cnames []dns.CNAME, ok bool) {
func canFollow(req *dns.Msg, resp *dns.Msg) (cnames []*dns.CNAME, ok bool) {
if req.Question[0].Qtype == dns.TypeCNAME {
return []*dns.CNAME{}, false
}

isAnswer := func(rr dns.RR) bool {
return rr.Header().Rrtype == dns.TypeA && rr.Header().Rrtype == dns.TypeAAAA
isA := func(rr dns.RR) bool {
return rr.Header().Rrtype == dns.TypeA || rr.Header().Rrtype == dns.TypeAAAA
}

isCname := func(rr dns.RR) bool {
return rr.Header().Rrtype == dns.TypeCNAME
}

ok = !slices.ContainsFunc(resp.Answer, isA) && slices.ContainsFunc(resp.Answer, isCname)
for _, rr := range resp.Answer {
if asCname, ok := rr.(*dns.CNAME); isCname(rr) && ok {
cnames = append(cnames, *asCname)
cnames = append(cnames, asCname)
}
}

ok = !slices.ContainsFunc(resp.Answer, isAnswer) && slices.ContainsFunc(resp.Answer, isCname)

return cnames, ok

return cnames, ok && len(cnames) != 0
}

// msg:
// Q: A fst.com
// A: CN snd.com, thrd.com
//

// DoTCP begins a tcp query
func (h *EventLoop) DoTCP(w dns.ResponseWriter, req *dns.Msg) {
h.muActive.RLock()
Expand Down
61 changes: 61 additions & 0 deletions records.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"github.com/cottand/grimd/internal/metric"
"github.com/miekg/dns"
"net"
)

type CustomDNSRecords struct {
Expand Down Expand Up @@ -57,3 +58,63 @@ func (records CustomDNSRecords) asHandler() func(dns.ResponseWriter, *dns.Msg) {
metric.ReportDNSResponse(writer, m, false)
}
}

// CustomRecordsResolver allows faking an in-mem DNS server just for custom records
type CustomRecordsResolver struct {
mux *dns.ServeMux
}

func NewCustomRecordsResolver(records []CustomDNSRecords) *CustomRecordsResolver {
mux := dns.NewServeMux()
for _, r := range records {
mux.HandleFunc(r.name, r.asHandler())
}
return &CustomRecordsResolver{mux}
}

// Resolve returns nil when there was no result found
func (r *CustomRecordsResolver) Resolve(req *dns.Msg, local net.Addr, remote net.Addr) *dns.Msg {
writer := roResponseWriter{local: local, remote: remote}
r.mux.ServeDNS(&writer, req)
if writer.result.Rcode == dns.RcodeRefused {
return nil
} else {
return writer.result
}
}

// roResponseWriter implements dns.ResponseWriter,
// but does not allow calling any method with
// side effects.
// It allows wrapping a dns.ResponseWriter in order
// to recover the final written dns.Msg
type roResponseWriter struct {
local net.Addr
remote net.Addr
result *dns.Msg
}

func (w *roResponseWriter) LocalAddr() net.Addr {
return w.local
}

func (w *roResponseWriter) RemoteAddr() net.Addr {
return w.remote
}

func (w *roResponseWriter) WriteMsg(msg *dns.Msg) error {
w.result = msg
return nil
}
func (w *roResponseWriter) Write([]byte) (int, error) {
return 0, nil
}
func (w *roResponseWriter) Close() error {
return nil
}
func (w *roResponseWriter) TsigStatus() error {
return nil
}
func (w *roResponseWriter) TsigTimersOnly(_ bool) {}
func (w *roResponseWriter) Hijack() {
}
58 changes: 11 additions & 47 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,16 @@ import (

// Server type
type Server struct {
host string
rTimeout time.Duration
wTimeout time.Duration
eventLoop *EventLoop
udpServer *dns.Server
tcpServer *dns.Server
httpServer *ServerHTTPS
udpHandler *dns.ServeMux
tcpHandler *dns.ServeMux
httpHandler *dns.ServeMux
activeHandlerPatterns []string
host string
rTimeout time.Duration
wTimeout time.Duration
eventLoop *EventLoop
udpServer *dns.Server
tcpServer *dns.Server
httpServer *ServerHTTPS
udpHandler *dns.ServeMux
tcpHandler *dns.ServeMux
httpHandler *dns.ServeMux
}

// Run starts the server
Expand All @@ -40,17 +39,6 @@ func (s *Server) Run(
httpHandler := dns.NewServeMux()
httpHandler.HandleFunc(".", s.eventLoop.DoHTTP)

handlerPatterns := make([]string, len(config.CustomDNSRecords))

for _, record := range NewCustomDNSRecordsFromText(config.CustomDNSRecords) {
dnsHandler := record.asHandler()
tcpHandler.HandleFunc(record.name, dnsHandler)
udpHandler.HandleFunc(record.name, dnsHandler)
httpHandler.HandleFunc(record.name, dnsHandler)
handlerPatterns = append(handlerPatterns, record.name)
}
s.activeHandlerPatterns = handlerPatterns

s.tcpHandler = tcpHandler
s.udpHandler = udpHandler
s.httpHandler = httpHandler
Expand Down Expand Up @@ -133,31 +121,7 @@ func (s *Server) Stop() {

// ReloadConfig only supports reloading the customDnsRecords section of the config for now
func (s *Server) ReloadConfig(config *Config) {
oldRecords := s.activeHandlerPatterns
newRecords := NewCustomDNSRecordsFromText(config.CustomDNSRecords)
newRecordsPatterns := make([]string, len(newRecords))
for _, r := range newRecords {
newRecordsPatterns = append(newRecordsPatterns, r.name)
}
if testEq(oldRecords, newRecordsPatterns) {
// no changes - nothing to reload
return
}
s.eventLoop.customDns = NewCustomRecordsResolver(newRecords)
defer metric.CustomDNSConfigReload.Inc()

deletedRecords := difference(oldRecords, newRecordsPatterns)

for _, deleted := range deletedRecords {
s.tcpHandler.HandleRemove(deleted)
s.udpHandler.HandleRemove(deleted)
s.httpHandler.HandleRemove(deleted)
}

for _, record := range newRecords {
dnsHandler := record.asHandler()
s.tcpHandler.HandleFunc(record.name, dnsHandler)
s.udpHandler.HandleFunc(record.name, dnsHandler)
s.httpHandler.HandleFunc(record.name, dnsHandler)
}
s.activeHandlerPatterns = newRecordsPatterns
}

0 comments on commit a46714f

Please sign in to comment.