Skip to content

Commit

Permalink
happy path test; aggregating but not yet verifying
Browse files Browse the repository at this point in the history
addresses review comment #457 (comment)
  • Loading branch information
feuGeneA committed Aug 30, 2024
1 parent ad87ee7 commit a7c907e
Show file tree
Hide file tree
Showing 3 changed files with 243 additions and 6 deletions.
6 changes: 3 additions & 3 deletions peers/app_request_network.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,12 +212,12 @@ type ConnectedCanonicalValidators struct {
ConnectedWeight uint64
TotalValidatorWeight uint64
ValidatorSet []*warp.Validator
nodeValidatorIndexMap map[ids.NodeID]int
NodeValidatorIndexMap map[ids.NodeID]int
}

// Returns the Warp Validator and its index in the canonical Validator ordering for a given nodeID
func (c *ConnectedCanonicalValidators) GetValidator(nodeID ids.NodeID) (*warp.Validator, int) {
return c.ValidatorSet[c.nodeValidatorIndexMap[nodeID]], c.nodeValidatorIndexMap[nodeID]
return c.ValidatorSet[c.NodeValidatorIndexMap[nodeID]], c.NodeValidatorIndexMap[nodeID]
}

// ConnectToCanonicalValidators connects to the canonical validators of the given subnet and returns the connected
Expand Down Expand Up @@ -258,7 +258,7 @@ func (n *appRequestNetwork) ConnectToCanonicalValidators(subnetID ids.ID) (*Conn
ConnectedWeight: connectedWeight,
TotalValidatorWeight: totalValidatorWeight,
ValidatorSet: validatorSet,
nodeValidatorIndexMap: nodeValidatorIndexMap,
NodeValidatorIndexMap: nodeValidatorIndexMap,
}, nil
}

Expand Down
2 changes: 2 additions & 0 deletions signature-aggregator/aggregator/aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package aggregator

import (
"bytes"
"encoding/hex"
"errors"
"fmt"
"math/big"
Expand Down Expand Up @@ -554,6 +555,7 @@ func (s *SignatureAggregator) isValidSignatureResponse(
if !bls.Verify(pubKey, sig, unsignedMessage.Bytes()) {
s.logger.Debug(
"Failed verification for signature",
zap.String("pubKey", hex.EncodeToString(bls.PublicKeyToUncompressedBytes(pubKey))),
)
return blsSignatureBuf{}, false
}
Expand Down
241 changes: 238 additions & 3 deletions signature-aggregator/aggregator/aggregator_test.go
Original file line number Diff line number Diff line change
@@ -1,23 +1,35 @@
package aggregator

import (
"encoding/hex"
"fmt"
"os"
"testing"

"github.com/ava-labs/avalanchego/ids"
"github.com/ava-labs/avalanchego/message"
"github.com/ava-labs/avalanchego/subnets"
"github.com/ava-labs/avalanchego/utils"
"github.com/ava-labs/avalanchego/utils/constants"
"github.com/ava-labs/avalanchego/utils/crypto/bls"
"github.com/ava-labs/avalanchego/utils/logging"
"github.com/ava-labs/avalanchego/utils/set"
"github.com/ava-labs/avalanchego/vms/platformvm/warp"
"github.com/ava-labs/awm-relayer/peers"
"github.com/ava-labs/awm-relayer/peers/mocks"
"github.com/ava-labs/awm-relayer/signature-aggregator/metrics"
evmMsg "github.com/ava-labs/subnet-evm/plugin/evm/message"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)

var sigAggMetrics *metrics.SignatureAggregatorMetrics
var messageCreator message.Creator
var (
sigAggMetrics *metrics.SignatureAggregatorMetrics
messageCreator message.Creator
)

func instantiateAggregator(t *testing.T) (
*SignatureAggregator,
Expand All @@ -39,7 +51,16 @@ func instantiateAggregator(t *testing.T) (
}
aggregator, err := NewSignatureAggregator(
mockNetwork,
logging.NoLog{},
logging.NewLogger(
"aggregator_test",
logging.NewWrappedCore(
logging.Debug,
os.Stdout,
zapcore.NewConsoleEncoder(
zap.NewProductionEncoderConfig(),
),
),
),
1024,
sigAggMetrics,
messageCreator,
Expand All @@ -48,6 +69,54 @@ func instantiateAggregator(t *testing.T) (
return aggregator, mockNetwork
}

func makeConnectedValidators(validatorCount uint64) (*peers.ConnectedCanonicalValidators, []*bls.SecretKey) {
var validatorSet []*warp.Validator
var validatorSecretKeys []*bls.SecretKey
for i := uint64(0); i < validatorCount; i++ {
secretKey, err := bls.NewSecretKey()
if err != nil {
panic(err)
}
validatorSecretKeys = append(validatorSecretKeys, secretKey)

pubKey := bls.PublicFromSecretKey(secretKey)

nodeID, err := ids.ToNodeID(utils.RandomBytes(20))
if err != nil {
panic(err)
}

fmt.Printf(
"validator with pubKey %s has nodeID %s\n",
hex.EncodeToString(bls.PublicKeyToUncompressedBytes(pubKey)),
nodeID.String(),
)

validatorSet = append(validatorSet,
&warp.Validator{
PublicKey: pubKey,
PublicKeyBytes: bls.PublicKeyToUncompressedBytes(pubKey),
Weight: 1,
NodeIDs: []ids.NodeID{nodeID},
},
)
}

nodeValidatorIndexMap := make(map[ids.NodeID]int)
for i, vdr := range validatorSet {
for _, node := range vdr.NodeIDs {
nodeValidatorIndexMap[node] = i
}
}

return &peers.ConnectedCanonicalValidators{
ConnectedWeight: validatorCount,
TotalValidatorWeight: validatorCount,
ValidatorSet: validatorSet,
NodeValidatorIndexMap: nodeValidatorIndexMap,
}, validatorSecretKeys
}

func TestCreateSignedMessageFailsWithNoValidators(t *testing.T) {
aggregator, mockNetwork := instantiateAggregator(t)
msg, err := warp.NewUnsignedMessage(0, ids.Empty, []byte{})
Expand Down Expand Up @@ -85,3 +154,169 @@ func TestCreateSignedMessageFailsWithoutSufficientConnectedStake(t *testing.T) {
"failed to connect to a threshold of stake",
)
}

func makeAppRequests(
chainID ids.ID,
requestID uint32,
connectedValidators *peers.ConnectedCanonicalValidators,
) []ids.RequestID {
var appRequests []ids.RequestID
for _, validator := range connectedValidators.ValidatorSet {
for _, nodeID := range validator.NodeIDs {
appRequests = append(
appRequests,
ids.RequestID{
NodeID: nodeID,
SourceChainID: chainID,
DestinationChainID: chainID,
RequestID: requestID,
Op: byte(
message.AppResponseOp,
),
},
)
}
}
return appRequests
}

func TestCreateSignedMessageRetriesAndFailsWithoutP2PResponses(t *testing.T) {
aggregator, mockNetwork := instantiateAggregator(t)

var (
connectedValidators, _ = makeConnectedValidators(2)
requestID = aggregator.currentRequestID.Load() + 1
)

chainID, err := ids.ToID(utils.RandomBytes(32))
if err != nil {
panic(err)
}

msg, err := warp.NewUnsignedMessage(0, chainID, []byte{})
require.Equal(t, err, nil)

subnetID, err := ids.ToID(utils.RandomBytes(32))
require.Equal(t, err, nil)
mockNetwork.EXPECT().GetSubnetID(chainID).Return(
subnetID,
nil,
)

mockNetwork.EXPECT().ConnectToCanonicalValidators(subnetID).Return(
connectedValidators,
nil,
)

appRequests := makeAppRequests(chainID, requestID, connectedValidators)
for _, appRequest := range appRequests {
mockNetwork.EXPECT().RegisterAppRequest(appRequest).Times(
maxRelayerQueryAttempts,
)
}

mockNetwork.EXPECT().RegisterRequestID(
requestID,
len(appRequests),
).Return(
make(chan message.InboundMessage, len(appRequests)),
).Times(maxRelayerQueryAttempts)

var nodeIDs set.Set[ids.NodeID]
for _, appRequest := range appRequests {
nodeIDs.Add(appRequest.NodeID)
}
mockNetwork.EXPECT().Send(
gomock.Any(),
nodeIDs,
subnetID,
subnets.NoOpAllower,
).Times(maxRelayerQueryAttempts)

_, err = aggregator.CreateSignedMessage(msg, subnetID, 80)
require.ErrorContains(
t,
err,
"failed to collect a threshold of signatures",
)
}

func TestCreateSignedMessageSucceeds(t *testing.T) {
aggregator, mockNetwork := instantiateAggregator(t)

var connectedValidators, validatorSecretKeys = makeConnectedValidators(5)
var requestID = aggregator.currentRequestID.Load() + 1

chainID, err := ids.ToID(utils.RandomBytes(32))
if err != nil {
panic(err)
}

subnetID, err := ids.ToID(utils.RandomBytes(32))
require.Equal(t, err, nil)
mockNetwork.EXPECT().GetSubnetID(chainID).Return(
subnetID,
nil,
)

mockNetwork.EXPECT().ConnectToCanonicalValidators(subnetID).Return(
connectedValidators,
nil,
)

appRequests := makeAppRequests(chainID, requestID, connectedValidators)
for _, appRequest := range appRequests {
mockNetwork.EXPECT().RegisterAppRequest(appRequest).Times(1)
}

networkID := uint32(0)
msg, err := warp.NewUnsignedMessage(
networkID,
chainID,
utils.RandomBytes(1234),
)
require.Equal(t, err, nil)

responseChan := make(chan message.InboundMessage, len(appRequests))
for _, appRequest := range appRequests {
validatorSecretKey := validatorSecretKeys[connectedValidators.NodeValidatorIndexMap[appRequest.NodeID]]
responseBytes, err := evmMsg.Codec.Marshal(
0,
&evmMsg.SignatureResponse{
Signature: [bls.SignatureLen]byte(
bls.SignatureToBytes(
bls.Sign(
validatorSecretKey,
msg.Bytes(),
),
),
),
},
)
require.Equal(t, err, nil)
responseChan <- message.InboundAppResponse(
chainID,
requestID,
responseBytes,
appRequest.NodeID,
)
}
mockNetwork.EXPECT().RegisterRequestID(
requestID,
len(appRequests),
).Return(responseChan).Times(1)

var nodeIDs set.Set[ids.NodeID]
for _, appRequest := range appRequests {
nodeIDs.Add(appRequest.NodeID)
}
mockNetwork.EXPECT().Send(
gomock.Any(),
nodeIDs,
subnetID,
subnets.NoOpAllower,
).Times(1).Return(nodeIDs)

_, err = aggregator.CreateSignedMessage(msg, subnetID, 80)
require.Equal(t, err, nil)
}

0 comments on commit a7c907e

Please sign in to comment.