From ae546cd757ee6f4a792776215d9ecc293cbe1bb7 Mon Sep 17 00:00:00 2001 From: Irine Sistiana <49315432+IrineSistiana@users.noreply.github.com> Date: Sun, 17 Sep 2023 19:06:53 +0800 Subject: [PATCH] add doq upstream --- pkg/pool/msg_buf.go | 30 ++++ pkg/pool/msg_buf_test.go | 3 +- pkg/upstream/doq/upstream.go | 283 +++++++++++++++++++++++++++++++++++ pkg/upstream/upstream.go | 36 ++++- 4 files changed, 347 insertions(+), 5 deletions(-) create mode 100644 pkg/upstream/doq/upstream.go diff --git a/pkg/pool/msg_buf.go b/pkg/pool/msg_buf.go index 980a08b53..f132cc277 100644 --- a/pkg/pool/msg_buf.go +++ b/pkg/pool/msg_buf.go @@ -20,6 +20,9 @@ package pool import ( + "encoding/binary" + "fmt" + "github.com/miekg/dns" ) @@ -40,3 +43,30 @@ func PackBuffer(m *dns.Msg) (wire []byte, buf *[]byte, err error) { } return wire, buf, nil } + +// PackBuffer packs the dns msg m to wire format, with to bytes length header. +// Callers should release the buf by calling ReleaseBuf. +func PackTCPBuffer(m *dns.Msg) (buf *[]byte, err error) { + b := GetBuf(packBufSize) + wire, err := m.PackBuffer((*b)[2:]) + if err != nil { + ReleaseBuf(b) + return nil, err + } + + l := len(wire) + if l > dns.MaxMsgSize { + ReleaseBuf(b) + return nil, fmt.Errorf("dns payload size %d is too large", l) + } + + if &((*b)[0]) != &wire[0] { // reallocated + ReleaseBuf(b) + b = GetBuf(l + 2) + binary.BigEndian.PutUint16((*b)[:2], uint16(l)) + copy((*b)[2:], wire) + return b, nil + } + *b = (*b)[:2+l] + return b, nil +} diff --git a/pkg/pool/msg_buf_test.go b/pkg/pool/msg_buf_test.go index 02d734880..97d9a7615 100644 --- a/pkg/pool/msg_buf_test.go +++ b/pkg/pool/msg_buf_test.go @@ -20,8 +20,9 @@ package pool import ( - "github.com/miekg/dns" "testing" + + "github.com/miekg/dns" ) func TestPackBuffer_No_Allocation(t *testing.T) { diff --git a/pkg/upstream/doq/upstream.go b/pkg/upstream/doq/upstream.go new file mode 100644 index 000000000..43241f483 --- /dev/null +++ b/pkg/upstream/doq/upstream.go @@ -0,0 +1,283 @@ +/* + * Copyright (C) 2020-2022, IrineSistiana + * + * This file is part of mosdns. + * + * mosdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * mosdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package doq + +import ( + "context" + "crypto/rand" + "crypto/tls" + "errors" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/IrineSistiana/mosdns/v5/pkg/dnsutils" + "github.com/IrineSistiana/mosdns/v5/pkg/pool" + "github.com/miekg/dns" + "github.com/quic-go/quic-go" +) + +const ( + defaultDoQTimeout = time.Second * 5 + dialTimeout = time.Second * 3 + connectionLostThreshold = time.Second * 5 +) + +var ( + doqAlpn = []string{"doq"} +) + +type Upstream struct { + t *quic.Transport + addr string + tlsConfig *tls.Config + quicConfig *quic.Config + + cm sync.Mutex + lc *lazyConn +} + +// tlsConfig cannot be nil, it should have Servername or InsecureSkipVerify. +func NewUpstream(addr string, lc *net.UDPConn, tlsConfig *tls.Config, quicConfig *quic.Config) (*Upstream, error) { + srk, err := initSrk() + if err != nil { + return nil, err + } + if tlsConfig == nil { + return nil, errors.New("nil tls config") + } + + tlsConfig = tlsConfig.Clone() + tlsConfig.NextProtos = doqAlpn + + return &Upstream{ + t: &quic.Transport{ + Conn: lc, + StatelessResetKey: (*quic.StatelessResetKey)(srk), + }, + addr: addr, + tlsConfig: tlsConfig, + quicConfig: quicConfig, + }, nil +} + +func initSrk() (*[32]byte, error) { + var b [32]byte + _, err := rand.Read(b[:]) + if err != nil { + return nil, err + } + return &b, nil +} + +func (u *Upstream) Close() error { + return u.t.Close() +} + +func (u *Upstream) newStream(ctx context.Context) (quic.Stream, *lazyConn, error) { + var lc *lazyConn + u.cm.Lock() + if u.lc == nil { // First dial. + u.lc = u.asyncDialConn() + } else { + select { + case <-u.lc.dialFinished: + if u.lc.err != nil { // previous dial failed + u.lc = u.asyncDialConn() + } else { + err := u.lc.c.Context().Err() + if err != nil { // previous connection is dead or closed + u.lc = u.asyncDialConn() + } + // previous connection looks good + } + default: + // still dialing + } + } + lc = u.lc + u.cm.Unlock() + + select { + case <-lc.dialFinished: + if lc.c == nil { + return nil, nil, lc.err + } + s, err := lc.c.OpenStream() + if err != nil { + lc.queryFailedAndCloseIfConnLost(time.Now()) + return nil, nil, err + } + return s, lc, nil + case <-ctx.Done(): + return nil, nil, ctx.Err() + } +} + +type lazyConn struct { + cancelDial func() + + m sync.Mutex + closed bool + dialFinished chan struct{} + c quic.Connection + err error + + latestRecvMs atomic.Int64 +} + +func (u *Upstream) asyncDialConn() *lazyConn { + ctx, cancel := context.WithTimeout(context.Background(), dialTimeout) + lc := &lazyConn{ + cancelDial: cancel, + dialFinished: make(chan struct{}), + } + + go func() { + defer cancel() + + ua, err := net.ResolveUDPAddr("udp", u.addr) // TODO: Support bootstrap. + if err != nil { + lc.err = err + return + } + + c, err := u.t.Dial(ctx, ua, u.tlsConfig, u.quicConfig) + + var closeC bool + lc.m.Lock() + if lc.closed { + closeC = true // lc was closed, nothing to do + } else { + if err != nil { + lc.err = err + } else { + lc.c = c + lc.saveLatestRespRecvTime(time.Now()) + } + close(lc.dialFinished) + } + lc.m.Unlock() + + if closeC && c != nil { // lc was closed while dialing. + c.CloseWithError(0, "") + } + }() + return lc +} + +func (lc *lazyConn) close() { + lc.m.Lock() + defer lc.m.Unlock() + + if lc.closed { + return + } + lc.closed = true + + select { + case <-lc.dialFinished: + if lc.c != nil { + lc.c.CloseWithError(0, "") + } + default: + lc.cancelDial() + lc.err = net.ErrClosed + close(lc.dialFinished) + } +} + +func (lc *lazyConn) saveLatestRespRecvTime(t time.Time) { + lc.latestRecvMs.Store(t.UnixMilli()) +} + +func (lc *lazyConn) latestRespRecvTime() time.Time { + return time.UnixMilli(lc.latestRecvMs.Load()) +} + +func (lc *lazyConn) queryFailedAndCloseIfConnLost(t time.Time) { + // No msg received for a quit long time. This connection may be lost. + if time.Since(lc.latestRespRecvTime()) > connectionLostThreshold { + lc.close() + } +} + +func (u *Upstream) ExchangeContext(ctx context.Context, q *dns.Msg) (*dns.Msg, error) { + payload, err := pool.PackTCPBuffer(q) + if err != nil { + return nil, err + } + // 4.2.1. DNS Message IDs + // When sending queries over a QUIC connection, the DNS Message ID MUST + // be set to 0. The stream mapping for DoQ allows for unambiguous + // correlation of queries and responses, so the Message ID field is not + // required. + (*payload)[2], (*payload)[3] = 0, 0 + + s, lc, err := u.newStream(ctx) + if err != nil { + return nil, err + } + + type res struct { + resp *dns.Msg + err error + } + rc := make(chan res, 1) + go func() { + defer pool.ReleaseBuf(payload) + defer s.Close() + r, err := u.exchange(s, *payload) + rc <- res{resp: r, err: err} + if err != nil { + lc.queryFailedAndCloseIfConnLost(time.Now()) + } else { + lc.saveLatestRespRecvTime(time.Now()) + } + }() + + select { + case <-ctx.Done(): + return nil, context.Cause(ctx) + case r := <-rc: + resp := r.resp + err := r.err + if resp != nil { + resp.Id = q.Id + } + return resp, err + } +} + +func (u *Upstream) exchange(s quic.Stream, payload []byte) (*dns.Msg, error) { + s.SetDeadline(time.Now().Add(defaultDoQTimeout)) + + _, err := s.Write(payload) + if err != nil { + return nil, err + } + + resp, _, err := dnsutils.ReadMsgFromTCP(s) + if err != nil { + return nil, err + } + return resp, nil +} diff --git a/pkg/upstream/upstream.go b/pkg/upstream/upstream.go index 84865634f..1bc80a2b8 100644 --- a/pkg/upstream/upstream.go +++ b/pkg/upstream/upstream.go @@ -35,6 +35,7 @@ import ( "github.com/IrineSistiana/mosdns/v5/pkg/dnsutils" "github.com/IrineSistiana/mosdns/v5/pkg/upstream/bootstrap" "github.com/IrineSistiana/mosdns/v5/pkg/upstream/doh" + "github.com/IrineSistiana/mosdns/v5/pkg/upstream/doq" "github.com/IrineSistiana/mosdns/v5/pkg/upstream/transport" "github.com/miekg/dns" "github.com/quic-go/quic-go" @@ -181,10 +182,8 @@ func NewUpstream(addr string, opt Opt) (Upstream, error) { } return transport.NewReuseConnTransport(transport.ReuseConnOpts{IOOpts: to}), nil case "tls": - var tlsConfig *tls.Config - if opt.TLSConfig != nil { - tlsConfig = opt.TLSConfig.Clone() - } else { + tlsConfig := opt.TLSConfig.Clone() + if tlsConfig == nil { tlsConfig = new(tls.Config) } if len(tlsConfig.ServerName) == 0 { @@ -287,6 +286,35 @@ func NewUpstream(addr string, opt Opt) (Upstream, error) { Client: &http.Client{Transport: t}, AddOnCloser: addonCloser, }, nil + case "quic", "doq": + tlsConfig := opt.TLSConfig.Clone() + if tlsConfig == nil { + tlsConfig = new(tls.Config) + } + if len(tlsConfig.ServerName) == 0 { + tlsConfig.ServerName = tryRemovePort(addrURL.Host) + } + + quicConfig := &quic.Config{ + TokenStore: quic.NewLRUTokenStore(4, 8), + InitialStreamReceiveWindow: 4 * 1024, + MaxStreamReceiveWindow: 4 * 1024, + InitialConnectionReceiveWindow: 8 * 1024, + MaxConnectionReceiveWindow: 64 * 1024, + } + + lc := net.ListenConfig{Control: getSocketControlFunc(socketOpts{so_mark: opt.SoMark, bind_to_device: opt.BindToDevice})} + uc, err := lc.ListenPacket(context.Background(), "udp", "") + if err != nil { + return nil, fmt.Errorf("failed to init udp socket for quic") + } + + dialAddr := getDialAddrWithPort(addrURL.Host, opt.DialAddr, 853) + u, err := doq.NewUpstream(dialAddr, uc.(*net.UDPConn), tlsConfig, quicConfig) + if err != nil { + return nil, fmt.Errorf("failed to setup doq upstream, %w", err) + } + return u, nil default: return nil, fmt.Errorf("unsupported protocol [%s]", addrURL.Scheme) }