diff --git a/consensus/consortium/common/constants.go b/consensus/consortium/common/constants.go index 84da1c44f5..6363311d20 100644 --- a/consensus/consortium/common/constants.go +++ b/consensus/consortium/common/constants.go @@ -7,8 +7,9 @@ import ( ) const ( - ExtraSeal = crypto.SignatureLength - ExtraVanity = 32 + ExtraSeal = crypto.SignatureLength + ExtraVanity = 32 + MaxFinalityVotePercentage uint16 = 10_000 ) var ( diff --git a/consensus/consortium/common/contract.go b/consensus/consortium/common/contract.go index 9bd5adbc8f..c5cbe0f0be 100644 --- a/consensus/consortium/common/contract.go +++ b/consensus/consortium/common/contract.go @@ -19,6 +19,7 @@ import ( "github.com/ethereum/go-ethereum/consensus/consortium/generated_contracts/profile" roninValidatorSet "github.com/ethereum/go-ethereum/consensus/consortium/generated_contracts/ronin_validator_set" slashIndicator "github.com/ethereum/go-ethereum/consensus/consortium/generated_contracts/slash_indicator" + "github.com/ethereum/go-ethereum/consensus/consortium/generated_contracts/staking" "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/vm" @@ -61,6 +62,7 @@ type ContractInteraction interface { Slash(opts *ApplyTransactOpts, spoiledValidator common.Address) error FinalityReward(opts *ApplyTransactOpts, votedValidators []common.Address) error GetBlsPublicKey(blockNumber *big.Int, validator common.Address) (blsCommon.PublicKey, error) + GetStakedAmount(blockNumber *big.Int, validators []common.Address) ([]*big.Int, error) } // ContractIntegrator is a contract facing to interact with smart contract that supports DPoS @@ -71,6 +73,7 @@ type ContractIntegrator struct { slashIndicatorSC *slashIndicator.SlashIndicator profileSC *profile.Profile finalityTrackingSC *finalityTracking.FinalityTracking + stakingSC *staking.Staking signTxFn SignerTxFn coinbase common.Address } @@ -101,12 +104,19 @@ func NewContractIntegrator(config *chainParams.ChainConfig, backend bind.Contrac return nil, err } + // Create Staking contract instance + stakingSC, err := staking.NewStaking(config.ConsortiumV2Contracts.StakingContract, backend) + if err != nil { + return nil, err + } + return &ContractIntegrator{ chainId: config.ChainID, roninValidatorSetSC: roninValidatorSetSC, slashIndicatorSC: slashIndicatorSC, profileSC: profileSC, finalityTrackingSC: finalityTrackingSC, + stakingSC: stakingSC, signTxFn: signTxFn, signer: types.LatestSignerForChainID(config.ChainID), coinbase: coinbase, @@ -265,6 +275,19 @@ func (c *ContractIntegrator) GetBlsPublicKey(blockNumber *big.Int, validator com return blsPublicKey, nil } +func (c *ContractIntegrator) GetStakedAmount(blockNumber *big.Int, validators []common.Address) ([]*big.Int, error) { + callOpts := bind.CallOpts{ + BlockNumber: blockNumber, + } + + stakedAmount, err := c.stakingSC.GetManyStakingTotals(&callOpts, validators) + if err != nil { + return nil, err + } + + return stakedAmount, nil +} + // ApplyMessageOpts is the collection of options to fine tune a contract call request. type ApplyMessageOpts struct { State *state.StateDB diff --git a/consensus/consortium/common/mock_contract.go b/consensus/consortium/common/mock_contract.go index 14426b67f2..5799a5cf68 100644 --- a/consensus/consortium/common/mock_contract.go +++ b/consensus/consortium/common/mock_contract.go @@ -14,7 +14,7 @@ import ( var Validators *MockValidators type MockValidators struct { - validators []common.Address + validators []common.Address blsPublicKeys map[common.Address]blsCommon.PublicKey } @@ -25,7 +25,7 @@ func SetMockValidators(validators, publicKeys string) error { return errors.New("mismatch length between mock validators and mock blsPubKey") } Validators = &MockValidators{ - validators: make([]common.Address, len(vals)), + validators: make([]common.Address, len(vals)), blsPublicKeys: make(map[common.Address]blsCommon.PublicKey), } for i, val := range vals { @@ -80,3 +80,7 @@ func (contract *MockContract) FinalityReward(*ApplyTransactOpts, []common.Addres func (contract *MockContract) GetBlsPublicKey(_ *big.Int, addr common.Address) (blsCommon.PublicKey, error) { return Validators.GetPublicKey(addr) } + +func (contract *MockContract) GetStakedAmount(_ *big.Int, _ []common.Address) ([]*big.Int, error) { + return nil, nil +} diff --git a/consensus/consortium/common/utils.go b/consensus/consortium/common/utils.go index dbbdf8cdf0..9740114c34 100644 --- a/consensus/consortium/common/utils.go +++ b/consensus/consortium/common/utils.go @@ -1,8 +1,10 @@ package common import ( - "github.com/ethereum/go-ethereum/common" + "math/big" "sort" + + "github.com/ethereum/go-ethereum/common" ) // ExtractAddressFromBytes extracts validators' address from extra data in header @@ -60,3 +62,50 @@ func RemoveOutdatedRecents(recents map[uint64]common.Address, currentBlock uint6 return newRecents } + +// 1. The vote weight of each validator is validator pool's staked amount / total staked of all validator's pools +// 2. If the vote weight of a validator is higher than 1 / n, then the vote weight is 1 / n with n is the number +// of validators +// 3. After the step 2, the total vote weight might be lower than 1. Normalize the vote weight to make total vote +// weight is 1 (new vote weight = current vote weight / current total vote weight) (after this step, the total vote +// weight might not be 1 due to precision problem, but it is neglectible with small n) +// +// For vote weight, we don't use floating pointer number but multiply the vote weight with MaxFinalityVotePercentage +// and store vote weight in integer type. The precision of calculation is based on MaxFinalityVotePercentage. +func NormalizeFinalityVoteWeight(stakedAmounts []*big.Int) []uint16 { + var ( + totalStakedAmount = big.NewInt(0) + finalityVoteWeight []uint16 + maxVoteWeight uint16 + totalVoteWeight uint + ) + + // Calculate the maximum vote weight of each validator for step 2 + // 1 * MaxFinalityVotePercentage / n + maxVoteWeight = MaxFinalityVotePercentage / uint16(len(stakedAmounts)) + + for _, stakedAmount := range stakedAmounts { + totalStakedAmount.Add(totalStakedAmount, stakedAmount) + } + + // Step 1, 2 + for _, stakedAmount := range stakedAmounts { + weight := new(big.Int).Mul(stakedAmount, big.NewInt(int64(MaxFinalityVotePercentage))) + weight.Div(weight, totalStakedAmount) + + w := uint16(weight.Uint64()) + if w > maxVoteWeight { + w = maxVoteWeight + } + totalVoteWeight += uint(w) + finalityVoteWeight = append(finalityVoteWeight, w) + } + + // Step 3 + for i, weight := range finalityVoteWeight { + normalizedWeight := uint16(uint(weight) * uint(MaxFinalityVotePercentage) / totalVoteWeight) + finalityVoteWeight[i] = normalizedWeight + } + + return finalityVoteWeight +} diff --git a/consensus/consortium/common/utils_test.go b/consensus/consortium/common/utils_test.go index 0c2880aef8..f057f592ac 100644 --- a/consensus/consortium/common/utils_test.go +++ b/consensus/consortium/common/utils_test.go @@ -1,6 +1,7 @@ package common import ( + "math/big" "reflect" "testing" @@ -65,3 +66,61 @@ func TestRemoveInvalidRecents(t *testing.T) { t.Errorf("Expect %v but got %v", expected, actual) } } + +func TestNormalizeFinalityVoteWeight(t *testing.T) { + // All staked amounts are equal + var stakedAmounts []*big.Int + for i := 0; i < 22; i++ { + stakedAmounts = append(stakedAmounts, big.NewInt(1_000_000)) + } + + voteWeights := NormalizeFinalityVoteWeight(stakedAmounts) + for _, voteWeight := range voteWeights { + if voteWeight != 454 { + t.Fatalf("Incorrect vote weight, expect %d got %d", 454, voteWeight) + } + } + + // All staked amount differs + for i := 0; i < 22; i++ { + stakedAmounts[i] = big.NewInt(int64(i) + 1) + } + voteWeights = NormalizeFinalityVoteWeight(stakedAmounts) + expectedVoteWeights := []uint16{51, 103, 155, 207, 259, 311, 363, 415, 467, 519, 571, 597, 597, 597, 597, 597, 597, 597, 597, 597, 597, 597} + + for i := range voteWeights { + if voteWeights[i] != expectedVoteWeights[i] { + t.Fatalf("Incorrect vote weight, expect %d got %d", expectedVoteWeights[i], voteWeights[i]) + } + } + + // Staked amount differences are small + for i := 0; i < 22; i++ { + stakedAmounts[i] = big.NewInt(int64(i) + 1_000_000) + } + voteWeights = NormalizeFinalityVoteWeight(stakedAmounts) + for i := range voteWeights { + if voteWeights[i] != 454 { + t.Fatalf("Incorrect vote weight, expect %d got %d", 454, voteWeights[i]) + } + } + + // Some staked amounts differ greatly + for i := 0; i < 20; i++ { + stakedAmounts[i] = big.NewInt(1_000_000) + } + stakedAmounts[20] = big.NewInt(1000) + stakedAmounts[21] = big.NewInt(2500) + voteWeights = NormalizeFinalityVoteWeight(stakedAmounts) + for i := 0; i < 20; i++ { + if voteWeights[i] != 499 { + t.Fatalf("Incorrect vote weight, expect %d got %d", 499, voteWeights[i]) + } + } + if voteWeights[20] != 0 { + t.Fatalf("Incorrect vote weight, expect %d got %d", 0, voteWeights[20]) + } + if voteWeights[21] != 1 { + t.Fatalf("Incorrect vote weight, expect %d got %d", 1, voteWeights[21]) + } +} diff --git a/consensus/consortium/generated_contracts/staking/staking.go b/consensus/consortium/generated_contracts/staking/staking.go new file mode 100644 index 0000000000..3fe18fca8d --- /dev/null +++ b/consensus/consortium/generated_contracts/staking/staking.go @@ -0,0 +1,211 @@ +// Code generated - DO NOT EDIT. +// This file is a generated binding and any manual changes will be lost. + +package staking + +import ( + "errors" + "math/big" + "strings" + + ethereum "github.com/ethereum/go-ethereum" + "github.com/ethereum/go-ethereum/accounts/abi" + "github.com/ethereum/go-ethereum/accounts/abi/bind" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/event" +) + +// Reference imports to suppress errors if they are not otherwise used. +var ( + _ = errors.New + _ = big.NewInt + _ = strings.NewReader + _ = ethereum.NotFound + _ = bind.Bind + _ = common.Big1 + _ = types.BloomLookup + _ = event.NewSubscription +) + +// StakingMetaData contains all meta data concerning the Staking contract. +var StakingMetaData = &bind.MetaData{ + ABI: "[{\"inputs\":[{\"internalType\":\"TConsensus[]\",\"name\":\"consensusAddrs\",\"type\":\"address[]\"}],\"name\":\"getManyStakingTotals\",\"outputs\":[{\"internalType\":\"uint256[]\",\"name\":\"stakingAmounts_\",\"type\":\"uint256[]\"}],\"stateMutability\":\"view\",\"type\":\"function\"}]", +} + +// StakingABI is the input ABI used to generate the binding from. +// Deprecated: Use StakingMetaData.ABI instead. +var StakingABI = StakingMetaData.ABI + +// Staking is an auto generated Go binding around an Ethereum contract. +type Staking struct { + StakingCaller // Read-only binding to the contract + StakingTransactor // Write-only binding to the contract + StakingFilterer // Log filterer for contract events +} + +// StakingCaller is an auto generated read-only Go binding around an Ethereum contract. +type StakingCaller struct { + contract *bind.BoundContract // Generic contract wrapper for the low level calls +} + +// StakingTransactor is an auto generated write-only Go binding around an Ethereum contract. +type StakingTransactor struct { + contract *bind.BoundContract // Generic contract wrapper for the low level calls +} + +// StakingFilterer is an auto generated log filtering Go binding around an Ethereum contract events. +type StakingFilterer struct { + contract *bind.BoundContract // Generic contract wrapper for the low level calls +} + +// StakingSession is an auto generated Go binding around an Ethereum contract, +// with pre-set call and transact options. +type StakingSession struct { + Contract *Staking // Generic contract binding to set the session for + CallOpts bind.CallOpts // Call options to use throughout this session + TransactOpts bind.TransactOpts // Transaction auth options to use throughout this session +} + +// StakingCallerSession is an auto generated read-only Go binding around an Ethereum contract, +// with pre-set call options. +type StakingCallerSession struct { + Contract *StakingCaller // Generic contract caller binding to set the session for + CallOpts bind.CallOpts // Call options to use throughout this session +} + +// StakingTransactorSession is an auto generated write-only Go binding around an Ethereum contract, +// with pre-set transact options. +type StakingTransactorSession struct { + Contract *StakingTransactor // Generic contract transactor binding to set the session for + TransactOpts bind.TransactOpts // Transaction auth options to use throughout this session +} + +// StakingRaw is an auto generated low-level Go binding around an Ethereum contract. +type StakingRaw struct { + Contract *Staking // Generic contract binding to access the raw methods on +} + +// StakingCallerRaw is an auto generated low-level read-only Go binding around an Ethereum contract. +type StakingCallerRaw struct { + Contract *StakingCaller // Generic read-only contract binding to access the raw methods on +} + +// StakingTransactorRaw is an auto generated low-level write-only Go binding around an Ethereum contract. +type StakingTransactorRaw struct { + Contract *StakingTransactor // Generic write-only contract binding to access the raw methods on +} + +// NewStaking creates a new instance of Staking, bound to a specific deployed contract. +func NewStaking(address common.Address, backend bind.ContractBackend) (*Staking, error) { + contract, err := bindStaking(address, backend, backend, backend) + if err != nil { + return nil, err + } + return &Staking{StakingCaller: StakingCaller{contract: contract}, StakingTransactor: StakingTransactor{contract: contract}, StakingFilterer: StakingFilterer{contract: contract}}, nil +} + +// NewStakingCaller creates a new read-only instance of Staking, bound to a specific deployed contract. +func NewStakingCaller(address common.Address, caller bind.ContractCaller) (*StakingCaller, error) { + contract, err := bindStaking(address, caller, nil, nil) + if err != nil { + return nil, err + } + return &StakingCaller{contract: contract}, nil +} + +// NewStakingTransactor creates a new write-only instance of Staking, bound to a specific deployed contract. +func NewStakingTransactor(address common.Address, transactor bind.ContractTransactor) (*StakingTransactor, error) { + contract, err := bindStaking(address, nil, transactor, nil) + if err != nil { + return nil, err + } + return &StakingTransactor{contract: contract}, nil +} + +// NewStakingFilterer creates a new log filterer instance of Staking, bound to a specific deployed contract. +func NewStakingFilterer(address common.Address, filterer bind.ContractFilterer) (*StakingFilterer, error) { + contract, err := bindStaking(address, nil, nil, filterer) + if err != nil { + return nil, err + } + return &StakingFilterer{contract: contract}, nil +} + +// bindStaking binds a generic wrapper to an already deployed contract. +func bindStaking(address common.Address, caller bind.ContractCaller, transactor bind.ContractTransactor, filterer bind.ContractFilterer) (*bind.BoundContract, error) { + parsed, err := abi.JSON(strings.NewReader(StakingABI)) + if err != nil { + return nil, err + } + return bind.NewBoundContract(address, parsed, caller, transactor, filterer), nil +} + +// Call invokes the (constant) contract method with params as input values and +// sets the output to result. The result type might be a single field for simple +// returns, a slice of interfaces for anonymous returns and a struct for named +// returns. +func (_Staking *StakingRaw) Call(opts *bind.CallOpts, result *[]interface{}, method string, params ...interface{}) error { + return _Staking.Contract.StakingCaller.contract.Call(opts, result, method, params...) +} + +// Transfer initiates a plain transaction to move funds to the contract, calling +// its default method if one is available. +func (_Staking *StakingRaw) Transfer(opts *bind.TransactOpts) (*types.Transaction, error) { + return _Staking.Contract.StakingTransactor.contract.Transfer(opts) +} + +// Transact invokes the (paid) contract method with params as input values. +func (_Staking *StakingRaw) Transact(opts *bind.TransactOpts, method string, params ...interface{}) (*types.Transaction, error) { + return _Staking.Contract.StakingTransactor.contract.Transact(opts, method, params...) +} + +// Call invokes the (constant) contract method with params as input values and +// sets the output to result. The result type might be a single field for simple +// returns, a slice of interfaces for anonymous returns and a struct for named +// returns. +func (_Staking *StakingCallerRaw) Call(opts *bind.CallOpts, result *[]interface{}, method string, params ...interface{}) error { + return _Staking.Contract.contract.Call(opts, result, method, params...) +} + +// Transfer initiates a plain transaction to move funds to the contract, calling +// its default method if one is available. +func (_Staking *StakingTransactorRaw) Transfer(opts *bind.TransactOpts) (*types.Transaction, error) { + return _Staking.Contract.contract.Transfer(opts) +} + +// Transact invokes the (paid) contract method with params as input values. +func (_Staking *StakingTransactorRaw) Transact(opts *bind.TransactOpts, method string, params ...interface{}) (*types.Transaction, error) { + return _Staking.Contract.contract.Transact(opts, method, params...) +} + +// GetManyStakingTotals is a free data retrieval call binding the contract method 0x91f8723f. +// +// Solidity: function getManyStakingTotals(address[] consensusAddrs) view returns(uint256[] stakingAmounts_) +func (_Staking *StakingCaller) GetManyStakingTotals(opts *bind.CallOpts, consensusAddrs []common.Address) ([]*big.Int, error) { + var out []interface{} + err := _Staking.contract.Call(opts, &out, "getManyStakingTotals", consensusAddrs) + + if err != nil { + return *new([]*big.Int), err + } + + out0 := *abi.ConvertType(out[0], new([]*big.Int)).(*[]*big.Int) + + return out0, err + +} + +// GetManyStakingTotals is a free data retrieval call binding the contract method 0x91f8723f. +// +// Solidity: function getManyStakingTotals(address[] consensusAddrs) view returns(uint256[] stakingAmounts_) +func (_Staking *StakingSession) GetManyStakingTotals(consensusAddrs []common.Address) ([]*big.Int, error) { + return _Staking.Contract.GetManyStakingTotals(&_Staking.CallOpts, consensusAddrs) +} + +// GetManyStakingTotals is a free data retrieval call binding the contract method 0x91f8723f. +// +// Solidity: function getManyStakingTotals(address[] consensusAddrs) view returns(uint256[] stakingAmounts_) +func (_Staking *StakingCallerSession) GetManyStakingTotals(consensusAddrs []common.Address) ([]*big.Int, error) { + return _Staking.Contract.GetManyStakingTotals(&_Staking.CallOpts, consensusAddrs) +} diff --git a/consensus/consortium/v2/consortium.go b/consensus/consortium/v2/consortium.go index 0b91785870..eb9ab3ffb8 100644 --- a/consensus/consortium/v2/consortium.go +++ b/consensus/consortium/v2/consortium.go @@ -273,16 +273,12 @@ func (c *Consortium) verifyFinalitySignatures( parentHash common.Hash, parents []*types.Header, ) error { + isTripp := c.chainConfig.IsTripp(new(big.Int).SetUint64(parentNumber + 1)) snap, err := c.snapshot(chain, parentNumber, parentHash, parents) if err != nil { return err } - votedValidatorPositions := finalityVotedValidators.Indices() - if len(votedValidatorPositions) < int(math.Floor(finalityRatio*float64(len(snap.ValidatorsWithBlsPub))))+1 { - return finality.ErrNotEnoughFinalityVote - } - voteData := types.VoteData{ TargetNumber: parentNumber, TargetHash: parentHash, @@ -290,13 +286,34 @@ func (c *Consortium) verifyFinalitySignatures( digest := voteData.Hash() // verify aggregated signature - var publicKeys []blsCommon.PublicKey + var ( + publicKeys []blsCommon.PublicKey + accumulatedVoteWeight int + finalityThreshold int + ) + votedValidatorPositions := finalityVotedValidators.Indices() for _, position := range votedValidatorPositions { if position >= len(snap.ValidatorsWithBlsPub) { return finality.ErrInvalidFinalityVotedBitSet } publicKeys = append(publicKeys, snap.ValidatorsWithBlsPub[position].BlsPublicKey) + if isTripp { + accumulatedVoteWeight += int(snap.FinalityVoteWeight[position]) + } else { + accumulatedVoteWeight += 1 + } } + + if isTripp { + finalityThreshold = int(math.Floor(finalityRatio*float64(consortiumCommon.MaxFinalityVotePercentage))) + 1 + } else { + finalityThreshold = int(math.Floor(finalityRatio*float64(len(snap.ValidatorsWithBlsPub)))) + 1 + } + + if accumulatedVoteWeight < finalityThreshold { + return finality.ErrNotEnoughFinalityVote + } + if !finalitySignatures.FastAggregateVerify(publicKeys, digest) { return finality.ErrFinalitySignatureVerificationFailed } @@ -681,7 +698,10 @@ func (c *Consortium) getCheckpointValidatorsFromContract( ) isShillin := c.chainConfig.IsShillin(header.Number) - if isShillin { + isTripp := c.chainConfig.IsTripp(header.Number) + // After Tripp, BLS key of validator is read at the start of new period, + // not the start of epoch anymore. + if isShillin && !isTripp { // The filteredValidators shares the same underlying array with newValidators // See more: https://github.com/golang/go/wiki/SliceTricks#filtering-without-allocating filteredValidators = filteredValidators[:0] @@ -698,7 +718,7 @@ func (c *Consortium) getCheckpointValidatorsFromContract( validatorWithBlsPub := finality.ValidatorWithBlsPub{ Address: filteredValidators[i], } - if isShillin { + if isShillin && !isTripp { validatorWithBlsPub.BlsPublicKey = blsPublicKeys[i] } @@ -735,6 +755,7 @@ func (c *Consortium) Prepare(chain consensus.ChainHeaderReader, header *types.He return err } extraData.CheckpointValidators = checkpointValidator + // TODO: if is Tripp and new period, read all validator candidates and their amounts, appends to header } // After Shillin, extraData.hasFinalityVote = 0 here as we does @@ -891,6 +912,7 @@ func (c *Consortium) Finalize(chain consensus.ChainHeaderReader, header *types.H if err != nil { return err } + // TODO: if is Tripp and new period, read all validator candidates and their amounts, check with stored data in header extraData, err := finality.DecodeExtra(header.Extra, isShillin) if err != nil { return err @@ -1228,14 +1250,24 @@ func (c *Consortium) assembleFinalityVote(header *types.Header, snap *Snapshot) var ( signatures []blsCommon.Signature finalityVotedValidators finality.FinalityVoteBitSet - finalityThreshold int = int(math.Floor(finalityRatio*float64(len(snap.ValidatorsWithBlsPub)))) + 1 + finalityThreshold int + accumulatedVoteWeight int ) + isTripp := c.chainConfig.IsTripp(header.Number) + if isTripp { + finalityThreshold = int(math.Floor(finalityRatio*float64(consortiumCommon.MaxFinalityVotePercentage))) + 1 + } else { + finalityThreshold = int(math.Floor(finalityRatio*float64(len(snap.ValidatorsWithBlsPub)))) + 1 + } + // We assume the signature has been verified in vote pool // so we do not verify signature here if c.votePool != nil { votes := c.votePool.FetchVoteByBlockHash(header.ParentHash) - if len(votes) >= finalityThreshold { + // Before Tripp (!isTripp), every vote has the same weight so if the number of votes + // is lower than threshold, we can short-circuit and skip all the checks + if isTripp || len(votes) >= finalityThreshold { for _, vote := range votes { publicKey, err := blst.PublicKeyFromBytes(vote.PublicKey[:]) if err != nil { @@ -1245,14 +1277,22 @@ func (c *Consortium) assembleFinalityVote(header *types.Header, snap *Snapshot) authorized := false for valPosition, validator := range snap.ValidatorsWithBlsPub { if publicKey.Equals(validator.BlsPublicKey) { + authorized = true signature, err := blst.SignatureFromBytes(vote.Signature[:]) if err != nil { log.Warn("Malformed signature from vote pool", "err", err) break } + if finalityVotedValidators.GetBit(valPosition) != 0 { + log.Warn("More than 1 vote from validator", "address", validator.Address.Hex(), + "blockHash", header.Hash(), "blockNumber", header.Number) + break + } signatures = append(signatures, signature) finalityVotedValidators.SetBit(valPosition) - authorized = true + if isTripp { + accumulatedVoteWeight += int(snap.FinalityVoteWeight[valPosition]) + } break } } @@ -1261,8 +1301,10 @@ func (c *Consortium) assembleFinalityVote(header *types.Header, snap *Snapshot) } } - bitSetCount := len(finalityVotedValidators.Indices()) - if bitSetCount >= finalityThreshold { + if !isTripp { + accumulatedVoteWeight = len(finalityVotedValidators.Indices()) + } + if accumulatedVoteWeight >= finalityThreshold { extraData, err := finality.DecodeExtra(header.Extra, true) if err != nil { // This should not happen @@ -1277,7 +1319,6 @@ func (c *Consortium) assembleFinalityVote(header *types.Header, snap *Snapshot) } } } - } // GetFinalizedBlock gets the fast finality finalized block diff --git a/consensus/consortium/v2/consortium_test.go b/consensus/consortium/v2/consortium_test.go index 59c10c0c87..b2a9784196 100644 --- a/consensus/consortium/v2/consortium_test.go +++ b/consensus/consortium/v2/consortium_test.go @@ -637,6 +637,112 @@ func TestVerifyFinalitySignature(t *testing.T) { } } +func TestVerifyFinalitySignatureTripp(t *testing.T) { + const numValidator = 3 + var err error + + secretKey := make([]blsCommon.SecretKey, numValidator) + for i := 0; i < len(secretKey); i++ { + secretKey[i], err = blst.RandKey() + if err != nil { + t.Fatalf("Failed to generate secret key, err %s", err) + } + } + + valWithBlsPub := make([]finality.ValidatorWithBlsPub, numValidator) + for i := 0; i < len(valWithBlsPub); i++ { + valWithBlsPub[i] = finality.ValidatorWithBlsPub{ + Address: common.BigToAddress(big.NewInt(int64(i))), + BlsPublicKey: secretKey[i].PublicKey(), + } + } + + blockNumber := uint64(0) + blockHash := common.Hash{0x1} + vote := types.VoteData{ + TargetNumber: blockNumber, + TargetHash: blockHash, + } + + digest := vote.Hash() + signature := make([]blsCommon.Signature, numValidator) + for i := 0; i < len(signature); i++ { + signature[i] = secretKey[i].Sign(digest[:]) + } + + snap := newSnapshot(nil, nil, nil, 10, common.Hash{}, nil, valWithBlsPub, nil) + snap.FinalityVoteWeight = make([]uint16, numValidator) + snap.FinalityVoteWeight[0] = 6666 + snap.FinalityVoteWeight[1] = 1 + snap.FinalityVoteWeight[2] = 3333 + + recents, _ := lru.NewARC(inmemorySnapshots) + c := Consortium{ + chainConfig: ¶ms.ChainConfig{ + ShillinBlock: big.NewInt(0), + TrippBlock: big.NewInt(0), + }, + config: ¶ms.ConsortiumConfig{ + EpochV2: 300, + }, + recents: recents, + } + snap.Hash = blockHash + c.recents.Add(snap.Hash, snap) + + // 1 voter with vote weight 6666 does not reach the threshold + votedBitSet := finality.FinalityVoteBitSet(0) + votedBitSet.SetBit(0) + aggregatedSignature := blst.AggregateSignatures([]blsCommon.Signature{ + signature[0], + }) + err = c.verifyFinalitySignatures(nil, votedBitSet, aggregatedSignature, 0, snap.Hash, nil) + if !errors.Is(err, finality.ErrNotEnoughFinalityVote) { + t.Errorf("Expect error %v have %v", finality.ErrNotEnoughFinalityVote, err) + } + + // 2 voters with total vote weight 3333 + 1 does not reach the threshold + votedBitSet = finality.FinalityVoteBitSet(0) + votedBitSet.SetBit(1) + votedBitSet.SetBit(2) + aggregatedSignature = blst.AggregateSignatures([]blsCommon.Signature{ + signature[1], + signature[2], + }) + err = c.verifyFinalitySignatures(nil, votedBitSet, aggregatedSignature, 0, snap.Hash, nil) + if !errors.Is(err, finality.ErrNotEnoughFinalityVote) { + t.Errorf("Expect error %v have %v", finality.ErrNotEnoughFinalityVote, err) + } + + // 2 voters with total vote weight 6666 + 1 reach the threshold + votedBitSet = finality.FinalityVoteBitSet(0) + votedBitSet.SetBit(0) + votedBitSet.SetBit(1) + aggregatedSignature = blst.AggregateSignatures([]blsCommon.Signature{ + signature[0], + signature[1], + }) + err = c.verifyFinalitySignatures(nil, votedBitSet, aggregatedSignature, 0, snap.Hash, nil) + if err != nil { + t.Errorf("Expect successful verification have %v", err) + } + + // All voters vote + votedBitSet = finality.FinalityVoteBitSet(0) + votedBitSet.SetBit(0) + votedBitSet.SetBit(1) + votedBitSet.SetBit(2) + aggregatedSignature = blst.AggregateSignatures([]blsCommon.Signature{ + signature[0], + signature[1], + signature[2], + }) + err = c.verifyFinalitySignatures(nil, votedBitSet, aggregatedSignature, 0, snap.Hash, nil) + if err != nil { + t.Errorf("Expect successful verification have %v", err) + } +} + func TestSnapshotValidatorWithBlsKey(t *testing.T) { secretKey, err := blst.RandKey() if err != nil { @@ -729,6 +835,10 @@ func (contract *mockContract) GetBlsPublicKey(_ *big.Int, address common.Address } } +func (contract *mockContract) GetStakedAmount(_ *big.Int, _ []common.Address) ([]*big.Int, error) { + return nil, nil +} + func TestGetCheckpointValidatorFromContract(t *testing.T) { var err error secretKeys := make([]blsCommon.SecretKey, 3) @@ -881,6 +991,11 @@ func TestAssembleFinalityVote(t *testing.T) { } } +// TODO: Add AssembleFinalityVoteTripp test +func TestAssembleFinalityVoteTripp(t *testing.T) { + +} + func TestVerifyVote(t *testing.T) { const numValidator = 3 var err error diff --git a/consensus/consortium/v2/finality/consortium_header.go b/consensus/consortium/v2/finality/consortium_header.go index e19e85ed15..9ef53ca12a 100644 --- a/consensus/consortium/v2/finality/consortium_header.go +++ b/consensus/consortium/v2/finality/consortium_header.go @@ -139,6 +139,14 @@ func (bitSet *FinalityVoteBitSet) Indices() []int { return votedValidatorPositions } +func (bitSet *FinalityVoteBitSet) GetBit(index int) int { + if index >= finalityVoteBitSetByteLength*8 { + return 0 + } + + return int((uint64(*bitSet) >> index) & 1) +} + func (bitSet *FinalityVoteBitSet) SetBit(index int) { if index >= finalityVoteBitSetByteLength*8 { return diff --git a/consensus/consortium/v2/finality/consortium_header_test.go b/consensus/consortium/v2/finality/consortium_header_test.go new file mode 100644 index 0000000000..2872f32119 --- /dev/null +++ b/consensus/consortium/v2/finality/consortium_header_test.go @@ -0,0 +1,39 @@ +package finality + +import "testing" + +func TestFinalityVoteBitSet(t *testing.T) { + var bitSet FinalityVoteBitSet + + bitSet.SetBit(0) + bitSet.SetBit(40) + // index >= 64 has no effect + bitSet.SetBit(64) + + // 2 ** 40 + 2 ** 0 + if uint64(bitSet) != 1099511627777 { + t.Fatalf("Wrong bitset value, exp %d got %d", 1099511627777, uint64(bitSet)) + } + + indices := bitSet.Indices() + if len(indices) != 2 { + t.Fatalf("Wrong indices length, exp %d got %d", 2, len(indices)) + } + if indices[0] != 0 { + t.Fatalf("Wrong index, exp %d got %d", 0, indices[0]) + } + if indices[1] != 40 { + t.Fatalf("Wrong index, exp %d got %d", 40, indices[1]) + } + + if bitSet.GetBit(40) != 1 { + t.Fatalf("Wrong bit, exp %d got %d", 1, bitSet.GetBit(40)) + } + if bitSet.GetBit(50) != 0 { + t.Fatalf("Wrong bit, exp %d got %d", 1, bitSet.GetBit(50)) + } + // index >= 64 returns 0 + if bitSet.GetBit(70) != 0 { + t.Fatalf("Wrong bit, exp %d got %d", 0, bitSet.GetBit(70)) + } +} diff --git a/consensus/consortium/v2/snapshot.go b/consensus/consortium/v2/snapshot.go index 5931f0713a..e1e4c765b4 100644 --- a/consensus/consortium/v2/snapshot.go +++ b/consensus/consortium/v2/snapshot.go @@ -21,6 +21,7 @@ import ( // Snapshot is the state of the authorization validators at a given point in time. type Snapshot struct { + // private fields are not json.Marshalled chainConfig *params.ChainConfig config *params.ConsortiumConfig // Consensus engine parameters to fine tune behavior ethAPI *ethapi.PublicBlockChainAPI @@ -32,7 +33,8 @@ type Snapshot struct { Recents map[uint64]common.Address `json:"recents"` // Set of recent validators for spam protections // Finality additional fields - ValidatorsWithBlsPub []finality.ValidatorWithBlsPub `json:"validatorWithBlsPub,omitempty"` // Array of sorted authorized validators and BLS public keys after Shillin + ValidatorsWithBlsPub []finality.ValidatorWithBlsPub `json:"validatorWithBlsPub,omitempty"` // Array of sorted authorized validators and BLS public keys after Shillin + FinalityVoteWeight []uint16 `json:"finalityVoteWeight,omitempty"` JustifiedBlockNumber uint64 `json:"justifiedBlockNumber,omitempty"` // The justified block number JustifiedBlockHash common.Hash `json:"justifiedBlockHash,omitempty"` // The justified block hash } @@ -239,6 +241,7 @@ func (s *Snapshot) apply(headers []*types.Header, chain consensus.ChainHeaderRea snap.ValidatorsWithBlsPub = nil } else { isShillin := chain.Config().IsShillin(checkpointHeader.Number) + isTripp := chain.Config().IsTripp(checkpointHeader.Number) // Get validator set from headers and use that for new validator set extraData, err := finality.DecodeExtra(checkpointHeader.Extra, isShillin) if err != nil { @@ -253,7 +256,11 @@ func (s *Snapshot) apply(headers []*types.Header, chain consensus.ChainHeaderRea } } - if isShillin { + if isTripp { + // TODO: if at the start of period, read BLS key, consensus and staked amount from header + + // TODO: if at the start of epoch, read block producer's consensus address + } else if isShillin { // The validator information in checkpoint header is already sorted, // we don't need to sort here snap.ValidatorsWithBlsPub = make([]finality.ValidatorWithBlsPub, len(extraData.CheckpointValidators))