Skip to content

Commit

Permalink
fix: data races (#224)
Browse files Browse the repository at this point in the history
* fix: data race in value ping error

* fix(cli): data race in value packetloss

* fix(api): data race in value repeat byte

* fix(api): data race in value totalDataVolume

* fix(api): data race in value running

* fix lint

* fix(api): data race in value running

* revert context
  • Loading branch information
r3inbowari authored Jul 13, 2024
1 parent 320467f commit 332a0d7
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 25 deletions.
5 changes: 5 additions & 0 deletions speedtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"os"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"

Expand Down Expand Up @@ -149,9 +150,12 @@ func main() {
SourceInterface: *source,
})

blocker := sync.WaitGroup{}
packetLossAnalyzerCtx, packetLossAnalyzerCancel := context.WithTimeout(context.Background(), time.Second*40)
taskManager.Run("Packet Loss Analyzer", func(task *Task) {
blocker.Add(1)
go func() {
defer blocker.Done()
err = analyzer.RunWithContext(packetLossAnalyzerCtx, server.Host, func(packetLoss *transport.PLoss) {
server.PacketLoss = *packetLoss
})
Expand Down Expand Up @@ -211,6 +215,7 @@ func main() {
time.Sleep(time.Second * 30)
}
packetLossAnalyzerCancel()
blocker.Wait()
if !*jsonOutput {
taskManager.Println(server.PacketLoss.String())
}
Expand Down
59 changes: 39 additions & 20 deletions speedtest/data_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ type DataManager struct {
rateCaptureFrequency time.Duration
nThread int

running bool
running bool
runningRW sync.RWMutex

download *TestDirection
upload *TestDirection
Expand All @@ -114,11 +115,13 @@ func (dm *DataManager) NewDataDirection(testType int) *TestDirection {
}

func NewDataManager() *DataManager {
r := bytes.Repeat([]byte{0xAA}, readChunkSize) // uniformly distributed sequence of bits
ret := &DataManager{
nThread: runtime.NumCPU(),
captureTime: time.Second * 15,
rateCaptureFrequency: time.Millisecond * 50,
Snapshot: &Snapshot{},
repeatByte: &r,
}
ret.download = ret.NewDataDirection(typeDownload)
ret.upload = ret.NewDataDirection(typeUpload)
Expand Down Expand Up @@ -169,6 +172,14 @@ func (dm *DataManager) RegisterDownloadHandler(fn func()) *TestDirection {
return dm.download
}

func (td *TestDirection) GetTotalDataVolume() int64 {
return atomic.LoadInt64(&td.totalDataVolume)
}

func (td *TestDirection) AddTotalDataVolume(delta int64) int64 {
return atomic.AddInt64(&td.totalDataVolume, delta)
}

func (td *TestDirection) Start(cancel context.CancelFunc, mainRequestHandlerIndex int) {
if len(td.fns) == 0 {
panic("empty task stack")
Expand Down Expand Up @@ -200,7 +211,9 @@ func (td *TestDirection) Start(cancel context.CancelFunc, mainRequestHandlerInde
once.Do(func() {
stopCapture <- true
close(stopCapture)
td.manager.runningRW.Lock()
td.manager.running = false
td.manager.runningRW.Unlock()
cancel()
dbg.Println("FuncGroup: Stop")
})
Expand All @@ -212,7 +225,10 @@ func (td *TestDirection) Start(cancel context.CancelFunc, mainRequestHandlerInde
go func() {
defer wg.Done()
for {
if !td.manager.running {
td.manager.runningRW.RLock()
running := td.manager.running
td.manager.runningRW.RUnlock()
if !running {
return
}
td.fns[mainRequestHandlerIndex]()
Expand All @@ -232,7 +248,10 @@ func (td *TestDirection) Start(cancel context.CancelFunc, mainRequestHandlerInde
go func() {
defer wg.Done()
for {
if !td.manager.running {
td.manager.runningRW.RLock()
running := td.manager.running
td.manager.runningRW.RUnlock()
if !running {
return
}
td.fns[t]()
Expand All @@ -255,14 +274,14 @@ func (td *TestDirection) rateCapture() chan bool {
for {
select {
case <-t.C:
newTotalDataVolume := td.totalDataVolume
newTotalDataVolume := td.GetTotalDataVolume()
deltaDataVolume := newTotalDataVolume - prevTotalDataVolume
prevTotalDataVolume = newTotalDataVolume
if deltaDataVolume != 0 {
td.RateSequence = append(td.RateSequence, deltaDataVolume)
}
// anyway we update the measuring instrument
globalAvg := (float64(td.totalDataVolume)) / float64(time.Since(sTime).Milliseconds()) * 1000
globalAvg := (float64(td.GetTotalDataVolume())) / float64(time.Since(sTime).Milliseconds()) * 1000
if td.welford.Update(globalAvg, float64(deltaDataVolume)) {
go td.closeFunc()
}
Expand Down Expand Up @@ -290,19 +309,19 @@ func (dm *DataManager) NewChunk() Chunk {
}

func (dm *DataManager) AddTotalDownload(value int64) {
atomic.AddInt64(&dm.download.totalDataVolume, value)
dm.download.AddTotalDataVolume(value)
}

func (dm *DataManager) AddTotalUpload(value int64) {
atomic.AddInt64(&dm.upload.totalDataVolume, value)
dm.upload.AddTotalDataVolume(value)
}

func (dm *DataManager) GetTotalDownload() int64 {
return dm.download.totalDataVolume
return dm.download.GetTotalDataVolume()
}

func (dm *DataManager) GetTotalUpload() int64 {
return dm.upload.totalDataVolume
return dm.upload.GetTotalDataVolume()
}

func (dm *DataManager) SetRateCaptureFrequency(duration time.Duration) Manager {
Expand Down Expand Up @@ -337,7 +356,7 @@ func (dm *DataManager) Reset() {

func (dm *DataManager) GetAvgDownloadRate() float64 {
unit := float64(dm.captureTime / time.Millisecond)
return float64(dm.download.totalDataVolume*8/1000) / unit
return float64(dm.download.GetTotalDataVolume()*8/1000) / unit
}

func (dm *DataManager) GetEWMADownloadRate() float64 {
Expand All @@ -349,7 +368,7 @@ func (dm *DataManager) GetEWMADownloadRate() float64 {

func (dm *DataManager) GetAvgUploadRate() float64 {
unit := float64(dm.captureTime / time.Millisecond)
return float64(dm.upload.totalDataVolume*8/1000) / unit
return float64(dm.upload.GetTotalDataVolume()*8/1000) / unit
}

func (dm *DataManager) GetEWMAUploadRate() float64 {
Expand Down Expand Up @@ -405,14 +424,17 @@ func (dc *DataChunk) DownloadHandler(r io.Reader) error {
defer blackHolePool.Put(bufP)
readSize := 0
for {
if !dc.manager.running {
dc.manager.runningRW.RLock()
running := dc.manager.running
dc.manager.runningRW.RUnlock()
if !running {
return nil
}
readSize, dc.err = r.Read(*bufP)
rs := int64(readSize)

dc.remainOrDiscardSize += rs
atomic.AddInt64(&dc.manager.download.totalDataVolume, rs)
dc.manager.download.AddTotalDataVolume(rs)
if dc.err != nil {
if dc.err == io.EOF {
return nil
Expand All @@ -434,12 +456,6 @@ func (dc *DataChunk) UploadHandler(size int64) Chunk {
dc.ContentLength = size
dc.remainOrDiscardSize = size
dc.dateType = typeUpload

if dc.manager.repeatByte == nil {
r := bytes.Repeat([]byte{0xAA}, readChunkSize) // uniformly distributed sequence of bits
dc.manager.repeatByte = &r
}

dc.startTime = time.Now()
return dc
}
Expand All @@ -453,7 +469,10 @@ func (dc *DataChunk) WriteTo(w io.Writer) (written int64, err error) {
nw := 0
nr := readChunkSize
for {
if !dc.manager.running || dc.remainOrDiscardSize <= 0 {
dc.manager.runningRW.RLock()
running := dc.manager.running
dc.manager.runningRW.RUnlock()
if !running || dc.remainOrDiscardSize <= 0 {
dc.endTime = time.Now()
return written, io.EOF
}
Expand Down
2 changes: 1 addition & 1 deletion speedtest/data_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func TestDataManager_AddTotalDownload(t *testing.T) {
}()
}
wg.Wait()
if dmp.download.totalDataVolume != 43521000000 {
if dmp.download.GetTotalDataVolume() != 43521000000 {
t.Fatal()
}
}
Expand Down
9 changes: 5 additions & 4 deletions speedtest/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,14 +291,15 @@ func (s *Speedtest) FetchServerListContext(ctx context.Context) (Servers, error)
wg.Add(1)
go func(gs *Server) {
var latency []int64
var errPing error
if s.config.PingMode == TCP {
latency, err = gs.TCPPing(pCtx, 1, time.Millisecond, nil)
latency, errPing = gs.TCPPing(pCtx, 1, time.Millisecond, nil)
} else if s.config.PingMode == ICMP {
latency, err = gs.ICMPPing(pCtx, 4*time.Second, 1, time.Millisecond, nil)
latency, errPing = gs.ICMPPing(pCtx, 4*time.Second, 1, time.Millisecond, nil)
} else {
latency, err = gs.HTTPPing(pCtx, 1, time.Millisecond, nil)
latency, errPing = gs.HTTPPing(pCtx, 1, time.Millisecond, nil)
}
if err != nil || len(latency) < 1 {
if errPing != nil || len(latency) < 1 {
gs.Latency = PingTimeout
} else {
gs.Latency = time.Duration(latency[0]) * time.Nanosecond
Expand Down

0 comments on commit 332a0d7

Please sign in to comment.