diff --git a/peers/app_request_network.go b/peers/app_request_network.go index f25a8518..df4e20e0 100644 --- a/peers/app_request_network.go +++ b/peers/app_request_network.go @@ -212,12 +212,12 @@ type ConnectedCanonicalValidators struct { ConnectedWeight uint64 TotalValidatorWeight uint64 ValidatorSet []*warp.Validator - nodeValidatorIndexMap map[ids.NodeID]int + NodeValidatorIndexMap map[ids.NodeID]int } // Returns the Warp Validator and its index in the canonical Validator ordering for a given nodeID func (c *ConnectedCanonicalValidators) GetValidator(nodeID ids.NodeID) (*warp.Validator, int) { - return c.ValidatorSet[c.nodeValidatorIndexMap[nodeID]], c.nodeValidatorIndexMap[nodeID] + return c.ValidatorSet[c.NodeValidatorIndexMap[nodeID]], c.NodeValidatorIndexMap[nodeID] } // ConnectToCanonicalValidators connects to the canonical validators of the given subnet and returns the connected @@ -258,7 +258,7 @@ func (n *appRequestNetwork) ConnectToCanonicalValidators(subnetID ids.ID) (*Conn ConnectedWeight: connectedWeight, TotalValidatorWeight: totalValidatorWeight, ValidatorSet: validatorSet, - nodeValidatorIndexMap: nodeValidatorIndexMap, + NodeValidatorIndexMap: nodeValidatorIndexMap, }, nil } diff --git a/signature-aggregator/aggregator/aggregator.go b/signature-aggregator/aggregator/aggregator.go index a48ac728..756630b1 100644 --- a/signature-aggregator/aggregator/aggregator.go +++ b/signature-aggregator/aggregator/aggregator.go @@ -5,6 +5,7 @@ package aggregator import ( "bytes" + "encoding/hex" "errors" "fmt" "math/big" @@ -554,6 +555,7 @@ func (s *SignatureAggregator) isValidSignatureResponse( if !bls.Verify(pubKey, sig, unsignedMessage.Bytes()) { s.logger.Debug( "Failed verification for signature", + zap.String("pubKey", hex.EncodeToString(bls.PublicKeyToUncompressedBytes(pubKey))), ) return blsSignatureBuf{}, false } diff --git a/signature-aggregator/aggregator/aggregator_test.go b/signature-aggregator/aggregator/aggregator_test.go index 13a6952c..b8d99860 100644 --- a/signature-aggregator/aggregator/aggregator_test.go +++ b/signature-aggregator/aggregator/aggregator_test.go @@ -1,23 +1,37 @@ package aggregator import ( + "context" + "encoding/hex" + "fmt" + "os" "testing" "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/message" + "github.com/ava-labs/avalanchego/snow/validators" + "github.com/ava-labs/avalanchego/subnets" + "github.com/ava-labs/avalanchego/utils" "github.com/ava-labs/avalanchego/utils/constants" + "github.com/ava-labs/avalanchego/utils/crypto/bls" "github.com/ava-labs/avalanchego/utils/logging" + "github.com/ava-labs/avalanchego/utils/set" "github.com/ava-labs/avalanchego/vms/platformvm/warp" "github.com/ava-labs/awm-relayer/peers" "github.com/ava-labs/awm-relayer/peers/mocks" "github.com/ava-labs/awm-relayer/signature-aggregator/metrics" + evmMsg "github.com/ava-labs/subnet-evm/plugin/evm/message" "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" ) -var sigAggMetrics *metrics.SignatureAggregatorMetrics -var messageCreator message.Creator +var ( + sigAggMetrics *metrics.SignatureAggregatorMetrics + messageCreator message.Creator +) func instantiateAggregator(t *testing.T) ( *SignatureAggregator, @@ -39,7 +53,16 @@ func instantiateAggregator(t *testing.T) ( } aggregator, err := NewSignatureAggregator( mockNetwork, - logging.NoLog{}, + logging.NewLogger( + "aggregator_test", + logging.NewWrappedCore( + logging.Debug, + os.Stdout, + zapcore.NewConsoleEncoder( + zap.NewProductionEncoderConfig(), + ), + ), + ), 1024, sigAggMetrics, messageCreator, @@ -48,6 +71,54 @@ func instantiateAggregator(t *testing.T) ( return aggregator, mockNetwork } +func makeConnectedValidators(validatorCount uint64) (*peers.ConnectedCanonicalValidators, []*bls.SecretKey) { + var validatorSet []*warp.Validator + var validatorSecretKeys []*bls.SecretKey + for i := uint64(0); i < validatorCount; i++ { + secretKey, err := bls.NewSecretKey() + if err != nil { + panic(err) + } + validatorSecretKeys = append(validatorSecretKeys, secretKey) + + pubKey := bls.PublicFromSecretKey(secretKey) + + nodeID, err := ids.ToNodeID(utils.RandomBytes(20)) + if err != nil { + panic(err) + } + + fmt.Printf( + "validator with pubKey %s has nodeID %s\n", + hex.EncodeToString(bls.PublicKeyToUncompressedBytes(pubKey)), + nodeID.String(), + ) + + validatorSet = append(validatorSet, + &warp.Validator{ + PublicKey: pubKey, + PublicKeyBytes: bls.PublicKeyToUncompressedBytes(pubKey), + Weight: 1, + NodeIDs: []ids.NodeID{nodeID}, + }, + ) + } + + nodeValidatorIndexMap := make(map[ids.NodeID]int) + for i, vdr := range validatorSet { + for _, node := range vdr.NodeIDs { + nodeValidatorIndexMap[node] = i + } + } + + return &peers.ConnectedCanonicalValidators{ + ConnectedWeight: validatorCount, + TotalValidatorWeight: validatorCount, + ValidatorSet: validatorSet, + NodeValidatorIndexMap: nodeValidatorIndexMap, + }, validatorSecretKeys +} + func TestCreateSignedMessageFailsWithNoValidators(t *testing.T) { aggregator, mockNetwork := instantiateAggregator(t) msg, err := warp.NewUnsignedMessage(0, ids.Empty, []byte{}) @@ -85,3 +156,224 @@ func TestCreateSignedMessageFailsWithoutSufficientConnectedStake(t *testing.T) { "failed to connect to a threshold of stake", ) } + +func makeAppRequests( + chainID ids.ID, + requestID uint32, + connectedValidators *peers.ConnectedCanonicalValidators, +) []ids.RequestID { + var appRequests []ids.RequestID + for _, validator := range connectedValidators.ValidatorSet { + for _, nodeID := range validator.NodeIDs { + appRequests = append( + appRequests, + ids.RequestID{ + NodeID: nodeID, + SourceChainID: chainID, + DestinationChainID: chainID, + RequestID: requestID, + Op: byte( + message.AppResponseOp, + ), + }, + ) + } + } + return appRequests +} + +func TestCreateSignedMessageRetriesAndFailsWithoutP2PResponses(t *testing.T) { + aggregator, mockNetwork := instantiateAggregator(t) + + var ( + connectedValidators, _ = makeConnectedValidators(2) + requestID = aggregator.currentRequestID.Load() + 1 + ) + + chainID, err := ids.ToID(utils.RandomBytes(32)) + if err != nil { + panic(err) + } + + msg, err := warp.NewUnsignedMessage(0, chainID, []byte{}) + require.Equal(t, err, nil) + + subnetID, err := ids.ToID(utils.RandomBytes(32)) + require.Equal(t, err, nil) + mockNetwork.EXPECT().GetSubnetID(chainID).Return( + subnetID, + nil, + ) + + mockNetwork.EXPECT().ConnectToCanonicalValidators(subnetID).Return( + connectedValidators, + nil, + ) + + appRequests := makeAppRequests(chainID, requestID, connectedValidators) + for _, appRequest := range appRequests { + mockNetwork.EXPECT().RegisterAppRequest(appRequest).Times( + maxRelayerQueryAttempts, + ) + } + + mockNetwork.EXPECT().RegisterRequestID( + requestID, + len(appRequests), + ).Return( + make(chan message.InboundMessage, len(appRequests)), + ).Times(maxRelayerQueryAttempts) + + var nodeIDs set.Set[ids.NodeID] + for _, appRequest := range appRequests { + nodeIDs.Add(appRequest.NodeID) + } + mockNetwork.EXPECT().Send( + gomock.Any(), + nodeIDs, + subnetID, + subnets.NoOpAllower, + ).Times(maxRelayerQueryAttempts) + + _, err = aggregator.CreateSignedMessage(msg, subnetID, 80) + require.ErrorContains( + t, + err, + "failed to collect a threshold of signatures", + ) +} + +func TestCreateSignedMessageSucceeds(t *testing.T) { + aggregator, mockNetwork := instantiateAggregator(t) + + var connectedValidators, validatorSecretKeys = makeConnectedValidators(5) + var requestID = aggregator.currentRequestID.Load() + 1 + + chainID, err := ids.ToID(utils.RandomBytes(32)) + if err != nil { + panic(err) + } + + subnetID, err := ids.ToID(utils.RandomBytes(32)) + require.Equal(t, err, nil) + mockNetwork.EXPECT().GetSubnetID(chainID).Return( + subnetID, + nil, + ) + + mockNetwork.EXPECT().ConnectToCanonicalValidators(subnetID).Return( + connectedValidators, + nil, + ) + + appRequests := makeAppRequests(chainID, requestID, connectedValidators) + for _, appRequest := range appRequests { + mockNetwork.EXPECT().RegisterAppRequest(appRequest).Times(1) + } + + networkID := uint32(0) + msg, err := warp.NewUnsignedMessage( + networkID, + chainID, + utils.RandomBytes(1234), + ) + require.Equal(t, err, nil) + + responseChan := make(chan message.InboundMessage, len(appRequests)) + for _, appRequest := range appRequests { + validatorSecretKey := validatorSecretKeys[connectedValidators.NodeValidatorIndexMap[appRequest.NodeID]] + responseBytes, err := evmMsg.Codec.Marshal( + 0, + &evmMsg.SignatureResponse{ + Signature: [bls.SignatureLen]byte( + bls.SignatureToBytes( + bls.Sign( + validatorSecretKey, + msg.Bytes(), + ), + ), + ), + }, + ) + require.Equal(t, err, nil) + responseChan <- message.InboundAppResponse( + chainID, + requestID, + responseBytes, + appRequest.NodeID, + ) + } + mockNetwork.EXPECT().RegisterRequestID( + requestID, + len(appRequests), + ).Return(responseChan).Times(1) + + var nodeIDs set.Set[ids.NodeID] + for _, appRequest := range appRequests { + nodeIDs.Add(appRequest.NodeID) + } + mockNetwork.EXPECT().Send( + gomock.Any(), + nodeIDs, + subnetID, + subnets.NoOpAllower, + ).Times(1).Return(nodeIDs) + + signedMessage, err := aggregator.CreateSignedMessage(msg, subnetID, 80) + require.Equal(t, err, nil) + + pChainState := newPChainStateStub(chainID, subnetID, connectedValidators) + err = signedMessage.Signature.Verify( + context.Background(), + msg, + networkID, + pChainState, + 1, + 80, + 100, + ) + require.Equal(t, err, nil) +} + +type pChainStateStub struct { + subnetIDByChainID map[ids.ID]ids.ID + connectedCanonicalValidators *peers.ConnectedCanonicalValidators +} + +func newPChainStateStub( + chainID, subnetID ids.ID, + connectedValidators *peers.ConnectedCanonicalValidators, +) *pChainStateStub { + subnetIDByChainID := make(map[ids.ID]ids.ID) + subnetIDByChainID[chainID] = subnetID + return &pChainStateStub{ + subnetIDByChainID: subnetIDByChainID, + connectedCanonicalValidators: connectedValidators, + } +} + +func (p pChainStateStub) GetSubnetID(ctx context.Context, chainID ids.ID) (ids.ID, error) { + return p.subnetIDByChainID[chainID], nil +} + +func (p pChainStateStub) GetMinimumHeight(context.Context) (uint64, error) { return 0, nil } + +func (p pChainStateStub) GetCurrentHeight(context.Context) (uint64, error) { return 0, nil } + +func (p pChainStateStub) GetValidatorSet( + ctx context.Context, + height uint64, + subnetID ids.ID, +) (map[ids.NodeID]*validators.GetValidatorOutput, error) { + output := make(map[ids.NodeID]*validators.GetValidatorOutput) + for _, validator := range p.connectedCanonicalValidators.ValidatorSet { + for _, nodeID := range validator.NodeIDs { + output[nodeID] = &validators.GetValidatorOutput{ + NodeID: nodeID, + PublicKey: validator.PublicKey, + Weight: validator.Weight, + } + } + } + return output, nil +}