Skip to content

Commit

Permalink
add doq upstream
Browse files Browse the repository at this point in the history
  • Loading branch information
IrineSistiana committed Sep 17, 2023
1 parent f02d377 commit ae546cd
Show file tree
Hide file tree
Showing 4 changed files with 347 additions and 5 deletions.
30 changes: 30 additions & 0 deletions pkg/pool/msg_buf.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
package pool

import (
"encoding/binary"
"fmt"

"github.com/miekg/dns"
)

Expand All @@ -40,3 +43,30 @@ func PackBuffer(m *dns.Msg) (wire []byte, buf *[]byte, err error) {
}
return wire, buf, nil
}

// PackBuffer packs the dns msg m to wire format, with to bytes length header.
// Callers should release the buf by calling ReleaseBuf.
func PackTCPBuffer(m *dns.Msg) (buf *[]byte, err error) {
b := GetBuf(packBufSize)
wire, err := m.PackBuffer((*b)[2:])
if err != nil {
ReleaseBuf(b)
return nil, err
}

l := len(wire)
if l > dns.MaxMsgSize {
ReleaseBuf(b)
return nil, fmt.Errorf("dns payload size %d is too large", l)
}

if &((*b)[0]) != &wire[0] { // reallocated
ReleaseBuf(b)
b = GetBuf(l + 2)
binary.BigEndian.PutUint16((*b)[:2], uint16(l))
copy((*b)[2:], wire)
return b, nil
}
*b = (*b)[:2+l]
return b, nil
}
3 changes: 2 additions & 1 deletion pkg/pool/msg_buf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
package pool

import (
"github.com/miekg/dns"
"testing"

"github.com/miekg/dns"
)

func TestPackBuffer_No_Allocation(t *testing.T) {
Expand Down
283 changes: 283 additions & 0 deletions pkg/upstream/doq/upstream.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
/*
* Copyright (C) 2020-2022, IrineSistiana
*
* This file is part of mosdns.
*
* mosdns is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* mosdns is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/

package doq

import (
"context"
"crypto/rand"
"crypto/tls"
"errors"
"net"
"sync"
"sync/atomic"
"time"

"github.com/IrineSistiana/mosdns/v5/pkg/dnsutils"
"github.com/IrineSistiana/mosdns/v5/pkg/pool"
"github.com/miekg/dns"
"github.com/quic-go/quic-go"
)

const (
defaultDoQTimeout = time.Second * 5
dialTimeout = time.Second * 3
connectionLostThreshold = time.Second * 5
)

var (
doqAlpn = []string{"doq"}
)

type Upstream struct {
t *quic.Transport
addr string
tlsConfig *tls.Config
quicConfig *quic.Config

cm sync.Mutex
lc *lazyConn
}

// tlsConfig cannot be nil, it should have Servername or InsecureSkipVerify.
func NewUpstream(addr string, lc *net.UDPConn, tlsConfig *tls.Config, quicConfig *quic.Config) (*Upstream, error) {
srk, err := initSrk()
if err != nil {
return nil, err
}
if tlsConfig == nil {
return nil, errors.New("nil tls config")
}

tlsConfig = tlsConfig.Clone()
tlsConfig.NextProtos = doqAlpn

return &Upstream{
t: &quic.Transport{
Conn: lc,
StatelessResetKey: (*quic.StatelessResetKey)(srk),
},
addr: addr,
tlsConfig: tlsConfig,
quicConfig: quicConfig,
}, nil
}

func initSrk() (*[32]byte, error) {
var b [32]byte
_, err := rand.Read(b[:])
if err != nil {
return nil, err
}
return &b, nil
}

func (u *Upstream) Close() error {
return u.t.Close()
}

func (u *Upstream) newStream(ctx context.Context) (quic.Stream, *lazyConn, error) {
var lc *lazyConn
u.cm.Lock()
if u.lc == nil { // First dial.
u.lc = u.asyncDialConn()
} else {
select {
case <-u.lc.dialFinished:
if u.lc.err != nil { // previous dial failed
u.lc = u.asyncDialConn()
} else {
err := u.lc.c.Context().Err()
if err != nil { // previous connection is dead or closed
u.lc = u.asyncDialConn()
}
// previous connection looks good
}
default:
// still dialing
}
}
lc = u.lc
u.cm.Unlock()

select {
case <-lc.dialFinished:
if lc.c == nil {
return nil, nil, lc.err
}
s, err := lc.c.OpenStream()
if err != nil {
lc.queryFailedAndCloseIfConnLost(time.Now())
return nil, nil, err
}
return s, lc, nil
case <-ctx.Done():
return nil, nil, ctx.Err()
}
}

type lazyConn struct {
cancelDial func()

m sync.Mutex
closed bool
dialFinished chan struct{}
c quic.Connection
err error

latestRecvMs atomic.Int64
}

func (u *Upstream) asyncDialConn() *lazyConn {
ctx, cancel := context.WithTimeout(context.Background(), dialTimeout)
lc := &lazyConn{
cancelDial: cancel,
dialFinished: make(chan struct{}),
}

go func() {
defer cancel()

ua, err := net.ResolveUDPAddr("udp", u.addr) // TODO: Support bootstrap.
if err != nil {
lc.err = err
return
}

c, err := u.t.Dial(ctx, ua, u.tlsConfig, u.quicConfig)

var closeC bool
lc.m.Lock()
if lc.closed {
closeC = true // lc was closed, nothing to do
} else {
if err != nil {
lc.err = err
} else {
lc.c = c
lc.saveLatestRespRecvTime(time.Now())
}
close(lc.dialFinished)
}
lc.m.Unlock()

if closeC && c != nil { // lc was closed while dialing.
c.CloseWithError(0, "")
}
}()
return lc
}

func (lc *lazyConn) close() {
lc.m.Lock()
defer lc.m.Unlock()

if lc.closed {
return
}
lc.closed = true

select {
case <-lc.dialFinished:
if lc.c != nil {
lc.c.CloseWithError(0, "")
}
default:
lc.cancelDial()
lc.err = net.ErrClosed
close(lc.dialFinished)
}
}

func (lc *lazyConn) saveLatestRespRecvTime(t time.Time) {
lc.latestRecvMs.Store(t.UnixMilli())
}

func (lc *lazyConn) latestRespRecvTime() time.Time {
return time.UnixMilli(lc.latestRecvMs.Load())
}

func (lc *lazyConn) queryFailedAndCloseIfConnLost(t time.Time) {
// No msg received for a quit long time. This connection may be lost.
if time.Since(lc.latestRespRecvTime()) > connectionLostThreshold {
lc.close()
}
}

func (u *Upstream) ExchangeContext(ctx context.Context, q *dns.Msg) (*dns.Msg, error) {
payload, err := pool.PackTCPBuffer(q)
if err != nil {
return nil, err
}
// 4.2.1. DNS Message IDs
// When sending queries over a QUIC connection, the DNS Message ID MUST
// be set to 0. The stream mapping for DoQ allows for unambiguous
// correlation of queries and responses, so the Message ID field is not
// required.
(*payload)[2], (*payload)[3] = 0, 0

s, lc, err := u.newStream(ctx)
if err != nil {
return nil, err
}

type res struct {
resp *dns.Msg
err error
}
rc := make(chan res, 1)
go func() {
defer pool.ReleaseBuf(payload)
defer s.Close()
r, err := u.exchange(s, *payload)
rc <- res{resp: r, err: err}
if err != nil {
lc.queryFailedAndCloseIfConnLost(time.Now())
} else {
lc.saveLatestRespRecvTime(time.Now())
}
}()

select {
case <-ctx.Done():
return nil, context.Cause(ctx)
case r := <-rc:
resp := r.resp
err := r.err
if resp != nil {
resp.Id = q.Id
}
return resp, err
}
}

func (u *Upstream) exchange(s quic.Stream, payload []byte) (*dns.Msg, error) {
s.SetDeadline(time.Now().Add(defaultDoQTimeout))

_, err := s.Write(payload)
if err != nil {
return nil, err
}

resp, _, err := dnsutils.ReadMsgFromTCP(s)
if err != nil {
return nil, err
}
return resp, nil
}
36 changes: 32 additions & 4 deletions pkg/upstream/upstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
"github.com/IrineSistiana/mosdns/v5/pkg/dnsutils"
"github.com/IrineSistiana/mosdns/v5/pkg/upstream/bootstrap"
"github.com/IrineSistiana/mosdns/v5/pkg/upstream/doh"
"github.com/IrineSistiana/mosdns/v5/pkg/upstream/doq"
"github.com/IrineSistiana/mosdns/v5/pkg/upstream/transport"
"github.com/miekg/dns"
"github.com/quic-go/quic-go"
Expand Down Expand Up @@ -181,10 +182,8 @@ func NewUpstream(addr string, opt Opt) (Upstream, error) {
}
return transport.NewReuseConnTransport(transport.ReuseConnOpts{IOOpts: to}), nil
case "tls":
var tlsConfig *tls.Config
if opt.TLSConfig != nil {
tlsConfig = opt.TLSConfig.Clone()
} else {
tlsConfig := opt.TLSConfig.Clone()
if tlsConfig == nil {
tlsConfig = new(tls.Config)
}
if len(tlsConfig.ServerName) == 0 {
Expand Down Expand Up @@ -287,6 +286,35 @@ func NewUpstream(addr string, opt Opt) (Upstream, error) {
Client: &http.Client{Transport: t},
AddOnCloser: addonCloser,
}, nil
case "quic", "doq":
tlsConfig := opt.TLSConfig.Clone()
if tlsConfig == nil {
tlsConfig = new(tls.Config)
}
if len(tlsConfig.ServerName) == 0 {
tlsConfig.ServerName = tryRemovePort(addrURL.Host)
}

quicConfig := &quic.Config{
TokenStore: quic.NewLRUTokenStore(4, 8),
InitialStreamReceiveWindow: 4 * 1024,
MaxStreamReceiveWindow: 4 * 1024,
InitialConnectionReceiveWindow: 8 * 1024,
MaxConnectionReceiveWindow: 64 * 1024,
}

lc := net.ListenConfig{Control: getSocketControlFunc(socketOpts{so_mark: opt.SoMark, bind_to_device: opt.BindToDevice})}
uc, err := lc.ListenPacket(context.Background(), "udp", "")
if err != nil {
return nil, fmt.Errorf("failed to init udp socket for quic")
}

dialAddr := getDialAddrWithPort(addrURL.Host, opt.DialAddr, 853)
u, err := doq.NewUpstream(dialAddr, uc.(*net.UDPConn), tlsConfig, quicConfig)
if err != nil {
return nil, fmt.Errorf("failed to setup doq upstream, %w", err)
}
return u, nil
default:
return nil, fmt.Errorf("unsupported protocol [%s]", addrURL.Scheme)
}
Expand Down

0 comments on commit ae546cd

Please sign in to comment.