diff --git a/common/clock/ratelimiter.go b/common/clock/ratelimiter.go index 74f7a9a59e8..cf46f9d085f 100644 --- a/common/clock/ratelimiter.go +++ b/common/clock/ratelimiter.go @@ -20,6 +20,8 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. +//go:generate mockgen -package=$GOPACKAGE -destination=ratelimiter_mock.go github.com/uber/cadence/common/clock Reservation + package clock import ( diff --git a/common/clock/ratelimiter_mock.go b/common/clock/ratelimiter_mock.go new file mode 100644 index 00000000000..f8a72327779 --- /dev/null +++ b/common/clock/ratelimiter_mock.go @@ -0,0 +1,88 @@ +// The MIT License (MIT) + +// Copyright (c) 2017-2020 Uber Technologies Inc. + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/uber/cadence/common/clock (interfaces: Reservation) +// +// Generated by this command: +// +// mockgen -package=clock -destination=ratelimiter_mock.go github.com/uber/cadence/common/clock Reservation +// + +// Package clock is a generated GoMock package. +package clock + +import ( + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockReservation is a mock of Reservation interface. +type MockReservation struct { + ctrl *gomock.Controller + recorder *MockReservationMockRecorder + isgomock struct{} +} + +// MockReservationMockRecorder is the mock recorder for MockReservation. +type MockReservationMockRecorder struct { + mock *MockReservation +} + +// NewMockReservation creates a new mock instance. +func NewMockReservation(ctrl *gomock.Controller) *MockReservation { + mock := &MockReservation{ctrl: ctrl} + mock.recorder = &MockReservationMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockReservation) EXPECT() *MockReservationMockRecorder { + return m.recorder +} + +// Allow mocks base method. +func (m *MockReservation) Allow() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Allow") + ret0, _ := ret[0].(bool) + return ret0 +} + +// Allow indicates an expected call of Allow. +func (mr *MockReservationMockRecorder) Allow() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Allow", reflect.TypeOf((*MockReservation)(nil).Allow)) +} + +// Used mocks base method. +func (m *MockReservation) Used(wasUsed bool) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Used", wasUsed) +} + +// Used indicates an expected call of Used. +func (mr *MockReservationMockRecorder) Used(wasUsed any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Used", reflect.TypeOf((*MockReservation)(nil).Used), wasUsed) +} diff --git a/common/dynamicconfig/constants.go b/common/dynamicconfig/constants.go index a5bfcc7d841..58623edfe13 100644 --- a/common/dynamicconfig/constants.go +++ b/common/dynamicconfig/constants.go @@ -943,6 +943,7 @@ const ( // Default value: 1 // Allowed filters: N/A TaskSchedulerDispatcherCount + TaskSchedulerGlobalDomainRPS // TaskCriticalRetryCount is the critical retry count for background tasks // when task attempt exceeds this threshold: // - task attempt metrics and additional error logs will be emitted @@ -1702,6 +1703,8 @@ const ( // Default value: false // Allowed filters: N/A TransferProcessorEnableValidator + TaskSchedulerEnableRateLimiter + TaskSchedulerEnableRateLimiterShadowMode // EnableAdminProtection is whether to enable admin checking // KeyName: history.enableAdminProtection // Value type: Bool @@ -3392,6 +3395,11 @@ var IntKeys = map[IntKey]DynamicInt{ Description: "TaskSchedulerDispatcherCount is the number of task dispatcher in task scheduler (only applies to host level task scheduler)", DefaultValue: 1, }, + TaskSchedulerGlobalDomainRPS: { + KeyName: "history.taskSchedulerGlobalDomainRPS", + Description: "TaskSchedulerGlobalDomainRPS is the task scheduling domain rate limit per second for the whole Cadence cluster", + DefaultValue: 1000, + }, TaskCriticalRetryCount: { KeyName: "history.taskCriticalRetryCount", Description: "TaskCriticalRetryCount is the critical retry count for background tasks, when task attempt exceeds this threshold:- task attempt metrics and additional error logs will be emitted- task priority will be lowered", @@ -4063,6 +4071,16 @@ var BoolKeys = map[BoolKey]DynamicBool{ Description: "TransferProcessorEnableValidator is whether validator should be enabled for transferQueueProcessor", DefaultValue: false, }, + TaskSchedulerEnableRateLimiter: { + KeyName: "history.taskSchedulerEnableRateLimiter", + Description: "TaskSchedulerEnableRateLimiter indicates whether the task scheduler rate limiter is enabled", + DefaultValue: false, + }, + TaskSchedulerEnableRateLimiterShadowMode: { + KeyName: "history.taskSchedulerEnableRateLimiterShadowMode", + Description: "TaskSchedulerEnableRateLimiterShadowMode indicates whether the task scheduler rate limiter is in shadow mode", + DefaultValue: true, + }, EnableAdminProtection: { KeyName: "history.enableAdminProtection", Description: "EnableAdminProtection is whether to enable admin checking", diff --git a/common/metrics/defs.go b/common/metrics/defs.go index 2edb6983d95..a6c41f44bce 100644 --- a/common/metrics/defs.go +++ b/common/metrics/defs.go @@ -815,6 +815,8 @@ const ( ParallelTaskProcessingScope // TaskSchedulerScope is used by task scheduler logic TaskSchedulerScope + // TaskSchedulerRateLimiterScope is used by task scheduler rate limiter logic + TaskSchedulerRateLimiterScope // HistoryArchiverScope is used by history archivers HistoryArchiverScope @@ -1755,6 +1757,7 @@ var ScopeDefs = map[ServiceIdx]map[int]scopeDefinition{ SequentialTaskProcessingScope: {operation: "SequentialTaskProcessing"}, ParallelTaskProcessingScope: {operation: "ParallelTaskProcessing"}, TaskSchedulerScope: {operation: "TaskScheduler"}, + TaskSchedulerRateLimiterScope: {operation: "TaskSchedulerRateLimiter"}, HistoryArchiverScope: {operation: "HistoryArchiver"}, VisibilityArchiverScope: {operation: "VisibilityArchiver"}, @@ -2334,6 +2337,7 @@ const ( TransferTaskMissingEventCounterPerDomain ReplicationTasksAppliedPerDomain WorkflowTerminateCounterPerDomain + TaskSchedulerThrottledCounterPerDomain TaskRedispatchQueuePendingTasksTimer @@ -3045,6 +3049,7 @@ var MetricDefs = map[ServiceIdx]map[int]metricDefinition{ TransferTaskMissingEventCounterPerDomain: {metricName: "transfer_task_missing_event_counter_per_domain", metricRollupName: "transfer_task_missing_event_counter", metricType: Counter}, ReplicationTasksAppliedPerDomain: {metricName: "replication_tasks_applied_per_domain", metricRollupName: "replication_tasks_applied", metricType: Counter}, WorkflowTerminateCounterPerDomain: {metricName: "workflow_terminate_counter_per_domain", metricRollupName: "workflow_terminate_counter", metricType: Counter}, + TaskSchedulerThrottledCounterPerDomain: {metricName: "task_scheduler_throttled_counter_per_domain", metricRollupName: "task_scheduler_throttled_counter", metricType: Counter}, TaskBatchCompleteCounter: {metricName: "task_batch_complete_counter", metricType: Counter}, TaskBatchCompleteFailure: {metricName: "task_batch_complete_error", metricType: Counter}, diff --git a/common/quotas/collection.go b/common/quotas/collection.go index ba8c8c1d7da..f68a72f1ba6 100644 --- a/common/quotas/collection.go +++ b/common/quotas/collection.go @@ -18,6 +18,8 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +//go:generate mockgen -package=$GOPACKAGE -destination=collection_mock.go github.com/uber/cadence/common/quotas ICollection + package quotas import "sync" diff --git a/common/quotas/collection_mock.go b/common/quotas/collection_mock.go new file mode 100644 index 00000000000..7058a327c79 --- /dev/null +++ b/common/quotas/collection_mock.go @@ -0,0 +1,76 @@ +// The MIT License (MIT) + +// Copyright (c) 2017-2020 Uber Technologies Inc. + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/uber/cadence/common/quotas (interfaces: ICollection) +// +// Generated by this command: +// +// mockgen -package=quotas -destination=collection_mock.go github.com/uber/cadence/common/quotas ICollection +// + +// Package quotas is a generated GoMock package. +package quotas + +import ( + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockICollection is a mock of ICollection interface. +type MockICollection struct { + ctrl *gomock.Controller + recorder *MockICollectionMockRecorder + isgomock struct{} +} + +// MockICollectionMockRecorder is the mock recorder for MockICollection. +type MockICollectionMockRecorder struct { + mock *MockICollection +} + +// NewMockICollection creates a new mock instance. +func NewMockICollection(ctrl *gomock.Controller) *MockICollection { + mock := &MockICollection{ctrl: ctrl} + mock.recorder = &MockICollectionMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockICollection) EXPECT() *MockICollectionMockRecorder { + return m.recorder +} + +// For mocks base method. +func (m *MockICollection) For(key string) Limiter { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "For", key) + ret0, _ := ret[0].(Limiter) + return ret0 +} + +// For indicates an expected call of For. +func (mr *MockICollectionMockRecorder) For(key any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "For", reflect.TypeOf((*MockICollection)(nil).For), key) +} diff --git a/service/history/config/config.go b/service/history/config/config.go index 95418971a98..b82240146fd 100644 --- a/service/history/config/config.go +++ b/service/history/config/config.go @@ -88,20 +88,23 @@ type Config struct { StandbyTaskMissingEventsDiscardDelay dynamicconfig.DurationPropertyFn // Task process settings - TaskProcessRPS dynamicconfig.IntPropertyFnWithDomainFilter - TaskSchedulerType dynamicconfig.IntPropertyFn - TaskSchedulerWorkerCount dynamicconfig.IntPropertyFn - TaskSchedulerShardWorkerCount dynamicconfig.IntPropertyFn - TaskSchedulerQueueSize dynamicconfig.IntPropertyFn - TaskSchedulerShardQueueSize dynamicconfig.IntPropertyFn - TaskSchedulerDispatcherCount dynamicconfig.IntPropertyFn - TaskSchedulerRoundRobinWeights dynamicconfig.MapPropertyFn - TaskCriticalRetryCount dynamicconfig.IntPropertyFn - ActiveTaskRedispatchInterval dynamicconfig.DurationPropertyFn - StandbyTaskRedispatchInterval dynamicconfig.DurationPropertyFn - StandbyTaskReReplicationContextTimeout dynamicconfig.DurationPropertyFnWithDomainIDFilter - EnableDropStuckTaskByDomainID dynamicconfig.BoolPropertyFnWithDomainIDFilter - ResurrectionCheckMinDelay dynamicconfig.DurationPropertyFnWithDomainFilter + TaskProcessRPS dynamicconfig.IntPropertyFnWithDomainFilter + TaskSchedulerType dynamicconfig.IntPropertyFn + TaskSchedulerWorkerCount dynamicconfig.IntPropertyFn + TaskSchedulerShardWorkerCount dynamicconfig.IntPropertyFn + TaskSchedulerQueueSize dynamicconfig.IntPropertyFn + TaskSchedulerShardQueueSize dynamicconfig.IntPropertyFn + TaskSchedulerDispatcherCount dynamicconfig.IntPropertyFn + TaskSchedulerRoundRobinWeights dynamicconfig.MapPropertyFn + TaskSchedulerGlobalDomainRPS dynamicconfig.IntPropertyFnWithDomainFilter + TaskSchedulerEnableRateLimiter dynamicconfig.BoolPropertyFn + TaskSchedulerEnableRateLimiterShadowMode dynamicconfig.BoolPropertyFnWithDomainFilter + TaskCriticalRetryCount dynamicconfig.IntPropertyFn + ActiveTaskRedispatchInterval dynamicconfig.DurationPropertyFn + StandbyTaskRedispatchInterval dynamicconfig.DurationPropertyFn + StandbyTaskReReplicationContextTimeout dynamicconfig.DurationPropertyFnWithDomainIDFilter + EnableDropStuckTaskByDomainID dynamicconfig.BoolPropertyFnWithDomainIDFilter + ResurrectionCheckMinDelay dynamicconfig.DurationPropertyFnWithDomainFilter // QueueProcessor settings QueueProcessorEnableSplit dynamicconfig.BoolPropertyFn @@ -365,20 +368,23 @@ func New(dc *dynamicconfig.Collection, numberOfShards int, maxMessageSize int, i DeleteHistoryEventContextTimeout: dc.GetIntProperty(dynamicconfig.DeleteHistoryEventContextTimeout), MaxResponseSize: maxMessageSize, - TaskProcessRPS: dc.GetIntPropertyFilteredByDomain(dynamicconfig.TaskProcessRPS), - TaskSchedulerType: dc.GetIntProperty(dynamicconfig.TaskSchedulerType), - TaskSchedulerWorkerCount: dc.GetIntProperty(dynamicconfig.TaskSchedulerWorkerCount), - TaskSchedulerShardWorkerCount: dc.GetIntProperty(dynamicconfig.TaskSchedulerShardWorkerCount), - TaskSchedulerQueueSize: dc.GetIntProperty(dynamicconfig.TaskSchedulerQueueSize), - TaskSchedulerShardQueueSize: dc.GetIntProperty(dynamicconfig.TaskSchedulerShardQueueSize), - TaskSchedulerDispatcherCount: dc.GetIntProperty(dynamicconfig.TaskSchedulerDispatcherCount), - TaskSchedulerRoundRobinWeights: dc.GetMapProperty(dynamicconfig.TaskSchedulerRoundRobinWeights), - TaskCriticalRetryCount: dc.GetIntProperty(dynamicconfig.TaskCriticalRetryCount), - ActiveTaskRedispatchInterval: dc.GetDurationProperty(dynamicconfig.ActiveTaskRedispatchInterval), - StandbyTaskRedispatchInterval: dc.GetDurationProperty(dynamicconfig.StandbyTaskRedispatchInterval), - StandbyTaskReReplicationContextTimeout: dc.GetDurationPropertyFilteredByDomainID(dynamicconfig.StandbyTaskReReplicationContextTimeout), - EnableDropStuckTaskByDomainID: dc.GetBoolPropertyFilteredByDomainID(dynamicconfig.EnableDropStuckTaskByDomainID), - ResurrectionCheckMinDelay: dc.GetDurationPropertyFilteredByDomain(dynamicconfig.ResurrectionCheckMinDelay), + TaskProcessRPS: dc.GetIntPropertyFilteredByDomain(dynamicconfig.TaskProcessRPS), + TaskSchedulerType: dc.GetIntProperty(dynamicconfig.TaskSchedulerType), + TaskSchedulerWorkerCount: dc.GetIntProperty(dynamicconfig.TaskSchedulerWorkerCount), + TaskSchedulerShardWorkerCount: dc.GetIntProperty(dynamicconfig.TaskSchedulerShardWorkerCount), + TaskSchedulerQueueSize: dc.GetIntProperty(dynamicconfig.TaskSchedulerQueueSize), + TaskSchedulerShardQueueSize: dc.GetIntProperty(dynamicconfig.TaskSchedulerShardQueueSize), + TaskSchedulerDispatcherCount: dc.GetIntProperty(dynamicconfig.TaskSchedulerDispatcherCount), + TaskSchedulerRoundRobinWeights: dc.GetMapProperty(dynamicconfig.TaskSchedulerRoundRobinWeights), + TaskSchedulerGlobalDomainRPS: dc.GetIntPropertyFilteredByDomain(dynamicconfig.TaskSchedulerGlobalDomainRPS), + TaskSchedulerEnableRateLimiter: dc.GetBoolProperty(dynamicconfig.TaskSchedulerEnableRateLimiter), + TaskSchedulerEnableRateLimiterShadowMode: dc.GetBoolPropertyFilteredByDomain(dynamicconfig.TaskSchedulerEnableRateLimiterShadowMode), + TaskCriticalRetryCount: dc.GetIntProperty(dynamicconfig.TaskCriticalRetryCount), + ActiveTaskRedispatchInterval: dc.GetDurationProperty(dynamicconfig.ActiveTaskRedispatchInterval), + StandbyTaskRedispatchInterval: dc.GetDurationProperty(dynamicconfig.StandbyTaskRedispatchInterval), + StandbyTaskReReplicationContextTimeout: dc.GetDurationPropertyFilteredByDomainID(dynamicconfig.StandbyTaskReReplicationContextTimeout), + EnableDropStuckTaskByDomainID: dc.GetBoolPropertyFilteredByDomainID(dynamicconfig.EnableDropStuckTaskByDomainID), + ResurrectionCheckMinDelay: dc.GetDurationPropertyFilteredByDomain(dynamicconfig.ResurrectionCheckMinDelay), QueueProcessorEnableSplit: dc.GetBoolProperty(dynamicconfig.QueueProcessorEnableSplit), QueueProcessorSplitMaxLevel: dc.GetIntProperty(dynamicconfig.QueueProcessorSplitMaxLevel), diff --git a/service/history/config/config_test.go b/service/history/config/config_test.go index 80625109b72..c85cddcae97 100644 --- a/service/history/config/config_test.go +++ b/service/history/config/config_test.go @@ -253,6 +253,9 @@ func TestNewConfig(t *testing.T) { "GlobalRatelimiterUpdateInterval": {dynamicconfig.GlobalRatelimiterUpdateInterval, time.Second}, "GlobalRatelimiterDecayAfter": {dynamicconfig.HistoryGlobalRatelimiterDecayAfter, time.Second}, "GlobalRatelimiterGCAfter": {dynamicconfig.HistoryGlobalRatelimiterGCAfter, time.Second}, + "TaskSchedulerGlobalDomainRPS": {dynamicconfig.TaskSchedulerGlobalDomainRPS, 97}, + "TaskSchedulerEnableRateLimiterShadowMode": {dynamicconfig.TaskSchedulerEnableRateLimiterShadowMode, false}, + "TaskSchedulerEnableRateLimiter": {dynamicconfig.TaskSchedulerEnableRateLimiter, true}, "HostName": {nil, hostname}, } client := dynamicconfig.NewInMemoryClient() diff --git a/service/history/handler/handler.go b/service/history/handler/handler.go index b1bd0fce322..c874ba4308a 100644 --- a/service/history/handler/handler.go +++ b/service/history/handler/handler.go @@ -128,7 +128,14 @@ func (h *handlerImpl) Start() { h.config, ) - h.queueTaskProcessor, err = task.NewProcessor( + h.controller = shard.NewShardController( + h.Resource, + h, + h.config, + ) + + var taskProcessor task.Processor + taskProcessor, err = task.NewProcessor( taskPriorityAssigner, h.config, h.GetLogger(), @@ -138,13 +145,16 @@ func (h *handlerImpl) Start() { if err != nil { h.GetLogger().Fatal("Creating priority task processor failed", tag.Error(err)) } - h.queueTaskProcessor.Start() - - h.controller = shard.NewShardController( - h.Resource, - h, + taskRateLimiter := task.NewRateLimiter( + h.GetLogger(), + h.GetMetricsClient(), + h.GetDomainCache(), h.config, + h.controller, ) + h.queueTaskProcessor = task.NewRateLimitedProcessor(taskProcessor, taskRateLimiter) + h.queueTaskProcessor.Start() + h.historyEventNotifier = events.NewNotifier(h.GetTimeSource(), h.GetMetricsClient(), h.config.GetShardID) // events notifier must starts before controller h.historyEventNotifier.Start() diff --git a/service/history/task/rate_limited_processor.go b/service/history/task/rate_limited_processor.go new file mode 100644 index 00000000000..e1e46301250 --- /dev/null +++ b/service/history/task/rate_limited_processor.go @@ -0,0 +1,88 @@ +// The MIT License (MIT) + +// Copyright (c) 2017-2020 Uber Technologies Inc. + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package task + +import ( + "context" + "sync/atomic" + + "github.com/uber/cadence/common" + "github.com/uber/cadence/service/history/shard" +) + +type rateLimitedProcessor struct { + baseProcessor Processor + rateLimiter RateLimiter + cancelCtx context.Context + cancelFn context.CancelFunc + status int32 +} + +func NewRateLimitedProcessor( + baseProcessor Processor, + rateLimiter RateLimiter, +) Processor { + ctx, cancel := context.WithCancel(context.Background()) + return &rateLimitedProcessor{ + baseProcessor: baseProcessor, + rateLimiter: rateLimiter, + cancelCtx: ctx, + cancelFn: cancel, + status: common.DaemonStatusInitialized, + } +} + +func (p *rateLimitedProcessor) Start() { + if !atomic.CompareAndSwapInt32(&p.status, common.DaemonStatusInitialized, common.DaemonStatusStarted) { + return + } + + p.baseProcessor.Start() +} + +func (p *rateLimitedProcessor) Stop() { + if !atomic.CompareAndSwapInt32(&p.status, common.DaemonStatusStarted, common.DaemonStatusStopped) { + return + } + + p.cancelFn() + p.baseProcessor.Stop() +} + +func (p *rateLimitedProcessor) StopShardProcessor(s shard.Context) { + p.baseProcessor.StopShardProcessor(s) +} + +func (p *rateLimitedProcessor) Submit(t Task) error { + if err := p.rateLimiter.Wait(p.cancelCtx, t); err != nil { + return err + } + return p.baseProcessor.Submit(t) +} + +func (p *rateLimitedProcessor) TrySubmit(t Task) (bool, error) { + if ok := p.rateLimiter.Allow(t); !ok { + return false, nil + } + return p.baseProcessor.TrySubmit(t) +} diff --git a/service/history/task/rate_limited_processor_test.go b/service/history/task/rate_limited_processor_test.go new file mode 100644 index 00000000000..d7ef59960c0 --- /dev/null +++ b/service/history/task/rate_limited_processor_test.go @@ -0,0 +1,174 @@ +// The MIT License (MIT) + +// Copyright (c) 2017-2020 Uber Technologies Inc. + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package task + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + gomock "go.uber.org/mock/gomock" + + "github.com/uber/cadence/service/history/shard" +) + +type rateLimitedProcessorMockDeps struct { + mockProcessor *MockProcessor + mockRateLimiter *MockRateLimiter +} + +func setupMocksForRateLimitedProcessor(t *testing.T) (*rateLimitedProcessor, *rateLimitedProcessorMockDeps) { + ctrl := gomock.NewController(t) + + deps := &rateLimitedProcessorMockDeps{ + mockProcessor: NewMockProcessor(ctrl), + mockRateLimiter: NewMockRateLimiter(ctrl), + } + + processor := NewRateLimitedProcessor(deps.mockProcessor, deps.mockRateLimiter) + rp, ok := processor.(*rateLimitedProcessor) + require.True(t, ok) + return rp, deps +} + +func TestRateLimitedProcessorLifecycle(t *testing.T) { + rp, deps := setupMocksForRateLimitedProcessor(t) + + deps.mockProcessor.EXPECT().Start().Times(1) + rp.Start() + + var shard shard.Context + deps.mockProcessor.EXPECT().StopShardProcessor(shard).Times(1) + rp.StopShardProcessor(shard) + + deps.mockProcessor.EXPECT().Stop().Times(1) + rp.Stop() +} + +func TestRateLimitedProcessorSubmit(t *testing.T) { + testCases := []struct { + name string + task Task + setupMocks func(*rateLimitedProcessorMockDeps) + expectError bool + expectedError string + }{ + { + name: "success", + task: &noopTask{}, + setupMocks: func(deps *rateLimitedProcessorMockDeps) { + deps.mockRateLimiter.EXPECT().Wait(gomock.Any(), gomock.Any()).Return(nil).Times(1) + deps.mockProcessor.EXPECT().Submit(gomock.Any()).Return(nil).Times(1) + }, + }, + { + name: "rate limiter error", + task: &noopTask{}, + setupMocks: func(deps *rateLimitedProcessorMockDeps) { + deps.mockRateLimiter.EXPECT().Wait(gomock.Any(), gomock.Any()).Return(errors.New("rate limited")).Times(1) + }, + expectError: true, + expectedError: "rate limited", + }, + { + name: "processor error", + task: &noopTask{}, + setupMocks: func(deps *rateLimitedProcessorMockDeps) { + deps.mockRateLimiter.EXPECT().Wait(gomock.Any(), gomock.Any()).Return(nil).Times(1) + deps.mockProcessor.EXPECT().Submit(gomock.Any()).Return(errors.New("processor error")).Times(1) + }, + expectError: true, + expectedError: "processor error", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rp, deps := setupMocksForRateLimitedProcessor(t) + tc.setupMocks(deps) + + err := rp.Submit(tc.task) + if tc.expectError { + assert.ErrorContains(t, err, tc.expectedError) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestRateLimitedProcessorTrySubmit(t *testing.T) { + testCases := []struct { + name string + task Task + setupMocks func(*rateLimitedProcessorMockDeps) + expected bool + expectError bool + expectedError string + }{ + { + name: "success", + task: &noopTask{}, + setupMocks: func(deps *rateLimitedProcessorMockDeps) { + deps.mockRateLimiter.EXPECT().Allow(gomock.Any()).Return(true) + deps.mockProcessor.EXPECT().TrySubmit(gomock.Any()).Return(true, nil) + }, + expected: true, + }, + { + name: "rate limited", + task: &noopTask{}, + setupMocks: func(deps *rateLimitedProcessorMockDeps) { + deps.mockRateLimiter.EXPECT().Allow(gomock.Any()).Return(false) + }, + expected: false, + }, + { + name: "error", + task: &noopTask{}, + setupMocks: func(deps *rateLimitedProcessorMockDeps) { + deps.mockRateLimiter.EXPECT().Allow(gomock.Any()).Return(true) + deps.mockProcessor.EXPECT().TrySubmit(gomock.Any()).Return(false, errors.New("submit error")) + }, + expected: false, + expectError: true, + expectedError: "submit error", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rp, deps := setupMocksForRateLimitedProcessor(t) + tc.setupMocks(deps) + + submitted, err := rp.TrySubmit(tc.task) + if tc.expectError { + assert.ErrorContains(t, err, tc.expectedError) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, submitted) + } + }) + } +} diff --git a/service/history/task/task_rate_limiter.go b/service/history/task/task_rate_limiter.go new file mode 100644 index 00000000000..2ce8b38eba6 --- /dev/null +++ b/service/history/task/task_rate_limiter.go @@ -0,0 +1,119 @@ +// The MIT License (MIT) + +// Copyright (c) 2017-2020 Uber Technologies Inc. + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +//go:generate mockgen -package $GOPACKAGE -destination task_rate_limiter_mock.go github.com/uber/cadence/service/history/task RateLimiter + +package task + +import ( + "context" + + "github.com/uber/cadence/common/cache" + "github.com/uber/cadence/common/dynamicconfig" + "github.com/uber/cadence/common/log" + "github.com/uber/cadence/common/log/tag" + "github.com/uber/cadence/common/metrics" + "github.com/uber/cadence/common/quotas" + "github.com/uber/cadence/service/history/config" + "github.com/uber/cadence/service/history/shard" +) + +type ( + RateLimiter interface { + Allow(Task) bool + Wait(context.Context, Task) error + } + + taskRateLimiterImpl struct { + logger log.Logger + metricsScope metrics.Scope + domainCache cache.DomainCache + limiters quotas.ICollection + enabled dynamicconfig.BoolPropertyFn + enableShadowMode dynamicconfig.BoolPropertyFnWithDomainFilter + } +) + +func NewRateLimiter( + logger log.Logger, + metricsClient metrics.Client, + domainCache cache.DomainCache, + config *config.Config, + controller shard.Controller, +) RateLimiter { + rps := func(domain string) int { + totalShards := float64(config.NumberOfShards) + totalRPS := float64(config.TaskSchedulerGlobalDomainRPS(domain)) + numShards := float64(controller.NumShards()) + return int(totalRPS * numShards / totalShards) + } + limiterFactory := quotas.NewSimpleDynamicRateLimiterFactory(rps) + return &taskRateLimiterImpl{ + logger: logger, + metricsScope: metricsClient.Scope(metrics.TaskSchedulerRateLimiterScope), + domainCache: domainCache, + enabled: config.TaskSchedulerEnableRateLimiter, + enableShadowMode: config.TaskSchedulerEnableRateLimiterShadowMode, + limiters: quotas.NewCollection(limiterFactory), + } +} + +func (r *taskRateLimiterImpl) Allow(t Task) bool { + if !r.enabled() { + return true + } + limiter, scope, shadow := r.getLimiter(t) + allow := limiter.Allow() + if allow { + return true + } + scope.IncCounter(metrics.TaskSchedulerThrottledCounterPerDomain) + return shadow +} + +func (r *taskRateLimiterImpl) Wait(ctx context.Context, t Task) error { + if !r.enabled() { + return nil + } + limiter, scope, shadow := r.getLimiter(t) + rsv := limiter.Reserve() + allow := rsv.Allow() + if allow { + rsv.Used(true) + return nil + } + rsv.Used(false) + scope.IncCounter(metrics.TaskSchedulerThrottledCounterPerDomain) + if shadow { + return nil + } + return limiter.Wait(ctx) +} + +func (r *taskRateLimiterImpl) getLimiter(t Task) (quotas.Limiter, metrics.Scope, bool) { + domainID := t.GetDomainID() + domainName, err := r.domainCache.GetDomainName(domainID) + if err != nil { + r.logger.Warn("failed to get domain name from domain cache", tag.WorkflowDomainID(domainID), tag.Error(err)) + } + return r.limiters.For(domainName), r.metricsScope.Tagged(metrics.DomainTag(domainName)), r.enableShadowMode(domainName) +} diff --git a/service/history/task/task_rate_limiter_mock.go b/service/history/task/task_rate_limiter_mock.go new file mode 100644 index 00000000000..3f1fe38be4b --- /dev/null +++ b/service/history/task/task_rate_limiter_mock.go @@ -0,0 +1,91 @@ +// The MIT License (MIT) + +// Copyright (c) 2017-2020 Uber Technologies Inc. + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/uber/cadence/service/history/task (interfaces: RateLimiter) +// +// Generated by this command: +// +// mockgen -package task -destination task_rate_limiter_mock.go github.com/uber/cadence/service/history/task RateLimiter +// + +// Package task is a generated GoMock package. +package task + +import ( + context "context" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockRateLimiter is a mock of RateLimiter interface. +type MockRateLimiter struct { + ctrl *gomock.Controller + recorder *MockRateLimiterMockRecorder + isgomock struct{} +} + +// MockRateLimiterMockRecorder is the mock recorder for MockRateLimiter. +type MockRateLimiterMockRecorder struct { + mock *MockRateLimiter +} + +// NewMockRateLimiter creates a new mock instance. +func NewMockRateLimiter(ctrl *gomock.Controller) *MockRateLimiter { + mock := &MockRateLimiter{ctrl: ctrl} + mock.recorder = &MockRateLimiterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRateLimiter) EXPECT() *MockRateLimiterMockRecorder { + return m.recorder +} + +// Allow mocks base method. +func (m *MockRateLimiter) Allow(arg0 Task) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Allow", arg0) + ret0, _ := ret[0].(bool) + return ret0 +} + +// Allow indicates an expected call of Allow. +func (mr *MockRateLimiterMockRecorder) Allow(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Allow", reflect.TypeOf((*MockRateLimiter)(nil).Allow), arg0) +} + +// Wait mocks base method. +func (m *MockRateLimiter) Wait(arg0 context.Context, arg1 Task) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Wait", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// Wait indicates an expected call of Wait. +func (mr *MockRateLimiterMockRecorder) Wait(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Wait", reflect.TypeOf((*MockRateLimiter)(nil).Wait), arg0, arg1) +} diff --git a/service/history/task/task_rate_limiter_test.go b/service/history/task/task_rate_limiter_test.go new file mode 100644 index 00000000000..69c0aa7497d --- /dev/null +++ b/service/history/task/task_rate_limiter_test.go @@ -0,0 +1,418 @@ +// The MIT License (MIT) + +// Copyright (c) 2017-2020 Uber Technologies Inc. + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package task + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + gomock "go.uber.org/mock/gomock" + + "github.com/uber/cadence/common/cache" + "github.com/uber/cadence/common/clock" + "github.com/uber/cadence/common/dynamicconfig" + "github.com/uber/cadence/common/log/testlogger" + "github.com/uber/cadence/common/metrics" + "github.com/uber/cadence/common/quotas" + ctask "github.com/uber/cadence/common/task" + "github.com/uber/cadence/service/history/config" + "github.com/uber/cadence/service/history/shard" +) + +type ( + noopTask struct { + sync.RWMutex + *noopTaskInfo + queueType QueueType + shard shard.Context + attempt int + priority int + state ctask.State + } + + noopTaskInfo struct { + version int64 + taskID int64 + taskType int + visibilityTimestamp time.Time + workflowID string + runID string + domainID string + } +) + +func (s *noopTask) Execute() error { + return nil +} + +func (s *noopTask) HandleErr(err error) error { + return nil +} + +func (s *noopTask) RetryErr(err error) bool { + return false +} + +func (s *noopTask) Ack() { + s.Lock() + defer s.Unlock() + s.state = ctask.TaskStateAcked +} + +func (s *noopTask) Nack() { + s.Lock() + defer s.Unlock() + s.state = ctask.TaskStateNacked +} + +func (s *noopTask) State() ctask.State { + s.RLock() + defer s.RUnlock() + return s.state +} + +func (s *noopTask) Priority() int { + s.RLock() + defer s.RUnlock() + return s.priority +} + +func (s *noopTask) SetPriority(p int) { + s.priority = p +} + +func (s *noopTask) GetShard() shard.Context { + return s.shard +} + +func (s *noopTask) GetQueueType() QueueType { + return s.queueType +} + +func (s *noopTask) GetAttempt() int { + return s.attempt +} + +func (s *noopTask) GetInfo() Info { + return s.noopTaskInfo +} + +func (s *noopTaskInfo) GetVersion() int64 { + return s.version +} + +func (s *noopTaskInfo) GetTaskID() int64 { + return s.taskID +} + +func (s *noopTaskInfo) GetTaskType() int { + return s.taskType +} + +func (s *noopTaskInfo) GetVisibilityTimestamp() time.Time { + return s.visibilityTimestamp +} + +func (s *noopTaskInfo) GetWorkflowID() string { + return s.workflowID +} + +func (s *noopTaskInfo) GetRunID() string { + return s.runID +} + +func (s *noopTaskInfo) GetDomainID() string { + return s.domainID +} + +type taskRateLimiterMockDeps struct { + ctrl *gomock.Controller + mockDomainCache *cache.MockDomainCache + mockShardController *shard.MockController + mockICollection *quotas.MockICollection + dynamicClient dynamicconfig.Client +} + +func setupMocksForTaskRateLimiter(t *testing.T, mockQuotasCollection bool) (*taskRateLimiterImpl, *taskRateLimiterMockDeps) { + ctrl := gomock.NewController(t) + mockDomainCache := cache.NewMockDomainCache(ctrl) + mockShardController := shard.NewMockController(ctrl) + dynamicClient := dynamicconfig.NewInMemoryClient() + + deps := &taskRateLimiterMockDeps{ + ctrl: ctrl, + mockDomainCache: mockDomainCache, + mockShardController: mockShardController, + dynamicClient: dynamicClient, + } + + config := config.New( + dynamicconfig.NewCollection( + dynamicClient, + testlogger.New(t), + ), + 16, + 1024, + false, + "hostname", + ) + + rateLimiter := NewRateLimiter( + testlogger.New(t), + metrics.NewNoopMetricsClient(), + deps.mockDomainCache, + config, + deps.mockShardController, + ) + r, ok := rateLimiter.(*taskRateLimiterImpl) + require.True(t, ok, "rate limiter type assertion failure") + if mockQuotasCollection { + deps.mockICollection = quotas.NewMockICollection(ctrl) + r.limiters = deps.mockICollection + } + return r, deps +} + +func TestRateLimiterRPS(t *testing.T) { + r, deps := setupMocksForTaskRateLimiter(t, false) + + deps.mockShardController.EXPECT().NumShards().Return(8).AnyTimes() + require.NoError(t, deps.dynamicClient.UpdateValue(dynamicconfig.TaskSchedulerGlobalDomainRPS, 100)) + + l := r.limiters.For("test-domain").Limit() + assert.Equal(t, 50, int(l)) + + require.NoError(t, deps.dynamicClient.UpdateValue(dynamicconfig.TaskSchedulerGlobalDomainRPS, 200)) + l = r.limiters.For("test-domain").Limit() + assert.Equal(t, 100, int(l)) +} + +func TestRateLimiterAllow(t *testing.T) { + testCases := []struct { + name string + task Task + setupMocks func(*taskRateLimiterMockDeps) + expected bool + }{ + { + name: "Not rate limited", + task: &noopTask{ + noopTaskInfo: &noopTaskInfo{ + domainID: "test-domain-id", + }, + }, + setupMocks: func(deps *taskRateLimiterMockDeps) { + require.NoError(t, deps.dynamicClient.UpdateValue(dynamicconfig.TaskSchedulerGlobalDomainRPS, 100)) + require.NoError(t, deps.dynamicClient.UpdateValue(dynamicconfig.TaskSchedulerEnableRateLimiter, true)) + require.NoError(t, deps.dynamicClient.UpdateValue(dynamicconfig.TaskSchedulerEnableRateLimiterShadowMode, false)) + deps.mockDomainCache.EXPECT().GetDomainName("test-domain-id").Return("test-domain", nil) + limiter := quotas.NewMockLimiter(deps.ctrl) + deps.mockICollection.EXPECT().For("test-domain").Return(limiter) + limiter.EXPECT().Allow().Return(true) + }, + expected: true, + }, + { + name: "Not rate limited - domain cache error ignored", + task: &noopTask{ + noopTaskInfo: &noopTaskInfo{ + domainID: "test-domain-id", + }, + }, + setupMocks: func(deps *taskRateLimiterMockDeps) { + require.NoError(t, deps.dynamicClient.UpdateValue(dynamicconfig.TaskSchedulerGlobalDomainRPS, 100)) + require.NoError(t, deps.dynamicClient.UpdateValue(dynamicconfig.TaskSchedulerEnableRateLimiter, true)) + require.NoError(t, deps.dynamicClient.UpdateValue(dynamicconfig.TaskSchedulerEnableRateLimiterShadowMode, false)) + deps.mockDomainCache.EXPECT().GetDomainName("test-domain-id").Return("", errors.New("cache error")) + limiter := quotas.NewMockLimiter(deps.ctrl) + deps.mockICollection.EXPECT().For("").Return(limiter) + limiter.EXPECT().Allow().Return(true) + }, + expected: true, + }, + { + name: "Rate limited - shadow mode", + task: &noopTask{ + noopTaskInfo: &noopTaskInfo{ + domainID: "test-domain-id", + }, + }, + setupMocks: func(deps *taskRateLimiterMockDeps) { + require.NoError(t, deps.dynamicClient.UpdateValue(dynamicconfig.TaskSchedulerGlobalDomainRPS, 100)) + require.NoError(t, deps.dynamicClient.UpdateValue(dynamicconfig.TaskSchedulerEnableRateLimiter, true)) + require.NoError(t, deps.dynamicClient.UpdateValue(dynamicconfig.TaskSchedulerEnableRateLimiterShadowMode, true)) + deps.mockDomainCache.EXPECT().GetDomainName("test-domain-id").Return("test-domain", nil) + limiter := quotas.NewMockLimiter(deps.ctrl) + deps.mockICollection.EXPECT().For("test-domain").Return(limiter) + limiter.EXPECT().Allow().Return(false) + }, + expected: true, + }, + { + name: "Rate limited", + task: &noopTask{ + noopTaskInfo: &noopTaskInfo{ + domainID: "test-domain-id", + }, + }, + setupMocks: func(deps *taskRateLimiterMockDeps) { + require.NoError(t, deps.dynamicClient.UpdateValue(dynamicconfig.TaskSchedulerGlobalDomainRPS, 100)) + require.NoError(t, deps.dynamicClient.UpdateValue(dynamicconfig.TaskSchedulerEnableRateLimiter, true)) + require.NoError(t, deps.dynamicClient.UpdateValue(dynamicconfig.TaskSchedulerEnableRateLimiterShadowMode, false)) + deps.mockDomainCache.EXPECT().GetDomainName("test-domain-id").Return("test-domain", nil) + limiter := quotas.NewMockLimiter(deps.ctrl) + deps.mockICollection.EXPECT().For("test-domain").Return(limiter) + limiter.EXPECT().Allow().Return(false) + }, + expected: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + r, deps := setupMocksForTaskRateLimiter(t, true) + + tc.setupMocks(deps) + + allow := r.Allow(tc.task) + assert.Equal(t, tc.expected, allow) + }) + } +} + +func TestRateLimiterWait(t *testing.T) { + testCases := []struct { + name string + task Task + setupMocks func(*taskRateLimiterMockDeps) + expectErr bool + expectedErr string + }{ + { + name: "Not rate limited", + task: &noopTask{ + noopTaskInfo: &noopTaskInfo{ + domainID: "test-domain-id", + }, + }, + setupMocks: func(deps *taskRateLimiterMockDeps) { + require.NoError(t, deps.dynamicClient.UpdateValue(dynamicconfig.TaskSchedulerGlobalDomainRPS, 100)) + require.NoError(t, deps.dynamicClient.UpdateValue(dynamicconfig.TaskSchedulerEnableRateLimiter, true)) + require.NoError(t, deps.dynamicClient.UpdateValue(dynamicconfig.TaskSchedulerEnableRateLimiterShadowMode, false)) + deps.mockDomainCache.EXPECT().GetDomainName("test-domain-id").Return("test-domain", nil) + limiter := quotas.NewMockLimiter(deps.ctrl) + rsv := clock.NewMockReservation(deps.ctrl) + deps.mockICollection.EXPECT().For("test-domain").Return(limiter) + limiter.EXPECT().Reserve().Return(rsv) + rsv.EXPECT().Allow().Return(true) + rsv.EXPECT().Used(true) + }, + expectErr: false, + }, + { + name: "Not rate limited - domain cache error ignored", + task: &noopTask{ + noopTaskInfo: &noopTaskInfo{ + domainID: "test-domain-id", + }, + }, + setupMocks: func(deps *taskRateLimiterMockDeps) { + require.NoError(t, deps.dynamicClient.UpdateValue(dynamicconfig.TaskSchedulerGlobalDomainRPS, 100)) + require.NoError(t, deps.dynamicClient.UpdateValue(dynamicconfig.TaskSchedulerEnableRateLimiter, true)) + require.NoError(t, deps.dynamicClient.UpdateValue(dynamicconfig.TaskSchedulerEnableRateLimiterShadowMode, false)) + deps.mockDomainCache.EXPECT().GetDomainName("test-domain-id").Return("", errors.New("cache error")) + limiter := quotas.NewMockLimiter(deps.ctrl) + rsv := clock.NewMockReservation(deps.ctrl) + deps.mockICollection.EXPECT().For("").Return(limiter) + limiter.EXPECT().Reserve().Return(rsv) + rsv.EXPECT().Allow().Return(true) + rsv.EXPECT().Used(true) + }, + expectErr: false, + }, + { + name: "Rate limited - shadow mode", + task: &noopTask{ + noopTaskInfo: &noopTaskInfo{ + domainID: "test-domain-id", + }, + }, + setupMocks: func(deps *taskRateLimiterMockDeps) { + require.NoError(t, deps.dynamicClient.UpdateValue(dynamicconfig.TaskSchedulerGlobalDomainRPS, 100)) + require.NoError(t, deps.dynamicClient.UpdateValue(dynamicconfig.TaskSchedulerEnableRateLimiter, true)) + require.NoError(t, deps.dynamicClient.UpdateValue(dynamicconfig.TaskSchedulerEnableRateLimiterShadowMode, true)) + deps.mockDomainCache.EXPECT().GetDomainName("test-domain-id").Return("test-domain", nil) + limiter := quotas.NewMockLimiter(deps.ctrl) + rsv := clock.NewMockReservation(deps.ctrl) + deps.mockICollection.EXPECT().For("test-domain").Return(limiter) + limiter.EXPECT().Reserve().Return(rsv) + rsv.EXPECT().Allow().Return(false) + rsv.EXPECT().Used(false) + }, + expectErr: false, + }, + { + name: "Rate limited - error", + task: &noopTask{ + noopTaskInfo: &noopTaskInfo{ + domainID: "test-domain-id", + }, + }, + setupMocks: func(deps *taskRateLimiterMockDeps) { + require.NoError(t, deps.dynamicClient.UpdateValue(dynamicconfig.TaskSchedulerGlobalDomainRPS, 100)) + require.NoError(t, deps.dynamicClient.UpdateValue(dynamicconfig.TaskSchedulerEnableRateLimiter, true)) + require.NoError(t, deps.dynamicClient.UpdateValue(dynamicconfig.TaskSchedulerEnableRateLimiterShadowMode, false)) + deps.mockDomainCache.EXPECT().GetDomainName("test-domain-id").Return("test-domain", nil) + limiter := quotas.NewMockLimiter(deps.ctrl) + rsv := clock.NewMockReservation(deps.ctrl) + deps.mockICollection.EXPECT().For("test-domain").Return(limiter) + limiter.EXPECT().Reserve().Return(rsv) + rsv.EXPECT().Allow().Return(false) + rsv.EXPECT().Used(false) + limiter.EXPECT().Wait(gomock.Any()).Return(errors.New("wait error")) + }, + expectErr: true, + expectedErr: "wait error", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + r, deps := setupMocksForTaskRateLimiter(t, true) + + tc.setupMocks(deps) + + err := r.Wait(context.Background(), tc.task) + if tc.expectErr { + assert.ErrorContains(t, err, tc.expectedErr) + } else { + assert.NoError(t, err) + } + }) + } +}