Skip to content

Commit

Permalink
maybe fixed read limit
Browse files Browse the repository at this point in the history
  • Loading branch information
zyxkad committed Feb 22, 2024
1 parent 37ca115 commit 8ce5c6b
Showing 1 changed file with 111 additions and 60 deletions.
171 changes: 111 additions & 60 deletions limited_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,26 +36,46 @@ type RateController struct {
writeRate int // bytes per second
minWriteRate int

closed atomic.Bool
closeCh chan struct{}
mux sync.Mutex
lastWrite time.Time
wroteCount int
lastRead time.Time
readCount int
closed atomic.Bool
closeCh chan struct{}
rmux, wmux sync.Mutex
lastRead time.Time
preReadCount int
readCount int
lastWrite time.Time
wroteCount int
}

func NewRateController(maxConn int, readRate, writeRate int) *RateController {
return &RateController{
Semaphore: NewSemaphore(maxConn),
writeRate: writeRate,
minWriteRate: 256,
closeCh: make(chan struct{}, 0),
readRate: readRate,
minReadRate: 256,
closeCh: make(chan struct{}, 0),
writeRate: writeRate,
minWriteRate: 256,
}
}

func (l *RateController) ReadRate() int {
return l.readRate
}

func (l *RateController) SetReadRate(rate int) {
l.readRate = rate
if rate > 0 && rate < l.minReadRate {
l.minReadRate = rate
}
}

func (l *RateController) MinReadRate() int {
return l.minReadRate
}

func (l *RateController) SetMinReadRate(rate int) {
l.minReadRate = rate
}

func (l *RateController) WriteRate() int {
return l.writeRate
}
Expand All @@ -75,23 +95,82 @@ func (l *RateController) SetMinWriteRate(rate int) {
l.minWriteRate = rate
}

func (l *RateController) ReadRate() int {
return l.readRate
}
func (l *RateController) preRead(n int) int {
if n <= 0 {
return n
}
readRate := l.readRate
if readRate <= 0 {
return n
}

func (l *RateController) SetReadRate(rate int) {
l.readRate = rate
if rate > 0 && rate < l.minReadRate {
l.minReadRate = rate
l.rmux.Lock()
defer l.rmux.Unlock()

avgRate := readRate
if ln := l.Len(); ln > 0 {
avgRate /= ln
}
if n > avgRate {
n = avgRate
}
}

func (l *RateController) MinReadRate() int {
return l.minReadRate
now := time.Now()
diff := time.Second - now.Sub(l.lastRead)
if diff <= 0 {
l.lastRead = now
l.preReadCount = 0
l.readCount = 0
}
if l.preReadCount >= readRate {
m := l.minReadRate
if m > n {
m = n
}
return m
}
l.preReadCount += n
return n
}

func (l *RateController) SetMinReadRate(rate int) {
l.minReadRate = rate
func (l *RateController) afterRead(n int, less int) time.Duration {
readRate := l.readRate
if readRate <= 0 {
return 0
}
if n < 0 {
n = 0
}

l.rmux.Lock()
defer l.rmux.Unlock()

avgRate := readRate
if ln := l.Len(); ln > 0 {
avgRate /= ln
}
if n > avgRate {
n = avgRate
}

now := time.Now()
diff := time.Second - now.Sub(l.lastRead)
if diff <= 0 {
l.lastRead = now
l.preReadCount = 0
l.readCount = n
} else {
l.preReadCount -= less
l.readCount += n
}
if n >= avgRate {
// TODO: replace the magic number 3
return time.Second / 3
}
if l.readCount >= readRate {
return diff
}
return 0
}

func (l *RateController) preWrite(n int) (int, time.Duration) {
Expand All @@ -103,8 +182,8 @@ func (l *RateController) preWrite(n int) (int, time.Duration) {
return n, 0
}

l.mux.Lock()
defer l.mux.Unlock()
l.wmux.Lock()
defer l.wmux.Unlock()

now := time.Now()
diff := time.Second - now.Sub(l.lastWrite)
Expand All @@ -126,38 +205,6 @@ func (l *RateController) preWrite(n int) (int, time.Duration) {
return n, 0
}

func (l *RateController) preRead(n int) (int, time.Duration) {
if n <= 0 {
return n, 0
}
readRate := l.readRate
if readRate <= 0 {
return n, 0
}

l.mux.Lock()
defer l.mux.Unlock()

now := time.Now()
diff := time.Second - now.Sub(l.lastRead)
if diff <= 0 {
l.lastRead = now
l.readCount = 0
} else if l.readCount >= readRate {
m := l.minReadRate
if m > n {
m = n
}
return m, diff
}
if n > readRate {
l.readCount += readRate
return readRate, time.Second
}
l.readCount += n
return n, 0
}

// Close will interrupted the incoming operations
// it will not close or interrupt the proxied connections and its operations
func (l *RateController) Close() error {
Expand Down Expand Up @@ -245,13 +292,15 @@ func (r *LimitedReader) Read(buf []byte) (n int, err error) {
time.Sleep(dur)
}
}
m, dur := r.controller.preRead(len(buf))
m := r.controller.preRead(len(buf))
n, err = r.Reader.Read(buf[:m])
dur := r.controller.afterRead(n, m-n)
if dur > 0 {
r.readAfter = time.Now().Add(dur)
} else {
r.readAfter = time.Time{}
}
return r.Reader.Read(buf[:m])
return
}

func (r *LimitedReader) Close() error {
Expand Down Expand Up @@ -346,13 +395,15 @@ func (c *LimitedConn) Read(buf []byte) (n int, err error) {
time.Sleep(dur)
}
}
m, dur := c.controller.preRead(len(buf))
m := c.controller.preRead(len(buf))
n, err = c.Conn.Read(buf[:m])
dur := c.controller.afterRead(n, m-n)
if dur > 0 {
c.readAfter = time.Now().Add(dur)
} else {
c.readAfter = time.Time{}
}
return c.Conn.Read(buf[:m])
return
}

func (c *LimitedConn) Write(buf []byte) (n int, err error) {
Expand Down

0 comments on commit 8ce5c6b

Please sign in to comment.