Skip to content

Commit

Permalink
Complete AXFR support
Browse files Browse the repository at this point in the history
  • Loading branch information
buglloc authored Aug 6, 2023
1 parent dbc4ede commit 865ba8a
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 41 deletions.
2 changes: 1 addition & 1 deletion example/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ listener:
clients:
- name: tst.
secret: NzBjOTU4OTVlOTZlOTg5OGQwYTUxYTdjNWYzNTI3NzA5YjIyZTIxNWVjOTc3NWMxNzIxZjdjN2ExNjliNDc1ZCAgLQo=
axfr_allowed: true
xfr_allowed: true
zones:
test.lala.
upstream:
Expand Down
20 changes: 10 additions & 10 deletions internal/config/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ func (k ListenerKind) MarshalText() ([]byte, error) {
}

type Client struct {
Name string `koanf:"name"`
Secret string `koanf:"secret"`
AxfrAllowed bool `koanf:"axfr_allowed"`
AutoDelete bool `koanf:"auto_delete"`
Zones []string `koanf:"zones"`
Name string `koanf:"name"`
Secret string `koanf:"secret"`
XFRAllowed bool `koanf:"xfr_allowed"`
AutoDelete bool `koanf:"auto_delete"`
Zones []string `koanf:"zones"`
}

type RFC2136Listener struct {
Expand Down Expand Up @@ -99,11 +99,11 @@ func (r *Runtime) newRFC2136Listener(u upstream.Upstream, cfg RFC2136Listener) (

for _, cl := range cfg.Clients {
lCfg.AppendClient(lrfc2136.Client{
Name: cl.Name,
Secret: cl.Secret,
Zones: cl.Zones,
AxfrAllowed: cl.AxfrAllowed,
AutoDelete: cl.AutoDelete,
Name: cl.Name,
Secret: cl.Secret,
Zones: cl.Zones,
XFRAllowed: cl.XFRAllowed,
AutoDelete: cl.AutoDelete,
})
}

Expand Down
21 changes: 7 additions & 14 deletions internal/listener/lrfc2136/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,14 @@ package lrfc2136
import (
"fmt"
"strings"

"github.com/miekg/dns"
)

type Client struct {
Name string
Secret string
AxfrAllowed bool
AutoDelete bool
Zones []string
Name string
Secret string
XFRAllowed bool
AutoDelete bool
Zones []string
}

type Clients struct {
Expand All @@ -23,13 +21,8 @@ func (a *Clients) ShouldAutoDelete(clientName string) bool {
return a.clients[clientName].AutoDelete
}

func (a *Clients) IsQTypeAllowed(clientName string, qtype uint16) bool {
switch qtype {
case dns.TypeAXFR:
return a.clients[clientName].AxfrAllowed
}

return true
func (a *Clients) IsXFRAllowed(clientName string) bool {
return a.clients[clientName].XFRAllowed
}

func (a *Clients) IsNameAllowed(clientName, name string) bool {
Expand Down
150 changes: 134 additions & 16 deletions internal/listener/lrfc2136/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"math"
"sync"
"time"

Expand Down Expand Up @@ -114,7 +115,6 @@ func (a *Listener) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}

func (a *Listener) lockedServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) error {
now := time.Now()
tsig := r.IsTsig()
if tsig == nil {
return errors.New("missing TSIG")
Expand All @@ -124,33 +124,146 @@ func (a *Listener) lockedServeDNS(ctx context.Context, w dns.ResponseWriter, r *
return fmt.Errorf("invalid TSIG: %w", err)
}

m := new(dns.Msg)
m.SetReply(r)
m.SetTsig(tsig.Hdr.Name, dns.HmacSHA256, 300, time.Now().Unix())
m.Compress = false

switch r.Opcode {
case dns.OpcodeQuery:
a.handleQuery(ctx, m, r)
if isXRFRequest(r) {
return a.logRequest(ctx, w, r, a.lockedServeXFR)
}
return a.logRequest(ctx, w, r, a.lockedServeQuery)

case dns.OpcodeUpdate:
a.handleUpdates(ctx, m, r)
return a.logRequest(ctx, w, r, a.lockedServeUpdate)
}

return fmt.Errorf("unsupported opcode: %s", dns.OpcodeToString[r.Opcode])
}

func (a *Listener) lockedServeXFR(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) error {
if !a.clients.IsXFRAllowed(r.IsTsig().Hdr.Name) {
return fmt.Errorf("XFR is not allowed for client %q", r.IsTsig().Hdr.Name)
}

if !isXRFRequest(r) {
return errors.New("invalid XFR request")
}

ch := make(chan *dns.Envelope)
tr := new(dns.Transfer)
done := make(chan struct{})
go func() {
defer close(done)

if err := tr.Out(w, r, ch); err != nil {
log.Ctx(ctx).Error().Err(err).Msg("unable to write transfer")
}
}()

a.handleXFR(ctx, r.Question[0], ch)
close(ch)

<-done
_ = w.Close()
return nil
}

func (a *Listener) lockedServeQuery(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) error {
m := new(dns.Msg)
m.SetReply(r)
m.SetTsig(r.IsTsig().Hdr.Name, dns.HmacSHA256, 300, time.Now().Unix())
a.handleQuery(ctx, m, r)
if err := w.WriteMsg(m); err != nil {
log.Ctx(ctx).Error().Err(err).Msg("write failed")
}

return nil
}

func (a *Listener) lockedServeUpdate(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) error {
m := new(dns.Msg)
m.SetReply(r)
m.SetTsig(r.IsTsig().Hdr.Name, dns.HmacSHA256, 300, time.Now().Unix())
a.handleUpdates(ctx, m, r)
if err := w.WriteMsg(m); err != nil {
log.Ctx(ctx).Error().Err(err).Msg("write failed")
return nil
}

log.Ctx(ctx).Info().Dur("elapsed", time.Since(now)).Msg("finished")
return nil
}

func (a *Listener) logRequest(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, fn func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) error) error {
now := time.Now()
err := fn(ctx, w, r)
log.Ctx(ctx).Info().Dur("elapsed", time.Since(now)).Msg("finished")
return err
}

func (a *Listener) handleXFR(ctx context.Context, q dns.Question, out chan *dns.Envelope) {
log.Ctx(ctx).Info().Str("name", q.Name).Msg("handle XFR request")

rules, err := a.upsc.Query(ctx, upstream.Rule{
Name: dns.Fqdn(q.Name),
Type: q.Qtype,
})
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("unable to get rules from upstream")
out <- &dns.Envelope{
Error: err,
}
return
}

const xfrTTL = 60
xfrMarker := &dns.Envelope{
RR: []dns.RR{
&dns.SOA{
Hdr: dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypeSOA,
Class: dns.ClassINET,
},
Ns: q.Name,
Mbox: q.Name,
Serial: uint32(time.Now().Unix() % math.MaxUint32),
Refresh: xfrTTL,
Retry: xfrTTL,
Expire: xfrTTL,
},
},
}

out <- xfrMarker
defer func() { out <- xfrMarker }()

const chunkSize = 64
for i := 0; i < len(rules); i += chunkSize {
end := i + chunkSize

if end > len(rules) {
end = len(rules)
}

rrs := make([]dns.RR, 0, chunkSize)
for _, rule := range rules[i:end] {
rr, err := rule.RR()
if err != nil {
log.Ctx(ctx).Error().Err(err).Str("name", rule.Name).Msg("unable to generate rr")
continue
}

rrs = append(rrs, rr)
}

out <- &dns.Envelope{
RR: rrs,
}
}
}

func (a *Listener) handleQuery(ctx context.Context, m *dns.Msg, r *dns.Msg) {
log.Ctx(ctx).Info().Msg("handle query")
for _, q := range r.Question {
if !a.clients.IsQTypeAllowed(m.IsTsig().Hdr.Name, q.Qtype) {
log.Ctx(ctx).Warn().Str("qtype", dns.TypeToString[q.Qtype]).Msg("qtype is dissallowed")
if isXRFQuestion(q) {
log.Ctx(ctx).Warn().Msg("unexpected XFR request ignored")
continue
}

Expand All @@ -159,6 +272,7 @@ func (a *Listener) handleQuery(ctx context.Context, m *dns.Msg, r *dns.Msg) {
Type: q.Qtype,
})
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("unable to get rules from upstream")
continue
}

Expand Down Expand Up @@ -229,10 +343,6 @@ func (a *Listener) handleUpdates(ctx context.Context, m *dns.Msg, r *dns.Msg) {
return nil
}

if len(r.Question) > 0 {
a.handleQuery(ctx, m, r)
}

for _, rr := range r.Ns {
if err := handleUpdate(rr); err != nil {
log.Error().Err(err).Any("rr", rr).Msg("update failed")
Expand Down Expand Up @@ -273,3 +383,11 @@ func dnsMsgAcceptFunc(dh dns.Header) dns.MsgAcceptAction {
}
return dns.MsgAccept
}

func isXRFRequest(r *dns.Msg) bool {
return len(r.Question) == 1 && isXRFQuestion(r.Question[0])
}

func isXRFQuestion(q dns.Question) bool {
return q.Qtype == dns.TypeAXFR || q.Qtype == dns.TypeIXFR
}
1 change: 1 addition & 0 deletions internal/upstream/upstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ func (r *Rule) RR() (dns.RR, error) {
hdr := dns.RR_Header{
Name: r.Name,
Rrtype: r.Type,
Class: dns.ClassINET,
}

switch r.Type {
Expand Down

0 comments on commit 865ba8a

Please sign in to comment.