diff --git a/statistics/cmsketch_bench_test.go b/statistics/cmsketch_bench_test.go index 08666c4c2c3db..e68d102ca57da 100644 --- a/statistics/cmsketch_bench_test.go +++ b/statistics/cmsketch_bench_test.go @@ -123,10 +123,7 @@ func benchmarkMergeGlobalStatsTopNByConcurrencyWithHists(partitions int, b *test h.Buckets = append(h.Buckets, statistics.Bucket{Repeat: 10, Count: 40}) hists = append(hists, h) } - wrapper := &statistics.StatsWrapper{ - AllTopN: topNs, - AllHg: hists, - } + wrapper := statistics.NewStatsWrapper(hists, topNs) const mergeConcurrency = 4 batchSize := len(wrapper.AllTopN) / mergeConcurrency if batchSize < 1 { diff --git a/statistics/handle/handle.go b/statistics/handle/handle.go index 26db0a5d4d3de..67382845edf45 100644 --- a/statistics/handle/handle.go +++ b/statistics/handle/handle.go @@ -15,7 +15,6 @@ package handle import ( - "bytes" "context" "encoding/json" "fmt" @@ -921,19 +920,15 @@ func MergeGlobalStatsTopNByConcurrency(mergeConcurrency, mergeBatchSize int, wra // handle Error hasErr := false + errMsg := make([]string, 0) for resp := range respCh { if resp.Err != nil { hasErr = true + errMsg = append(errMsg, resp.Err.Error()) } resps = append(resps, resp) } if hasErr { - errMsg := make([]string, 0) - for _, resp := range resps { - if resp.Err != nil { - errMsg = append(errMsg, resp.Err.Error()) - } - } return nil, nil, nil, errors.New(strings.Join(errMsg, ",")) } @@ -945,17 +940,6 @@ func MergeGlobalStatsTopNByConcurrency(mergeConcurrency, mergeBatchSize int, wra sorted = append(sorted, resp.TopN.TopN...) } leftTopn = append(leftTopn, resp.PopedTopn...) - for i, removeTopn := range resp.RemoveVals { - // Remove the value from the Hists. - if len(removeTopn) > 0 { - tmp := removeTopn - slices.SortFunc(tmp, func(i, j statistics.TopNMeta) bool { - cmpResult := bytes.Compare(i.Encoded, j.Encoded) - return cmpResult < 0 - }) - wrapper.AllHg[i].RemoveVals(tmp) - } - } } globalTopN, popedTopn := statistics.GetMergedTopNFromSortedSlice(sorted, n) diff --git a/statistics/merge_worker.go b/statistics/merge_worker.go index 9ddbd95788a1e..3c3a3db4ba9c0 100644 --- a/statistics/merge_worker.go +++ b/statistics/merge_worker.go @@ -15,6 +15,7 @@ package statistics import ( + "sync" "sync/atomic" "time" @@ -44,6 +45,8 @@ type topnStatsMergeWorker struct { respCh chan<- *TopnStatsMergeResponse // the stats in the wrapper should only be read during the worker statsWrapper *StatsWrapper + // shardMutex is used to protect `statsWrapper.AllHg` + shardMutex []sync.Mutex } // NewTopnStatsMergeWorker returns topn merge worker @@ -57,6 +60,7 @@ func NewTopnStatsMergeWorker( respCh: respCh, } worker.statsWrapper = wrapper + worker.shardMutex = make([]sync.Mutex, len(wrapper.AllHg)) worker.killed = killed return worker } @@ -77,10 +81,9 @@ func NewTopnStatsMergeTask(start, end int) *TopnStatsMergeTask { // TopnStatsMergeResponse indicates topn merge worker response type TopnStatsMergeResponse struct { - TopN *TopN - PopedTopn []TopNMeta - RemoveVals [][]TopNMeta - Err error + Err error + TopN *TopN + PopedTopn []TopNMeta } // Run runs topn merge like statistics.MergePartTopN2GlobalTopN @@ -99,7 +102,6 @@ func (worker *topnStatsMergeWorker) Run(timeZone *time.Location, isIndex bool, return } partNum := len(allTopNs) - removeVals := make([][]TopNMeta, partNum) // Different TopN structures may hold the same value, we have to merge them. counter := make(map[hack.MutableString]float64) // datumMap is used to store the mapping from the string type to datum type. @@ -168,13 +170,13 @@ func (worker *topnStatsMergeWorker) Run(timeZone *time.Location, isIndex bool, if count != 0 { counter[encodedVal] += count // Remove the value corresponding to encodedVal from the histogram. - removeVals[j] = append(removeVals[j], TopNMeta{Encoded: datum.GetBytes(), Count: uint64(count)}) + worker.shardMutex[j].Lock() + worker.statsWrapper.AllHg[j].BinarySearchRemoveVal(TopNMeta{Encoded: datum.GetBytes(), Count: uint64(count)}) + worker.shardMutex[j].Unlock() } } } } - // record remove values - resp.RemoveVals = removeVals numTop := len(counter) if numTop == 0 {