diff --git a/pkg/statistics/hot_cache.go b/pkg/statistics/hot_cache.go index 3f076734a7b..86f7d7d6b08 100644 --- a/pkg/statistics/hot_cache.go +++ b/pkg/statistics/hot_cache.go @@ -113,18 +113,20 @@ func (w *HotCache) IsRegionHot(region *core.RegionInfo, minHotDegree int) bool { succ1 := w.CheckWriteAsync(checkRegionHotWriteTask) succ2 := w.CheckReadAsync(checkRegionHotReadTask) if succ1 && succ2 { - select { - case <-w.ctx.Done(): - return false - case r := <-retWrite: - return r - case r := <-retRead: - return r - } + return waitRet(w.ctx, retWrite) || waitRet(w.ctx, retRead) } return false } +func waitRet(ctx context.Context, ret chan bool) bool { + select { + case <-ctx.Done(): + return false + case r := <-ret: + return r + } +} + // GetHotPeerStat returns hot peer stat with specified regionID and storeID. func (w *HotCache) GetHotPeerStat(kind utils.RWType, regionID, storeID uint64) *HotPeerStat { ret := make(chan *HotPeerStat, 1) diff --git a/pkg/statistics/hot_cache_test.go b/pkg/statistics/hot_cache_test.go new file mode 100644 index 00000000000..9794a4a8968 --- /dev/null +++ b/pkg/statistics/hot_cache_test.go @@ -0,0 +1,36 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package statistics + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "github.com/tikv/pd/pkg/statistics/utils" +) + +func TestIsHot(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cache := NewHotCache(ctx) + region := buildRegion(utils.Read, 3, 60) + stats := cache.CheckReadPeerSync(region, region.GetPeers(), []float64{100000000, 1000, 1000}, 60) + cache.Update(stats[0], utils.Read) + for i := 0; i < 100; i++ { + re.True(cache.IsRegionHot(region, 1)) + } +}