diff --git a/internal/listener/lrfc2136/listener.go b/internal/listener/lrfc2136/listener.go index 4f2df10..e08b70b 100644 --- a/internal/listener/lrfc2136/listener.go +++ b/internal/listener/lrfc2136/listener.go @@ -105,9 +105,7 @@ func (a *Listener) lockedServeDNS(ctx context.Context, w dns.ResponseWriter, r * a.handleQuery(ctx, m) case dns.OpcodeUpdate: - for _, rr := range r.Ns { - a.handleUpdate(ctx, m, rr) - } + a.handleUpdates(ctx, m) } if err := w.WriteMsg(m); err != nil { @@ -122,82 +120,77 @@ func (a *Listener) lockedServeDNS(ctx context.Context, w dns.ResponseWriter, r * func (a *Listener) handleQuery(ctx context.Context, m *dns.Msg) { log.Ctx(ctx).Info().Msg("handle query") for _, q := range m.Question { - rule, err := a.upsc.Query(ctx, q.Name, q.Qtype) + rules, err := a.upsc.Query(ctx, upstream.Rule{ + Name: q.Name, + Type: q.Qtype, + }) if err != nil { continue } - rr, err := rule.RR() - if err != nil { - log.Ctx(ctx).Error().Err(err).Str("name", rule.Name).Msg("unable to generate rr") - continue - } + for _, rule := range rules { + rr, err := rule.RR() + if err != nil { + log.Ctx(ctx).Error().Err(err).Str("name", rule.Name).Msg("unable to generate rr") + continue + } - m.Answer = append(m.Answer, rr) + m.Answer = append(m.Answer, rr) + } } } -func (a *Listener) handleUpdate(ctx context.Context, m *dns.Msg, r dns.RR) { - header := r.Header() - name := header.Name - l := log.Ctx(ctx).With().Str("name", name).Logger() +func (a *Listener) handleUpdates(ctx context.Context, m *dns.Msg) { + log.Ctx(ctx).Info().Msg("handle updates") - if _, ok := dns.IsDomainName(name); !ok { - l.Warn().Msg("skip non-domain name") + tx, err := a.upsc.Tx(ctx) + if err != nil { + log.Ctx(ctx).Error().Err(err).Msg("create tx") return } - if !a.acl.IsAllow(m.IsTsig().Hdr.Name, name) { - l.Warn().Str("client", m.IsTsig().Hdr.Name).Msg("skip not allowed domain name") - return + handleUpdate := func(rr dns.RR) error { + header := rr.Header() + name := header.Name + l := log.Ctx(ctx).With().Str("name", name).Logger() + + if _, ok := dns.IsDomainName(name); !ok { + return errors.New("invalid domain name") + } + + if !a.acl.IsAllow(m.IsTsig().Hdr.Name, name) { + return fmt.Errorf("%q is not allowed for client %q", name, m.IsTsig().Hdr.Name) + } + + if header.Class == dns.ClassANY && header.Rdlength == 0 { + err := tx.Delete(upstream.Rule{ + Name: name, + Type: header.Rrtype, + }) + if err != nil { + return fmt.Errorf("unable to delete: %w", err) + } + l.Info().Str("name", name).Msg("deleted") + return nil + } + + rule, err := upstream.RuleFromRR(rr) + if err != nil { + return err + } + + return tx.Append(rule) } - if header.Class == dns.ClassANY && header.Rdlength == 0 { - fmt.Println("delete", name, "!!!!!!!!!") - //if err := deleteRecord(name, rtype); err != nil { - // l.Error().Err(err).Msg("unable to delete record") - //} else { - // l.Info().Msg("deleted") - //} - return + for _, rr := range m.Ns { + if err := handleUpdate(rr); err != nil { + log.Error().Err(err).Any("rr", rr).Msg("update failed") + } } - switch rr := r.(type) { - case *dns.A: - fmt.Println("A", rr) - case *dns.AAAA: - fmt.Println("AAAA", rr) - case *dns.CNAME: - fmt.Println("CNAME", rr) - case *dns.TXT: - fmt.Println("TXT", rr) - default: - l.Warn().Type("type", r).Msg("ignore unsupported request") - } - // - //if a, ok := r.(*dns.A); ok { - // rrr, err := getRecord(name, rtype) - // if err == nil { - // rr = rrr.(*dns.A) - // } else { - // rr = new(dns.A) - // } - // - // ip = a.A - // rr.(*dns.A).Hdr = rheader - // rr.(*dns.A).A = ip - //} else if a, ok := r.(*dns.AAAA); ok { - // rrr, err := getRecord(name, rtype) - // if err == nil { - // rr = rrr.(*dns.AAAA) - // } else { - // rr = new(dns.AAAA) - // } - // - // ip = a.AAAA - // rr.(*dns.AAAA).Hdr = rheader - // rr.(*dns.AAAA).AAAA = ip - //} + if err := tx.Commit(context.Background()); err != nil { + log.Error().Err(err).Msg("commit failed") + } } func dnsMsgAcceptFunc(dh dns.Header) dns.MsgAcceptAction { diff --git a/internal/upstream/upstream.go b/internal/upstream/upstream.go index e3d3fad..a63a3a8 100644 --- a/internal/upstream/upstream.go +++ b/internal/upstream/upstream.go @@ -31,13 +31,40 @@ type Rule struct { ValueStr string } -func NewRule(name string, typ RType, value RValue) *Rule { - return &Rule{ - Name: name, - Type: typ, +func RuleFromRR(rr dns.RR) (Rule, error) { + var value any + switch v := rr.(type) { + case *dns.A: + value = v.A + + case *dns.AAAA: + value = v.AAAA + + case *dns.CNAME: + value = v.Target + + case *dns.MX: + value = v + + case *dns.PTR: + value = v.Ptr + + case *dns.TXT: + value = v.Txt + + case *dns.SRV: + value = v + + default: + return Rule{}, fmt.Errorf("unsupported rr: %s", rr) + } + + return Rule{ + Name: rr.Header().Name, + Type: rr.Header().Rrtype, Value: value, ValueStr: fmt.Sprint(value), - } + }, nil } func (r *Rule) RR() (dns.RR, error) {