From 8ce5c6bedd439cb0f36c8b6ed17ff885eec919ee Mon Sep 17 00:00:00 2001 From: zyxkad Date: Wed, 21 Feb 2024 20:37:26 -0700 Subject: [PATCH] maybe fixed read limit --- limited_conn.go | 171 +++++++++++++++++++++++++++++++----------------- 1 file changed, 111 insertions(+), 60 deletions(-) diff --git a/limited_conn.go b/limited_conn.go index b4b3e68b..4d919398 100644 --- a/limited_conn.go +++ b/limited_conn.go @@ -36,26 +36,46 @@ type RateController struct { writeRate int // bytes per second minWriteRate int - closed atomic.Bool - closeCh chan struct{} - mux sync.Mutex - lastWrite time.Time - wroteCount int - lastRead time.Time - readCount int + closed atomic.Bool + closeCh chan struct{} + rmux, wmux sync.Mutex + lastRead time.Time + preReadCount int + readCount int + lastWrite time.Time + wroteCount int } func NewRateController(maxConn int, readRate, writeRate int) *RateController { return &RateController{ Semaphore: NewSemaphore(maxConn), - writeRate: writeRate, - minWriteRate: 256, + closeCh: make(chan struct{}, 0), readRate: readRate, minReadRate: 256, - closeCh: make(chan struct{}, 0), + writeRate: writeRate, + minWriteRate: 256, } } +func (l *RateController) ReadRate() int { + return l.readRate +} + +func (l *RateController) SetReadRate(rate int) { + l.readRate = rate + if rate > 0 && rate < l.minReadRate { + l.minReadRate = rate + } +} + +func (l *RateController) MinReadRate() int { + return l.minReadRate +} + +func (l *RateController) SetMinReadRate(rate int) { + l.minReadRate = rate +} + func (l *RateController) WriteRate() int { return l.writeRate } @@ -75,23 +95,82 @@ func (l *RateController) SetMinWriteRate(rate int) { l.minWriteRate = rate } -func (l *RateController) ReadRate() int { - return l.readRate -} +func (l *RateController) preRead(n int) int { + if n <= 0 { + return n + } + readRate := l.readRate + if readRate <= 0 { + return n + } -func (l *RateController) SetReadRate(rate int) { - l.readRate = rate - if rate > 0 && rate < l.minReadRate { - l.minReadRate = rate + l.rmux.Lock() + defer l.rmux.Unlock() + + avgRate := readRate + if ln := l.Len(); ln > 0 { + avgRate /= ln + } + if n > avgRate { + n = avgRate } -} -func (l *RateController) MinReadRate() int { - return l.minReadRate + now := time.Now() + diff := time.Second - now.Sub(l.lastRead) + if diff <= 0 { + l.lastRead = now + l.preReadCount = 0 + l.readCount = 0 + } + if l.preReadCount >= readRate { + m := l.minReadRate + if m > n { + m = n + } + return m + } + l.preReadCount += n + return n } -func (l *RateController) SetMinReadRate(rate int) { - l.minReadRate = rate +func (l *RateController) afterRead(n int, less int) time.Duration { + readRate := l.readRate + if readRate <= 0 { + return 0 + } + if n < 0 { + n = 0 + } + + l.rmux.Lock() + defer l.rmux.Unlock() + + avgRate := readRate + if ln := l.Len(); ln > 0 { + avgRate /= ln + } + if n > avgRate { + n = avgRate + } + + now := time.Now() + diff := time.Second - now.Sub(l.lastRead) + if diff <= 0 { + l.lastRead = now + l.preReadCount = 0 + l.readCount = n + } else { + l.preReadCount -= less + l.readCount += n + } + if n >= avgRate { + // TODO: replace the magic number 3 + return time.Second / 3 + } + if l.readCount >= readRate { + return diff + } + return 0 } func (l *RateController) preWrite(n int) (int, time.Duration) { @@ -103,8 +182,8 @@ func (l *RateController) preWrite(n int) (int, time.Duration) { return n, 0 } - l.mux.Lock() - defer l.mux.Unlock() + l.wmux.Lock() + defer l.wmux.Unlock() now := time.Now() diff := time.Second - now.Sub(l.lastWrite) @@ -126,38 +205,6 @@ func (l *RateController) preWrite(n int) (int, time.Duration) { return n, 0 } -func (l *RateController) preRead(n int) (int, time.Duration) { - if n <= 0 { - return n, 0 - } - readRate := l.readRate - if readRate <= 0 { - return n, 0 - } - - l.mux.Lock() - defer l.mux.Unlock() - - now := time.Now() - diff := time.Second - now.Sub(l.lastRead) - if diff <= 0 { - l.lastRead = now - l.readCount = 0 - } else if l.readCount >= readRate { - m := l.minReadRate - if m > n { - m = n - } - return m, diff - } - if n > readRate { - l.readCount += readRate - return readRate, time.Second - } - l.readCount += n - return n, 0 -} - // Close will interrupted the incoming operations // it will not close or interrupt the proxied connections and its operations func (l *RateController) Close() error { @@ -245,13 +292,15 @@ func (r *LimitedReader) Read(buf []byte) (n int, err error) { time.Sleep(dur) } } - m, dur := r.controller.preRead(len(buf)) + m := r.controller.preRead(len(buf)) + n, err = r.Reader.Read(buf[:m]) + dur := r.controller.afterRead(n, m-n) if dur > 0 { r.readAfter = time.Now().Add(dur) } else { r.readAfter = time.Time{} } - return r.Reader.Read(buf[:m]) + return } func (r *LimitedReader) Close() error { @@ -346,13 +395,15 @@ func (c *LimitedConn) Read(buf []byte) (n int, err error) { time.Sleep(dur) } } - m, dur := c.controller.preRead(len(buf)) + m := c.controller.preRead(len(buf)) + n, err = c.Conn.Read(buf[:m]) + dur := c.controller.afterRead(n, m-n) if dur > 0 { c.readAfter = time.Now().Add(dur) } else { c.readAfter = time.Time{} } - return c.Conn.Read(buf[:m]) + return } func (c *LimitedConn) Write(buf []byte) (n int, err error) {