From 841a085c161d1558ecae4fb2d44f486040a0a198 Mon Sep 17 00:00:00 2001 From: Stephen Buttolph Date: Fri, 26 Jan 2024 12:46:32 -0500 Subject: [PATCH] Add cache validation for platform.GetValidatorsAt --- .../gvalidators/validator_state_client.go | 4 + snow/validators/mock_state.go | 14 ++++ snow/validators/state.go | 10 +++ snow/validators/test_state.go | 4 + snow/validators/traced_state.go | 4 + vms/platformvm/service.go | 20 ++++- vms/platformvm/validators/manager.go | 76 ++++++++++++++++--- vms/platformvm/validators/test_manager.go | 4 + 8 files changed, 126 insertions(+), 10 deletions(-) diff --git a/snow/validators/gvalidators/validator_state_client.go b/snow/validators/gvalidators/validator_state_client.go index 49fa1e64141..c1c5da3da56 100644 --- a/snow/validators/gvalidators/validator_state_client.go +++ b/snow/validators/gvalidators/validator_state_client.go @@ -93,3 +93,7 @@ func (c *Client) GetValidatorSet( } return vdrs, nil } + +func (*Client) ValidateCachedGetValidatorSet(context.Context, uint64, ids.ID) error { + return nil +} diff --git a/snow/validators/mock_state.go b/snow/validators/mock_state.go index 6bed638becd..ee202553df0 100644 --- a/snow/validators/mock_state.go +++ b/snow/validators/mock_state.go @@ -99,3 +99,17 @@ func (mr *MockStateMockRecorder) GetValidatorSet(arg0, arg1, arg2 any) *gomock.C mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatorSet", reflect.TypeOf((*MockState)(nil).GetValidatorSet), arg0, arg1, arg2) } + +// ValidateCachedGetValidatorSet mocks base method. +func (m *MockState) ValidateCachedGetValidatorSet(arg0 context.Context, arg1 uint64, arg2 ids.ID) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ValidateCachedGetValidatorSet", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// ValidateCachedGetValidatorSet indicates an expected call of ValidateCachedGetValidatorSet. +func (mr *MockStateMockRecorder) ValidateCachedGetValidatorSet(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateCachedGetValidatorSet", reflect.TypeOf((*MockState)(nil).ValidateCachedGetValidatorSet), arg0, arg1, arg2) +} diff --git a/snow/validators/state.go b/snow/validators/state.go index 3f92df35231..088b8f0f335 100644 --- a/snow/validators/state.go +++ b/snow/validators/state.go @@ -32,6 +32,12 @@ type State interface { height uint64, subnetID ids.ID, ) (map[ids.NodeID]*GetValidatorOutput, error) + + ValidateCachedGetValidatorSet( + ctx context.Context, + targetHeight uint64, + subnetID ids.ID, + ) error } type lockedState struct { @@ -78,6 +84,10 @@ func (s *lockedState) GetValidatorSet( return s.s.GetValidatorSet(ctx, height, subnetID) } +func (*lockedState) ValidateCachedGetValidatorSet(context.Context, uint64, ids.ID) error { + return nil +} + type noValidators struct { State } diff --git a/snow/validators/test_state.go b/snow/validators/test_state.go index ee4102cf719..48b9ff4752c 100644 --- a/snow/validators/test_state.go +++ b/snow/validators/test_state.go @@ -79,3 +79,7 @@ func (vm *TestState) GetValidatorSet( } return nil, errGetValidatorSet } + +func (*TestState) ValidateCachedGetValidatorSet(context.Context, uint64, ids.ID) error { + return nil +} diff --git a/snow/validators/traced_state.go b/snow/validators/traced_state.go index 126a2b009eb..fedb39117ba 100644 --- a/snow/validators/traced_state.go +++ b/snow/validators/traced_state.go @@ -73,3 +73,7 @@ func (s *tracedState) GetValidatorSet( return s.s.GetValidatorSet(ctx, height, subnetID) } + +func (*tracedState) ValidateCachedGetValidatorSet(context.Context, uint64, ids.ID) error { + return nil +} diff --git a/vms/platformvm/service.go b/vms/platformvm/service.go index 16e5b16844c..4f3355b642d 100644 --- a/vms/platformvm/service.go +++ b/vms/platformvm/service.go @@ -2671,6 +2671,15 @@ func (v *GetValidatorsAtReply) MarshalJSON() ([]byte, error) { m[vdr.NodeID] = vdrJSON } + if v.ErrorString != "" { + return stdjson.Marshal(struct { + Validators map[ids.NodeID]*jsonGetValidatorOutput + ErrorString string + }{ + Validators: m, + ErrorString: v.ErrorString, + }) + } return stdjson.Marshal(m) } @@ -2710,7 +2719,8 @@ func (v *GetValidatorsAtReply) UnmarshalJSON(b []byte) error { // GetValidatorsAtReply is the response from GetValidatorsAt type GetValidatorsAtReply struct { - Validators map[ids.NodeID]*validators.GetValidatorOutput + Validators map[ids.NodeID]*validators.GetValidatorOutput + ErrorString string } // GetValidatorsAt returns the weights of the validator set of a provided subnet @@ -2733,6 +2743,14 @@ func (s *Service) GetValidatorsAt(r *http.Request, args *GetValidatorsAtArgs, re if err != nil { return fmt.Errorf("failed to get validator set: %w", err) } + if err := s.vm.ValidateCachedGetValidatorSet(ctx, height, args.SubnetID); err != nil { + s.vm.ctx.Log.Error("invalid validator set", + zap.Stringer("subnetID", args.SubnetID), + zap.Uint64("height", height), + zap.Error(err), + ) + reply.ErrorString = err.Error() + } return nil } diff --git a/vms/platformvm/validators/manager.go b/vms/platformvm/validators/manager.go index 2c8b025a128..f014dc951ce 100644 --- a/vms/platformvm/validators/manager.go +++ b/vms/platformvm/validators/manager.go @@ -5,7 +5,9 @@ package validators import ( "context" + "errors" "fmt" + "reflect" "time" "github.com/ava-labs/avalanchego/cache" @@ -30,13 +32,23 @@ const ( recentlyAcceptedWindowTTL = 2 * time.Minute ) -var _ validators.State = (*manager)(nil) +var ( + _ validators.State = (*manager)(nil) + + errInconsistentValidatorSet = errors.New("inconsistent validator set") +) // Manager adds the ability to introduce newly accepted blocks IDs to the State // interface. type Manager interface { validators.State + ValidateCachedGetValidatorSet( + ctx context.Context, + targetHeight uint64, + subnetID ids.ID, + ) error + // OnAcceptedBlockID registers the ID of the latest accepted block. // It is used to update the [recentlyAccepted] sliding window. OnAcceptedBlockID(blkID ids.ID) @@ -97,7 +109,7 @@ func NewManager( state: state, metrics: metrics, clk: clk, - caches: make(map[ids.ID]cache.Cacher[uint64, map[ids.NodeID]*validators.GetValidatorOutput]), + caches: make(map[ids.ID]cache.Cacher[uint64, *cachedValidatorSet]), recentlyAccepted: window.New[ids.ID]( window.Config{ Clock: clk, @@ -121,12 +133,17 @@ type manager struct { // Maps caches for each subnet that is currently tracked. // Key: Subnet ID // Value: cache mapping height -> validator set map - caches map[ids.ID]cache.Cacher[uint64, map[ids.NodeID]*validators.GetValidatorOutput] + caches map[ids.ID]cache.Cacher[uint64, *cachedValidatorSet] // sliding window of blocks that were recently accepted recentlyAccepted window.Window[ids.ID] } +type cachedValidatorSet struct { + validatorSet map[ids.NodeID]*validators.GetValidatorOutput + calculatedHeight uint64 +} + // GetMinimumHeight returns the height of the most recent block beyond the // horizon of our recentlyAccepted window. // @@ -187,10 +204,9 @@ func (m *manager) GetValidatorSet( subnetID ids.ID, ) (map[ids.NodeID]*validators.GetValidatorOutput, error) { validatorSetsCache := m.getValidatorSetCache(subnetID) - if validatorSet, ok := validatorSetsCache.Get(targetHeight); ok { m.metrics.IncValidatorSetsCached() - return validatorSet, nil + return validatorSet.validatorSet, nil } // get the start time to track metrics @@ -211,7 +227,10 @@ func (m *manager) GetValidatorSet( } // cache the validator set - validatorSetsCache.Put(targetHeight, validatorSet) + validatorSetsCache.Put(targetHeight, &cachedValidatorSet{ + validatorSet: validatorSet, + calculatedHeight: currentHeight, + }) duration := m.clk.Time().Sub(startTime) m.metrics.IncValidatorSetsCreated() @@ -220,10 +239,49 @@ func (m *manager) GetValidatorSet( return validatorSet, nil } -func (m *manager) getValidatorSetCache(subnetID ids.ID) cache.Cacher[uint64, map[ids.NodeID]*validators.GetValidatorOutput] { +func (m *manager) ValidateCachedGetValidatorSet( + ctx context.Context, + targetHeight uint64, + subnetID ids.ID, +) error { + validatorSetsCache := m.getValidatorSetCache(subnetID) + cachedValidatorSet, ok := validatorSetsCache.Get(targetHeight) + if !ok { + // If the validator set isn't cached, then there is nothing to check. + return nil + } + + var ( + validatorSet map[ids.NodeID]*validators.GetValidatorOutput + currentHeight uint64 + err error + ) + if subnetID == constants.PrimaryNetworkID { + validatorSet, currentHeight, err = m.makePrimaryNetworkValidatorSet(ctx, targetHeight) + } else { + validatorSet, currentHeight, err = m.makeSubnetValidatorSet(ctx, targetHeight, subnetID) + } + if err != nil { + return err + } + + if reflect.DeepEqual(cachedValidatorSet.validatorSet, validatorSet) { + return nil + } + + return fmt.Errorf("%w calculated for %s:%d at %d and %d", + errInconsistentValidatorSet, + subnetID, + targetHeight, + cachedValidatorSet.calculatedHeight, + currentHeight, + ) +} + +func (m *manager) getValidatorSetCache(subnetID ids.ID) cache.Cacher[uint64, *cachedValidatorSet] { // Only cache tracked subnets if subnetID != constants.PrimaryNetworkID && !m.cfg.TrackedSubnets.Contains(subnetID) { - return &cache.Empty[uint64, map[ids.NodeID]*validators.GetValidatorOutput]{} + return &cache.Empty[uint64, *cachedValidatorSet]{} } validatorSetsCache, exists := m.caches[subnetID] @@ -231,7 +289,7 @@ func (m *manager) getValidatorSetCache(subnetID ids.ID) cache.Cacher[uint64, map return validatorSetsCache } - validatorSetsCache = &cache.LRU[uint64, map[ids.NodeID]*validators.GetValidatorOutput]{ + validatorSetsCache = &cache.LRU[uint64, *cachedValidatorSet]{ Size: validatorSetsCacheSize, } m.caches[subnetID] = validatorSetsCache diff --git a/vms/platformvm/validators/test_manager.go b/vms/platformvm/validators/test_manager.go index e04742f265c..bb6bb6e6bc2 100644 --- a/vms/platformvm/validators/test_manager.go +++ b/vms/platformvm/validators/test_manager.go @@ -30,4 +30,8 @@ func (testManager) GetValidatorSet(context.Context, uint64, ids.ID) (map[ids.Nod return nil, nil } +func (testManager) ValidateCachedGetValidatorSet(context.Context, uint64, ids.ID) error { + return nil +} + func (testManager) OnAcceptedBlockID(ids.ID) {}