diff --git a/utils/iterator/slice.go b/utils/iterator/slice.go index d195a1adc14..a7b18189aab 100644 --- a/utils/iterator/slice.go +++ b/utils/iterator/slice.go @@ -5,6 +5,18 @@ package iterator var _ Iterator[any] = (*slice[any])(nil) +// ToSlice returns a slice that contains all of the elements from [it] in order. +// [it] will be released before returning. +func ToSlice[T any](it Iterator[T]) []T { + defer it.Release() + + var elements []T + for it.Next() { + elements = append(elements, it.Value()) + } + return elements +} + type slice[T any] struct { index int elements []T diff --git a/vms/platformvm/state/diff_test.go b/vms/platformvm/state/diff_test.go index a7eec42364b..c4246b8aaee 100644 --- a/vms/platformvm/state/diff_test.go +++ b/vms/platformvm/state/diff_test.go @@ -15,6 +15,7 @@ import ( "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/utils" "github.com/ava-labs/avalanchego/utils/constants" + "github.com/ava-labs/avalanchego/utils/iterator" "github.com/ava-labs/avalanchego/utils/iterator/iteratormock" "github.com/ava-labs/avalanchego/vms/components/avax" "github.com/ava-labs/avalanchego/vms/components/gas" @@ -530,14 +531,20 @@ func assertChainsEqual(t *testing.T, expected, actual Chain) { actualCurrentStakerIterator, actualErr := actual.GetCurrentStakerIterator() require.Equal(expectedErr, actualErr) if expectedErr == nil { - assertIteratorsEqual(t, expectedCurrentStakerIterator, actualCurrentStakerIterator) + require.Equal( + iterator.ToSlice(expectedCurrentStakerIterator), + iterator.ToSlice(actualCurrentStakerIterator), + ) } expectedPendingStakerIterator, expectedErr := expected.GetPendingStakerIterator() actualPendingStakerIterator, actualErr := actual.GetPendingStakerIterator() require.Equal(expectedErr, actualErr) if expectedErr == nil { - assertIteratorsEqual(t, expectedPendingStakerIterator, actualPendingStakerIterator) + require.Equal( + iterator.ToSlice(expectedPendingStakerIterator), + iterator.ToSlice(actualPendingStakerIterator), + ) } require.Equal(expected.GetTimestamp(), actual.GetTimestamp()) diff --git a/vms/platformvm/state/stakers_test.go b/vms/platformvm/state/stakers_test.go index d536b2a719d..8141959d80a 100644 --- a/vms/platformvm/state/stakers_test.go +++ b/vms/platformvm/state/stakers_test.go @@ -86,7 +86,10 @@ func TestBaseStakersValidator(t *testing.T) { require.ErrorIs(err, database.ErrNotFound) stakerIterator := v.GetStakerIterator() - assertIteratorsEqual(t, iterator.FromSlice(delegator), stakerIterator) + require.Equal( + []*Staker{delegator}, + iterator.ToSlice(stakerIterator), + ) v.PutValidator(staker) @@ -97,7 +100,10 @@ func TestBaseStakersValidator(t *testing.T) { v.DeleteDelegator(delegator) stakerIterator = v.GetStakerIterator() - assertIteratorsEqual(t, iterator.FromSlice(staker), stakerIterator) + require.Equal( + []*Staker{staker}, + iterator.ToSlice(stakerIterator), + ) v.DeleteValidator(staker) @@ -105,30 +111,42 @@ func TestBaseStakersValidator(t *testing.T) { require.ErrorIs(err, database.ErrNotFound) stakerIterator = v.GetStakerIterator() - assertIteratorsEqual(t, iterator.Empty[*Staker]{}, stakerIterator) + require.Empty( + iterator.ToSlice(stakerIterator), + ) } func TestBaseStakersDelegator(t *testing.T) { + require := require.New(t) staker := newTestStaker() delegator := newTestStaker() v := newBaseStakers() delegatorIterator := v.GetDelegatorIterator(delegator.SubnetID, delegator.NodeID) - assertIteratorsEqual(t, iterator.Empty[*Staker]{}, delegatorIterator) + require.Empty( + iterator.ToSlice(delegatorIterator), + ) v.PutDelegator(delegator) delegatorIterator = v.GetDelegatorIterator(delegator.SubnetID, ids.GenerateTestNodeID()) - assertIteratorsEqual(t, iterator.Empty[*Staker]{}, delegatorIterator) + require.Empty( + iterator.ToSlice(delegatorIterator), + ) delegatorIterator = v.GetDelegatorIterator(delegator.SubnetID, delegator.NodeID) - assertIteratorsEqual(t, iterator.FromSlice(delegator), delegatorIterator) + require.Equal( + []*Staker{delegator}, + iterator.ToSlice(delegatorIterator), + ) v.DeleteDelegator(delegator) delegatorIterator = v.GetDelegatorIterator(delegator.SubnetID, delegator.NodeID) - assertIteratorsEqual(t, iterator.Empty[*Staker]{}, delegatorIterator) + require.Empty( + iterator.ToSlice(delegatorIterator), + ) v.PutValidator(staker) @@ -136,7 +154,9 @@ func TestBaseStakersDelegator(t *testing.T) { v.DeleteDelegator(delegator) delegatorIterator = v.GetDelegatorIterator(staker.SubnetID, staker.NodeID) - assertIteratorsEqual(t, iterator.Empty[*Staker]{}, delegatorIterator) + require.Empty( + iterator.ToSlice(delegatorIterator), + ) } func TestDiffStakersValidator(t *testing.T) { @@ -160,7 +180,10 @@ func TestDiffStakersValidator(t *testing.T) { require.Equal(unmodified, status) stakerIterator := v.GetStakerIterator(iterator.Empty[*Staker]{}) - assertIteratorsEqual(t, iterator.FromSlice(delegator), stakerIterator) + require.Equal( + []*Staker{delegator}, + iterator.ToSlice(stakerIterator), + ) require.NoError(v.PutValidator(staker)) @@ -177,7 +200,10 @@ func TestDiffStakersValidator(t *testing.T) { require.Equal(unmodified, status) stakerIterator = v.GetStakerIterator(iterator.Empty[*Staker]{}) - assertIteratorsEqual(t, iterator.FromSlice(delegator), stakerIterator) + require.Equal( + []*Staker{delegator}, + iterator.ToSlice(stakerIterator), + ) } func TestDiffStakersDeleteValidator(t *testing.T) { @@ -198,25 +224,33 @@ func TestDiffStakersDeleteValidator(t *testing.T) { } func TestDiffStakersDelegator(t *testing.T) { + require := require.New(t) staker := newTestStaker() delegator := newTestStaker() v := diffStakers{} - require.NoError(t, v.PutValidator(staker)) + require.NoError(v.PutValidator(staker)) delegatorIterator := v.GetDelegatorIterator(iterator.Empty[*Staker]{}, ids.GenerateTestID(), delegator.NodeID) - assertIteratorsEqual(t, iterator.Empty[*Staker]{}, delegatorIterator) + require.Empty( + iterator.ToSlice(delegatorIterator), + ) v.PutDelegator(delegator) delegatorIterator = v.GetDelegatorIterator(iterator.Empty[*Staker]{}, delegator.SubnetID, delegator.NodeID) - assertIteratorsEqual(t, iterator.FromSlice(delegator), delegatorIterator) + require.Equal( + []*Staker{delegator}, + iterator.ToSlice(delegatorIterator), + ) v.DeleteDelegator(delegator) delegatorIterator = v.GetDelegatorIterator(iterator.Empty[*Staker]{}, ids.GenerateTestID(), delegator.NodeID) - assertIteratorsEqual(t, iterator.Empty[*Staker]{}, delegatorIterator) + require.Empty( + iterator.ToSlice(delegatorIterator), + ) } func newTestStaker() *Staker { @@ -235,22 +269,3 @@ func newTestStaker() *Staker { Priority: txs.PrimaryNetworkDelegatorCurrentPriority, } } - -func assertIteratorsEqual(t *testing.T, expected, actual iterator.Iterator[*Staker]) { - require := require.New(t) - - t.Helper() - - for expected.Next() { - require.True(actual.Next()) - - expectedStaker := expected.Value() - actualStaker := actual.Value() - - require.Equal(expectedStaker, actualStaker) - } - require.False(actual.Next()) - - expected.Release() - actual.Release() -} diff --git a/vms/platformvm/state/state_test.go b/vms/platformvm/state/state_test.go index 3be526333e3..515794364d7 100644 --- a/vms/platformvm/state/state_test.go +++ b/vms/platformvm/state/state_test.go @@ -88,18 +88,25 @@ func TestStateSyncGenesis(t *testing.T) { delegatorIterator, err := state.GetCurrentDelegatorIterator(constants.PrimaryNetworkID, defaultValidatorNodeID) require.NoError(err) - assertIteratorsEqual(t, iterator.Empty[*Staker]{}, delegatorIterator) + require.Empty( + iterator.ToSlice(delegatorIterator), + ) stakerIterator, err := state.GetCurrentStakerIterator() require.NoError(err) - assertIteratorsEqual(t, iterator.FromSlice(staker), stakerIterator) + require.Equal( + []*Staker{staker}, + iterator.ToSlice(stakerIterator), + ) _, err = state.GetPendingValidator(constants.PrimaryNetworkID, defaultValidatorNodeID) require.ErrorIs(err, database.ErrNotFound) delegatorIterator, err = state.GetPendingDelegatorIterator(constants.PrimaryNetworkID, defaultValidatorNodeID) require.NoError(err) - assertIteratorsEqual(t, iterator.Empty[*Staker]{}, delegatorIterator) + require.Empty( + iterator.ToSlice(delegatorIterator), + ) } // Whenever we store a staker, a whole bunch a data structures are updated diff --git a/vms/platformvm/state/subnet_only_validator.go b/vms/platformvm/state/subnet_only_validator.go new file mode 100644 index 00000000000..5af028314e6 --- /dev/null +++ b/vms/platformvm/state/subnet_only_validator.go @@ -0,0 +1,96 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package state + +import ( + "fmt" + + "github.com/google/btree" + + "github.com/ava-labs/avalanchego/database" + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/vms/platformvm/block" +) + +var _ btree.LessFunc[*SubnetOnlyValidator] = (*SubnetOnlyValidator).Less + +type SubnetOnlyValidator struct { + // ValidationID is not serialized because it is used as the key in the + // database, so it doesn't need to be stored in the value. + ValidationID ids.ID + + SubnetID ids.ID `serialize:"true"` + NodeID ids.NodeID `serialize:"true"` + + // PublicKey is the uncompressed BLS public key of the validator. It is + // guaranteed to be populated. + PublicKey []byte `serialize:"true"` + + // StartTime is the unix timestamp, in seconds, when this validator was + // added to the set. + StartTime uint64 `serialize:"true"` + + // Weight of this validator. It can be updated when the MinNonce is + // increased. If the weight is being set to 0, the validator is being + // removed. + Weight uint64 `serialize:"true"` + + // MinNonce is the smallest nonce that can be used to modify this + // validator's weight. It is initially set to 0 and is set to one higher + // than the last nonce used. It is not valid to use nonce MaxUint64 unless + // the weight is being set to 0, which removes the validator from the set. + MinNonce uint64 `serialize:"true"` + + // EndAccumulatedFee is the amount of globally accumulated fees that can + // accrue before this validator must be deactivated. It is equal to the + // amount of fees this validator is willing to pay plus the amount of + // globally accumulated fees when this validator started validating. + EndAccumulatedFee uint64 `serialize:"true"` +} + +// Less determines a canonical ordering of *SubnetOnlyValidators based on their +// EndAccumulatedFees and ValidationIDs. +// +// Returns true if: +// +// 1. This validator has a lower EndAccumulatedFee than the other. +// 2. This validator has an equal EndAccumulatedFee to the other and has a +// lexicographically lower ValidationID. +func (v *SubnetOnlyValidator) Less(o *SubnetOnlyValidator) bool { + switch { + case v.EndAccumulatedFee < o.EndAccumulatedFee: + return true + case o.EndAccumulatedFee < v.EndAccumulatedFee: + return false + default: + return v.ValidationID.Compare(o.ValidationID) == -1 + } +} + +func getSubnetOnlyValidator(db database.KeyValueReader, validationID ids.ID) (*SubnetOnlyValidator, error) { + bytes, err := db.Get(validationID[:]) + if err != nil { + return nil, err + } + + vdr := &SubnetOnlyValidator{ + ValidationID: validationID, + } + if _, err = block.GenesisCodec.Unmarshal(bytes, vdr); err != nil { + return nil, fmt.Errorf("failed to unmarshal SubnetOnlyValidator: %w", err) + } + return vdr, err +} + +func putSubnetOnlyValidator(db database.KeyValueWriter, vdr *SubnetOnlyValidator) error { + bytes, err := block.GenesisCodec.Marshal(block.CodecVersion, vdr) + if err != nil { + return fmt.Errorf("failed to marshal SubnetOnlyValidator: %w", err) + } + return db.Put(vdr.ValidationID[:], bytes) +} + +func deleteSubnetOnlyValidator(db database.KeyValueDeleter, validationID ids.ID) error { + return db.Delete(validationID[:]) +} diff --git a/vms/platformvm/state/subnet_only_validator_test.go b/vms/platformvm/state/subnet_only_validator_test.go new file mode 100644 index 00000000000..bcbb21e0027 --- /dev/null +++ b/vms/platformvm/state/subnet_only_validator_test.go @@ -0,0 +1,113 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package state + +import ( + "math/rand" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/ava-labs/avalanchego/database" + "github.com/ava-labs/avalanchego/database/memdb" + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/utils/crypto/bls" +) + +func TestSubnetOnlyValidator_Less(t *testing.T) { + tests := []struct { + name string + v *SubnetOnlyValidator + o *SubnetOnlyValidator + equal bool + }{ + { + name: "v.EndAccumulatedFee < o.EndAccumulatedFee", + v: &SubnetOnlyValidator{ + ValidationID: ids.GenerateTestID(), + EndAccumulatedFee: 1, + }, + o: &SubnetOnlyValidator{ + ValidationID: ids.GenerateTestID(), + EndAccumulatedFee: 2, + }, + equal: false, + }, + { + name: "v.EndAccumulatedFee = o.EndAccumulatedFee, v.ValidationID < o.ValidationID", + v: &SubnetOnlyValidator{ + ValidationID: ids.ID{0}, + EndAccumulatedFee: 1, + }, + o: &SubnetOnlyValidator{ + ValidationID: ids.ID{1}, + EndAccumulatedFee: 1, + }, + equal: false, + }, + { + name: "v.EndAccumulatedFee = o.EndAccumulatedFee, v.ValidationID = o.ValidationID", + v: &SubnetOnlyValidator{ + ValidationID: ids.ID{0}, + EndAccumulatedFee: 1, + }, + o: &SubnetOnlyValidator{ + ValidationID: ids.ID{0}, + EndAccumulatedFee: 1, + }, + equal: true, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require := require.New(t) + + less := test.v.Less(test.o) + require.Equal(!test.equal, less) + + greater := test.o.Less(test.v) + require.False(greater) + }) + } +} + +func TestSubnetOnlyValidator_DatabaseHelpers(t *testing.T) { + require := require.New(t) + db := memdb.New() + + sk, err := bls.NewSecretKey() + require.NoError(err) + + vdr := &SubnetOnlyValidator{ + ValidationID: ids.GenerateTestID(), + SubnetID: ids.GenerateTestID(), + NodeID: ids.GenerateTestNodeID(), + PublicKey: bls.PublicKeyToUncompressedBytes(bls.PublicFromSecretKey(sk)), + StartTime: rand.Uint64(), // #nosec G404 + Weight: rand.Uint64(), // #nosec G404 + MinNonce: rand.Uint64(), // #nosec G404 + EndAccumulatedFee: rand.Uint64(), // #nosec G404 + } + + // Validator hasn't been put on disk yet + gotVdr, err := getSubnetOnlyValidator(db, vdr.ValidationID) + require.ErrorIs(err, database.ErrNotFound) + require.Nil(gotVdr) + + // Place the validator on disk + require.NoError(putSubnetOnlyValidator(db, vdr)) + + // Verify that the validator can be fetched from disk + gotVdr, err = getSubnetOnlyValidator(db, vdr.ValidationID) + require.NoError(err) + require.Equal(vdr, gotVdr) + + // Remove the validator from disk + require.NoError(deleteSubnetOnlyValidator(db, vdr.ValidationID)) + + // Verify that the validator has been removed from disk + gotVdr, err = getSubnetOnlyValidator(db, vdr.ValidationID) + require.ErrorIs(err, database.ErrNotFound) + require.Nil(gotVdr) +}