Skip to content

Commit

Permalink
proxy: imp code, algo
Browse files Browse the repository at this point in the history
  • Loading branch information
EugeneOne1 committed Dec 15, 2023
1 parent 068bcd8 commit 2ad8274
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 45 deletions.
55 changes: 28 additions & 27 deletions proxy/exchange.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,26 @@ func (p *Proxy) exchangeUpstreams(
case dns.TypeA, dns.TypeAAAA:
return p.fastestAddr.ExchangeFastest(req, ups)
default:
// Fallthrough to the load-balancing mode.
// Go on to the load-balancing mode.
}
default:
// Load-balancing mode goes on.
// Go on to the load-balancing mode.
}

if len(ups) == 1 {
u = ups[0]
resp, _, err = exchange(u, req, p.time)
// TODO(e.burkov): p.updateRTT(u.Address(), elapsed)

return resp, u, err
}

w := sampleuv.NewWeighted(p.calcWeights(ups), p.randSrc)

var errs []error
for i, ok := w.Take(); ok; i, ok = w.Take() {
u = ups[i]

var elapsed uint64
var elapsed time.Duration
resp, elapsed, err = exchange(u, req, p.time)
if err == nil {
p.updateRTT(u.Address(), elapsed)
Expand All @@ -56,7 +56,7 @@ func (p *Proxy) exchangeUpstreams(

// TODO(e.burkov): Use the actual configured timeout or, perhaps, the
// actual measured elapsed time.
p.updateRTT(u.Address(), uint64(defaultTimeout/time.Millisecond))
p.updateRTT(u.Address(), defaultTimeout)
}

// TODO(e.burkov): Use [errors.Join].
Expand All @@ -66,53 +66,53 @@ func (p *Proxy) exchangeUpstreams(
// exchange returns the result of the DNS request exchange with the given
// upstream and the elapsed time in milliseconds. It uses the given clock to
// measure the request duration.
func exchange(u upstream.Upstream, req *dns.Msg, c Clock) (resp *dns.Msg, dur uint64, err error) {
func exchange(u upstream.Upstream, req *dns.Msg, c Clock) (resp *dns.Msg, dur time.Duration, err error) {
startTime := c.Now()

reply, err := u.Exchange(req)

// Don't use [time.Since] because it uses [time.Now].
elapsed := c.Now().Sub(startTime)
dur = c.Now().Sub(startTime)

addr := u.Address()
if err != nil {
log.Error(
"dnsproxy: upstream %s failed to exchange %s in %s: %s",
addr,
req.Question[0].String(),
elapsed,
dur,
err,
)
} else {
log.Debug(
"dnsproxy: upstream %s successfully finished exchange of %s; elapsed %s",
addr,
req.Question[0].String(),
elapsed,
dur,
)
}

return reply, uint64(elapsed.Milliseconds()), err
return reply, dur, err
}

// upstreamRTTStats is the statistics for a single upstream's round-trip time.
type upstreamRTTStats struct {
// rttSum is the sum of round-trip times for all requests to the upstream.
rttSum uint64
// avgRTT is the current average round-trip time in seconds. The float64 is
// the returning type of [time.Duration.Seconds] method and is used to avoid
// unnecessary divisions.
avgRTT float64

// reqNum is the number of requests to the upstream.
reqNum uint64
}

// avg returns the average round-trip time for the upstream. It returns 1 if
// there were no requests to the upstream or the sum of round-trip times is 0 to
// avoid division by zero.
func (s upstreamRTTStats) avg() (avg float64) {
if s.reqNum == 0 || s.rttSum == 0 {
return 1
// update returns updated stats after adding given RTT.
func (stats upstreamRTTStats) update(rtt time.Duration) (updated upstreamRTTStats) {
return upstreamRTTStats{
// See https://en.wikipedia.org/wiki/Moving_average#Cumulative_average.
avgRTT: (rtt.Seconds() + stats.avgRTT*float64(stats.reqNum)) / float64(stats.reqNum+1),
reqNum: stats.reqNum + 1,
}

return float64(s.rttSum) / float64(s.reqNum)
}

// calcWeights returns the slice of weights, each corresponding to the upstream
Expand All @@ -125,25 +125,26 @@ func (p *Proxy) calcWeights(ups []upstream.Upstream) (weights []float64) {

for _, u := range ups {
stat := p.upstreamRTTStats[u.Address()]
weights = append(weights, 1/stat.avg())
if stat.avgRTT == 0 {
// Use 1 as the default weight.
weights = append(weights, 1)
} else {
weights = append(weights, 1/stat.avgRTT)
}
}

return weights
}

// updateRTT updates the round-trip time in [upstreamRTTStats] for given
// address.
func (p *Proxy) updateRTT(address string, rtt uint64) {
func (p *Proxy) updateRTT(address string, rtt time.Duration) {
p.rttLock.Lock()
defer p.rttLock.Unlock()

if p.upstreamRTTStats == nil {
p.upstreamRTTStats = map[string]upstreamRTTStats{}
}

stat := p.upstreamRTTStats[address]
p.upstreamRTTStats[address] = upstreamRTTStats{
rttSum: stat.rttSum + rtt,
reqNum: stat.reqNum + 1,
}
p.upstreamRTTStats[address] = p.upstreamRTTStats[address].update(rtt)
}
31 changes: 15 additions & 16 deletions proxy/exchange_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ type measuredUpstream struct {
}

// type check
var _ upstream.Upstream = (*measuredUpstream)(nil)
var _ upstream.Upstream = measuredUpstream{}

// Exchange implements the [upstream.Upstream] interface for *countedUpstream.
// Exchange implements the [upstream.Upstream] interface for measuredUpstream.
func (u measuredUpstream) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
u.stats[u.Upstream]++

Expand All @@ -70,12 +70,11 @@ func TestProxy_Exchange_loadBalance(t *testing.T) {
requestsNum = 10_000
)

zeroTime := time.Unix(0, 0)
currentNow := zeroTime

// zeroingClock returns the value of currentNow and sets it back to
// zeroTime, so that all the calls since the second one return the same zero
// value until currentNow is modified elsewhere.
zeroTime := time.Unix(0, 0)
currentNow := zeroTime
zeroingClock := &fakeClock{
onNow: func() (now time.Time) {
now, currentNow = currentNow, zeroTime
Expand Down Expand Up @@ -154,18 +153,18 @@ func TestProxy_Exchange_loadBalance(t *testing.T) {
servers []upstream.Upstream
}{{
wantStat: map[upstream.Upstream]int64{
fastUps: 8917,
slowerUps: 911,
slowestUps: 172,
fastUps: 8906,
slowerUps: 920,
slowestUps: 174,
},
clock: zeroingClock,
name: "all_good",
servers: []upstream.Upstream{slowestUps, slowerUps, fastUps},
}, {
wantStat: map[upstream.Upstream]int64{
fastUps: 9081,
slowerUps: 919,
err1Ups: 7,
fastUps: 9074,
slowerUps: 926,
err1Ups: 8,
},
clock: zeroingClock,
name: "one_bad",
Expand All @@ -180,18 +179,18 @@ func TestProxy_Exchange_loadBalance(t *testing.T) {
servers: []upstream.Upstream{err2Ups, err1Ups},
}, {
wantStat: map[upstream.Upstream]int64{
fastUps: 7803,
slowerUps: 833,
fastUps: 7806,
slowerUps: 830,
fastestUps: 1365,
},
clock: zeroingClock,
name: "error_once",
servers: []upstream.Upstream{fastUps, slowerUps, fastestUps},
}, {
wantStat: map[upstream.Upstream]int64{
each200: 5316,
each100: 3090,
each50: 1683,
each200: 5308,
each100: 3099,
each50: 1682,
},
clock: constClock,
name: "error_each_nth",
Expand Down
5 changes: 3 additions & 2 deletions proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,9 @@ type Proxy struct {
// Upstream
// --

// upstreamRTTStats is a map of upstream addresses and their rtt. Used to
// sort upstreams by their latency.
// upstreamRTTStats maps the upstream address to its round-trip time
// statistics. It's holds the statistics for all upstreams to perform a
// weighted random selection when using the load balancing mode.
upstreamRTTStats map[string]upstreamRTTStats

// rttLock protects upstreamRTTStats.
Expand Down
13 changes: 13 additions & 0 deletions proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"io"
"math/big"
"net"
Expand Down Expand Up @@ -765,6 +766,18 @@ func (u *fakeUpstream) Address() (addr string) { return u.onAddress() }
// Close implements upstream.Upstream interface for *funcUpstream.
func (u *fakeUpstream) Close() (err error) { return u.onClose() }

// type check
var _ fmt.GoStringer = (*fakeUpstream)(nil)

// GoString implements the [fmt.GoStringer] interface for measuredUpstream.
func (u *fakeUpstream) GoString() (s string) {
if u.onAddress != nil {
return u.onAddress()
}

return fmt.Sprintf("%#v", u)
}

func TestProxy_ReplyFromUpstream_badResponse(t *testing.T) {
dnsProxy := createTestProxy(t, nil)
require.NoError(t, dnsProxy.Start())
Expand Down

0 comments on commit 2ad8274

Please sign in to comment.