diff --git a/example/config.yaml b/example/config.yaml index 794d520..5dbe0a5 100644 --- a/example/config.yaml +++ b/example/config.yaml @@ -5,7 +5,7 @@ listener: clients: - name: tst. secret: NzBjOTU4OTVlOTZlOTg5OGQwYTUxYTdjNWYzNTI3NzA5YjIyZTIxNWVjOTc3NWMxNzIxZjdjN2ExNjliNDc1ZCAgLQo= - axfr_allowed: true + xfr_allowed: true zones: test.lala. upstream: diff --git a/internal/config/listener.go b/internal/config/listener.go index b790c1f..63fbb6e 100644 --- a/internal/config/listener.go +++ b/internal/config/listener.go @@ -36,11 +36,11 @@ func (k ListenerKind) MarshalText() ([]byte, error) { } type Client struct { - Name string `koanf:"name"` - Secret string `koanf:"secret"` - AxfrAllowed bool `koanf:"axfr_allowed"` - AutoDelete bool `koanf:"auto_delete"` - Zones []string `koanf:"zones"` + Name string `koanf:"name"` + Secret string `koanf:"secret"` + XFRAllowed bool `koanf:"xfr_allowed"` + AutoDelete bool `koanf:"auto_delete"` + Zones []string `koanf:"zones"` } type RFC2136Listener struct { @@ -99,11 +99,11 @@ func (r *Runtime) newRFC2136Listener(u upstream.Upstream, cfg RFC2136Listener) ( for _, cl := range cfg.Clients { lCfg.AppendClient(lrfc2136.Client{ - Name: cl.Name, - Secret: cl.Secret, - Zones: cl.Zones, - AxfrAllowed: cl.AxfrAllowed, - AutoDelete: cl.AutoDelete, + Name: cl.Name, + Secret: cl.Secret, + Zones: cl.Zones, + XFRAllowed: cl.XFRAllowed, + AutoDelete: cl.AutoDelete, }) } diff --git a/internal/listener/lrfc2136/clients.go b/internal/listener/lrfc2136/clients.go index ff7fbfa..9af1f3a 100644 --- a/internal/listener/lrfc2136/clients.go +++ b/internal/listener/lrfc2136/clients.go @@ -3,16 +3,14 @@ package lrfc2136 import ( "fmt" "strings" - - "github.com/miekg/dns" ) type Client struct { - Name string - Secret string - AxfrAllowed bool - AutoDelete bool - Zones []string + Name string + Secret string + XFRAllowed bool + AutoDelete bool + Zones []string } type Clients struct { @@ -23,13 +21,8 @@ func (a *Clients) ShouldAutoDelete(clientName string) bool { return a.clients[clientName].AutoDelete } -func (a *Clients) IsQTypeAllowed(clientName string, qtype uint16) bool { - switch qtype { - case dns.TypeAXFR: - return a.clients[clientName].AxfrAllowed - } - - return true +func (a *Clients) IsXFRAllowed(clientName string) bool { + return a.clients[clientName].XFRAllowed } func (a *Clients) IsNameAllowed(clientName, name string) bool { diff --git a/internal/listener/lrfc2136/listener.go b/internal/listener/lrfc2136/listener.go index f57ea03..faaf61b 100644 --- a/internal/listener/lrfc2136/listener.go +++ b/internal/listener/lrfc2136/listener.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "math" "sync" "time" @@ -114,7 +115,6 @@ func (a *Listener) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } func (a *Listener) lockedServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) error { - now := time.Now() tsig := r.IsTsig() if tsig == nil { return errors.New("missing TSIG") @@ -124,33 +124,146 @@ func (a *Listener) lockedServeDNS(ctx context.Context, w dns.ResponseWriter, r * return fmt.Errorf("invalid TSIG: %w", err) } - m := new(dns.Msg) - m.SetReply(r) - m.SetTsig(tsig.Hdr.Name, dns.HmacSHA256, 300, time.Now().Unix()) - m.Compress = false - switch r.Opcode { case dns.OpcodeQuery: - a.handleQuery(ctx, m, r) + if isXRFRequest(r) { + return a.logRequest(ctx, w, r, a.lockedServeXFR) + } + return a.logRequest(ctx, w, r, a.lockedServeQuery) case dns.OpcodeUpdate: - a.handleUpdates(ctx, m, r) + return a.logRequest(ctx, w, r, a.lockedServeUpdate) + } + + return fmt.Errorf("unsupported opcode: %s", dns.OpcodeToString[r.Opcode]) +} + +func (a *Listener) lockedServeXFR(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) error { + if !a.clients.IsXFRAllowed(r.IsTsig().Hdr.Name) { + return fmt.Errorf("XFR is not allowed for client %q", r.IsTsig().Hdr.Name) } + if !isXRFRequest(r) { + return errors.New("invalid XFR request") + } + + ch := make(chan *dns.Envelope) + tr := new(dns.Transfer) + done := make(chan struct{}) + go func() { + defer close(done) + + if err := tr.Out(w, r, ch); err != nil { + log.Ctx(ctx).Error().Err(err).Msg("unable to write transfer") + } + }() + + a.handleXFR(ctx, r.Question[0], ch) + close(ch) + + <-done + _ = w.Close() + return nil +} + +func (a *Listener) lockedServeQuery(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) error { + m := new(dns.Msg) + m.SetReply(r) + m.SetTsig(r.IsTsig().Hdr.Name, dns.HmacSHA256, 300, time.Now().Unix()) + a.handleQuery(ctx, m, r) + if err := w.WriteMsg(m); err != nil { + log.Ctx(ctx).Error().Err(err).Msg("write failed") + } + + return nil +} + +func (a *Listener) lockedServeUpdate(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) error { + m := new(dns.Msg) + m.SetReply(r) + m.SetTsig(r.IsTsig().Hdr.Name, dns.HmacSHA256, 300, time.Now().Unix()) + a.handleUpdates(ctx, m, r) if err := w.WriteMsg(m); err != nil { log.Ctx(ctx).Error().Err(err).Msg("write failed") - return nil } - log.Ctx(ctx).Info().Dur("elapsed", time.Since(now)).Msg("finished") return nil } +func (a *Listener) logRequest(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, fn func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) error) error { + now := time.Now() + err := fn(ctx, w, r) + log.Ctx(ctx).Info().Dur("elapsed", time.Since(now)).Msg("finished") + return err +} + +func (a *Listener) handleXFR(ctx context.Context, q dns.Question, out chan *dns.Envelope) { + log.Ctx(ctx).Info().Str("name", q.Name).Msg("handle XFR request") + + rules, err := a.upsc.Query(ctx, upstream.Rule{ + Name: dns.Fqdn(q.Name), + Type: q.Qtype, + }) + if err != nil { + log.Ctx(ctx).Error().Err(err).Msg("unable to get rules from upstream") + out <- &dns.Envelope{ + Error: err, + } + return + } + + const xfrTTL = 60 + xfrMarker := &dns.Envelope{ + RR: []dns.RR{ + &dns.SOA{ + Hdr: dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypeSOA, + Class: dns.ClassINET, + }, + Ns: q.Name, + Mbox: q.Name, + Serial: uint32(time.Now().Unix() % math.MaxUint32), + Refresh: xfrTTL, + Retry: xfrTTL, + Expire: xfrTTL, + }, + }, + } + + out <- xfrMarker + defer func() { out <- xfrMarker }() + + const chunkSize = 64 + for i := 0; i < len(rules); i += chunkSize { + end := i + chunkSize + + if end > len(rules) { + end = len(rules) + } + + rrs := make([]dns.RR, 0, chunkSize) + for _, rule := range rules[i:end] { + rr, err := rule.RR() + if err != nil { + log.Ctx(ctx).Error().Err(err).Str("name", rule.Name).Msg("unable to generate rr") + continue + } + + rrs = append(rrs, rr) + } + + out <- &dns.Envelope{ + RR: rrs, + } + } +} + func (a *Listener) handleQuery(ctx context.Context, m *dns.Msg, r *dns.Msg) { log.Ctx(ctx).Info().Msg("handle query") for _, q := range r.Question { - if !a.clients.IsQTypeAllowed(m.IsTsig().Hdr.Name, q.Qtype) { - log.Ctx(ctx).Warn().Str("qtype", dns.TypeToString[q.Qtype]).Msg("qtype is dissallowed") + if isXRFQuestion(q) { + log.Ctx(ctx).Warn().Msg("unexpected XFR request ignored") continue } @@ -159,6 +272,7 @@ func (a *Listener) handleQuery(ctx context.Context, m *dns.Msg, r *dns.Msg) { Type: q.Qtype, }) if err != nil { + log.Ctx(ctx).Error().Err(err).Msg("unable to get rules from upstream") continue } @@ -229,10 +343,6 @@ func (a *Listener) handleUpdates(ctx context.Context, m *dns.Msg, r *dns.Msg) { return nil } - if len(r.Question) > 0 { - a.handleQuery(ctx, m, r) - } - for _, rr := range r.Ns { if err := handleUpdate(rr); err != nil { log.Error().Err(err).Any("rr", rr).Msg("update failed") @@ -273,3 +383,11 @@ func dnsMsgAcceptFunc(dh dns.Header) dns.MsgAcceptAction { } return dns.MsgAccept } + +func isXRFRequest(r *dns.Msg) bool { + return len(r.Question) == 1 && isXRFQuestion(r.Question[0]) +} + +func isXRFQuestion(q dns.Question) bool { + return q.Qtype == dns.TypeAXFR || q.Qtype == dns.TypeIXFR +} diff --git a/internal/upstream/upstream.go b/internal/upstream/upstream.go index d92b2e4..cf24229 100644 --- a/internal/upstream/upstream.go +++ b/internal/upstream/upstream.go @@ -81,6 +81,7 @@ func (r *Rule) RR() (dns.RR, error) { hdr := dns.RR_Header{ Name: r.Name, Rrtype: r.Type, + Class: dns.ClassINET, } switch r.Type {