diff --git a/protocol/blockfetch/client.go b/protocol/blockfetch/client.go index b2930939..39881ff4 100644 --- a/protocol/blockfetch/client.go +++ b/protocol/blockfetch/client.go @@ -66,6 +66,7 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { Name: ProtocolName, ProtocolId: ProtocolId, Muxer: protoOptions.Muxer, + Logger: protoOptions.Logger, ErrorChan: protoOptions.ErrorChan, Mode: protoOptions.Mode, Role: protocol.ProtocolRoleClient, @@ -80,6 +81,8 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { func (c *Client) Start() { c.onceStart.Do(func() { + c.Protocol.Logger(). + Debug(fmt.Sprintf("starting protocol: %s", ProtocolName)) c.Protocol.Start() // Start goroutine to cleanup resources on protocol shutdown go func() { @@ -93,6 +96,8 @@ func (c *Client) Start() { func (c *Client) Stop() error { var err error c.onceStop.Do(func() { + c.Protocol.Logger(). + Debug(fmt.Sprintf("stopping protocol: %s", ProtocolName)) msg := NewMsgClientDone() err = c.SendMessage(msg) }) @@ -101,6 +106,8 @@ func (c *Client) Stop() error { // GetBlockRange starts an async process to fetch all blocks in the specified range (inclusive) func (c *Client) GetBlockRange(start common.Point, end common.Point) error { + c.Protocol.Logger(). + Debug(fmt.Sprintf("client called %s GetBlockRange(start: %+v, end: %+v)", ProtocolName, start, end)) c.busyMutex.Lock() c.blockUseCallback = true msg := NewMsgRequestRange(start, end) @@ -121,6 +128,8 @@ func (c *Client) GetBlockRange(start common.Point, end common.Point) error { // GetBlock requests and returns a single block specified by the provided point func (c *Client) GetBlock(point common.Point) (ledger.Block, error) { + c.Protocol.Logger(). + Debug(fmt.Sprintf("client called %s GetBlock(point: %+v)", ProtocolName, point)) c.busyMutex.Lock() c.blockUseCallback = false msg := NewMsgRequestRange(point, point) @@ -144,6 +153,8 @@ func (c *Client) GetBlock(point common.Point) (ledger.Block, error) { } func (c *Client) messageHandler(msg protocol.Message) error { + c.Protocol.Logger(). + Debug(fmt.Sprintf("handling client message for %s", ProtocolName)) var err error switch msg.Type() { case MessageTypeStartBatch: @@ -165,17 +176,23 @@ func (c *Client) messageHandler(msg protocol.Message) error { } func (c *Client) handleStartBatch() error { + c.Protocol.Logger(). + Debug(fmt.Sprintf("handling client start batch for %s", ProtocolName)) c.startBatchResultChan <- nil return nil } func (c *Client) handleNoBlocks() error { + c.Protocol.Logger(). + Debug(fmt.Sprintf("handling client no blocks found for %s", ProtocolName)) err := fmt.Errorf("block(s) not found") c.startBatchResultChan <- err return nil } func (c *Client) handleBlock(msgGeneric protocol.Message) error { + c.Protocol.Logger(). + Debug(fmt.Sprintf("handling client block found for %s", ProtocolName)) msg := msgGeneric.(*MsgBlock) // Decode only enough to get the block type value var wrappedBlock WrappedBlock @@ -201,6 +218,8 @@ func (c *Client) handleBlock(msgGeneric protocol.Message) error { } func (c *Client) handleBatchDone() error { + c.Protocol.Logger(). + Debug(fmt.Sprintf("handling client batch done for %s", ProtocolName)) c.busyMutex.Unlock() return nil } diff --git a/protocol/blockfetch/server.go b/protocol/blockfetch/server.go index e48de64d..b00f7770 100644 --- a/protocol/blockfetch/server.go +++ b/protocol/blockfetch/server.go @@ -47,6 +47,7 @@ func (s *Server) initProtocol() { Name: ProtocolName, ProtocolId: ProtocolId, Muxer: s.protoOptions.Muxer, + Logger: s.protoOptions.Logger, ErrorChan: s.protoOptions.ErrorChan, Mode: s.protoOptions.Mode, Role: protocol.ProtocolRoleServer, @@ -59,16 +60,22 @@ func (s *Server) initProtocol() { } func (s *Server) NoBlocks() error { + s.Protocol.Logger(). + Debug(fmt.Sprintf("server called %s NoBlocks()", ProtocolName)) msg := NewMsgNoBlocks() return s.SendMessage(msg) } func (s *Server) StartBatch() error { + s.Protocol.Logger(). + Debug(fmt.Sprintf("server called %s StartBatch()", ProtocolName)) msg := NewMsgStartBatch() return s.SendMessage(msg) } func (s *Server) Block(blockType uint, blockData []byte) error { + s.Protocol.Logger(). + Debug(fmt.Sprintf("server called %s Block(blockType: %+v, blockData: %x)", ProtocolName, blockType, blockData)) wrappedBlock := WrappedBlock{ Type: blockType, RawBlock: blockData, @@ -82,11 +89,15 @@ func (s *Server) Block(blockType uint, blockData []byte) error { } func (s *Server) BatchDone() error { + s.Protocol.Logger(). + Debug(fmt.Sprintf("server called %s BatchDone()", ProtocolName)) msg := NewMsgBatchDone() return s.SendMessage(msg) } func (s *Server) messageHandler(msg protocol.Message) error { + s.Protocol.Logger(). + Debug(fmt.Sprintf("handling server message for %s", ProtocolName)) var err error switch msg.Type() { case MessageTypeRequestRange: @@ -104,6 +115,8 @@ func (s *Server) messageHandler(msg protocol.Message) error { } func (s *Server) handleRequestRange(msg protocol.Message) error { + s.Protocol.Logger(). + Debug(fmt.Sprintf("handling server request range for %s", ProtocolName)) if s.config == nil || s.config.RequestRangeFunc == nil { return fmt.Errorf( "received block-fetch RequestRange message but no callback function is defined", @@ -118,6 +131,8 @@ func (s *Server) handleRequestRange(msg protocol.Message) error { } func (s *Server) handleClientDone() error { + s.Protocol.Logger(). + Debug(fmt.Sprintf("handling server client done for %s", ProtocolName)) // Restart protocol s.Protocol.Stop() s.initProtocol() diff --git a/protocol/chainsync/client.go b/protocol/chainsync/client.go index 85a11b8f..77e3b205 100644 --- a/protocol/chainsync/client.go +++ b/protocol/chainsync/client.go @@ -100,6 +100,7 @@ func NewClient( Name: ProtocolName, ProtocolId: ProtocolId, Muxer: protoOptions.Muxer, + Logger: protoOptions.Logger, ErrorChan: protoOptions.ErrorChan, Mode: protoOptions.Mode, Role: protocol.ProtocolRoleClient, @@ -115,6 +116,8 @@ func NewClient( func (c *Client) Start() { c.onceStart.Do(func() { + c.Protocol.Logger(). + Debug(fmt.Sprintf("starting protocol: %s", ProtocolName)) c.Protocol.Start() // Start goroutine to cleanup resources on protocol shutdown go func() { @@ -124,33 +127,12 @@ func (c *Client) Start() { }) } -func (c *Client) messageHandler(msg protocol.Message) error { - var err error - switch msg.Type() { - case MessageTypeAwaitReply: - err = c.handleAwaitReply() - case MessageTypeRollForward: - err = c.handleRollForward(msg) - case MessageTypeRollBackward: - err = c.handleRollBackward(msg) - case MessageTypeIntersectFound: - err = c.handleIntersectFound(msg) - case MessageTypeIntersectNotFound: - err = c.handleIntersectNotFound(msg) - default: - err = fmt.Errorf( - "%s: received unexpected message type %d", - ProtocolName, - msg.Type(), - ) - } - return err -} - // Stop transitions the protocol to the Done state. No more protocol operations will be possible afterward func (c *Client) Stop() error { var err error c.onceStop.Do(func() { + c.Protocol.Logger(). + Debug(fmt.Sprintf("stopping protocol: %s", ProtocolName)) c.busyMutex.Lock() defer c.busyMutex.Unlock() msg := NewMsgDone() @@ -163,6 +145,8 @@ func (c *Client) Stop() error { // GetCurrentTip returns the current chain tip func (c *Client) GetCurrentTip() (*Tip, error) { + c.Protocol.Logger(). + Debug(fmt.Sprintf("client called %s GetCurrentTip()", ProtocolName)) done := atomic.Bool{} requestResultChan := make(chan Tip, 1) requestErrorChan := make(chan error, 1) @@ -220,6 +204,8 @@ func (c *Client) GetCurrentTip() (*Tip, error) { func (c *Client) GetAvailableBlockRange( intersectPoints []common.Point, ) (common.Point, common.Point, error) { + c.Protocol.Logger(). + Debug(fmt.Sprintf("client called %s GetAvailableBlockRange(intersectPoints: %+v)", ProtocolName, intersectPoints)) c.busyMutex.Lock() defer c.busyMutex.Unlock() @@ -293,6 +279,8 @@ func (c *Client) GetAvailableBlockRange( // Sync begins a chain-sync operation using the provided intersect point(s). Incoming blocks will be delivered // via the RollForward callback function specified in the protocol config func (c *Client) Sync(intersectPoints []common.Point) error { + c.Protocol.Logger(). + Debug(fmt.Sprintf("client called %s Sync(intersectPoints: %+v)", ProtocolName, intersectPoints)) c.busyMutex.Lock() defer c.busyMutex.Unlock() // Use origin if no intersect points were specified @@ -441,11 +429,40 @@ func (c *Client) requestFindIntersect( } } +func (c *Client) messageHandler(msg protocol.Message) error { + c.Protocol.Logger(). + Debug(fmt.Sprintf("handling client message for %s", ProtocolName)) + var err error + switch msg.Type() { + case MessageTypeAwaitReply: + err = c.handleAwaitReply() + case MessageTypeRollForward: + err = c.handleRollForward(msg) + case MessageTypeRollBackward: + err = c.handleRollBackward(msg) + case MessageTypeIntersectFound: + err = c.handleIntersectFound(msg) + case MessageTypeIntersectNotFound: + err = c.handleIntersectNotFound(msg) + default: + err = fmt.Errorf( + "%s: received unexpected message type %d", + ProtocolName, + msg.Type(), + ) + } + return err +} + func (c *Client) handleAwaitReply() error { + c.Protocol.Logger(). + Debug(fmt.Sprintf("handling client await reply for %s", ProtocolName)) return nil } func (c *Client) handleRollForward(msgGeneric protocol.Message) error { + c.Protocol.Logger(). + Debug(fmt.Sprintf("handling client roll forward for %s", ProtocolName)) firstBlockChan := func() chan<- clientPointResult { select { case ch := <-c.wantFirstBlockChan: @@ -554,6 +571,8 @@ func (c *Client) handleRollForward(msgGeneric protocol.Message) error { } func (c *Client) handleRollBackward(msg protocol.Message) error { + c.Protocol.Logger(). + Debug(fmt.Sprintf("handling client roll backward for %s", ProtocolName)) msgRollBackward := msg.(*MsgRollBackward) c.sendCurrentTip(msgRollBackward.Tip) if len(c.wantFirstBlockChan) == 0 { @@ -579,6 +598,8 @@ func (c *Client) handleRollBackward(msg protocol.Message) error { } func (c *Client) handleIntersectFound(msg protocol.Message) error { + c.Protocol.Logger(). + Debug(fmt.Sprintf("handling client intersect found for %s", ProtocolName)) msgIntersectFound := msg.(*MsgIntersectFound) c.sendCurrentTip(msgIntersectFound.Tip) @@ -591,6 +612,8 @@ func (c *Client) handleIntersectFound(msg protocol.Message) error { } func (c *Client) handleIntersectNotFound(msgGeneric protocol.Message) error { + c.Protocol.Logger(). + Debug(fmt.Sprintf("handling client intersect not found for %s", ProtocolName)) msgIntersectNotFound := msgGeneric.(*MsgIntersectNotFound) c.sendCurrentTip(msgIntersectNotFound.Tip) diff --git a/protocol/chainsync/server.go b/protocol/chainsync/server.go index e912de2a..b9e81ebe 100644 --- a/protocol/chainsync/server.go +++ b/protocol/chainsync/server.go @@ -64,6 +64,7 @@ func (s *Server) initProtocol() { Name: ProtocolName, ProtocolId: ProtocolId, Muxer: s.protoOptions.Muxer, + Logger: s.protoOptions.Logger, ErrorChan: s.protoOptions.ErrorChan, Mode: s.protoOptions.Mode, Role: protocol.ProtocolRoleServer, @@ -77,16 +78,22 @@ func (s *Server) initProtocol() { } func (s *Server) RollBackward(point common.Point, tip Tip) error { + s.Protocol.Logger(). + Debug(fmt.Sprintf("server called %s RollBackward(point: %+v, tip: %+v)", ProtocolName, point, tip)) msg := NewMsgRollBackward(point, tip) return s.SendMessage(msg) } func (s *Server) AwaitReply() error { + s.Protocol.Logger(). + Debug(fmt.Sprintf("server called %s AwaitReply()", ProtocolName)) msg := NewMsgAwaitReply() return s.SendMessage(msg) } func (s *Server) RollForward(blockType uint, blockData []byte, tip Tip) error { + s.Protocol.Logger(). + Debug(fmt.Sprintf("server called %s Rollforward(blockType: %+v, blockData: %x, tip: %+v)", ProtocolName, blockType, blockData, tip)) if s.Mode() == protocol.ProtocolModeNodeToNode { eraId := ledger.BlockToBlockHeaderTypeMap[blockType] msg := NewMsgRollForwardNtN( @@ -107,6 +114,8 @@ func (s *Server) RollForward(blockType uint, blockData []byte, tip Tip) error { } func (s *Server) messageHandler(msg protocol.Message) error { + s.Protocol.Logger(). + Debug(fmt.Sprintf("handling server message for %s", ProtocolName)) var err error switch msg.Type() { case MessageTypeRequestNext: @@ -126,6 +135,10 @@ func (s *Server) messageHandler(msg protocol.Message) error { } func (s *Server) handleRequestNext() error { + // TODO: figure out why this one log message causes a panic (and only this one) + // during tests + // s.Protocol.Logger(). + // Debug(fmt.Sprintf("handling server request next for %s", ProtocolName)) if s.config == nil || s.config.RequestNextFunc == nil { return fmt.Errorf( "received chain-sync RequestNext message but no callback function is defined", @@ -135,6 +148,8 @@ func (s *Server) handleRequestNext() error { } func (s *Server) handleFindIntersect(msg protocol.Message) error { + s.Protocol.Logger(). + Debug(fmt.Sprintf("handling server find intersect for %s", ProtocolName)) if s.config == nil || s.config.FindIntersectFunc == nil { return fmt.Errorf( "received chain-sync FindIntersect message but no callback function is defined", @@ -163,6 +178,8 @@ func (s *Server) handleFindIntersect(msg protocol.Message) error { } func (s *Server) handleDone() error { + s.Protocol.Logger(). + Debug(fmt.Sprintf("handling server done for %s", ProtocolName)) // Restart protocol s.Protocol.Stop() s.initProtocol() diff --git a/protocol/handshake/client.go b/protocol/handshake/client.go index 5c949e2b..ca0eaa44 100644 --- a/protocol/handshake/client.go +++ b/protocol/handshake/client.go @@ -53,6 +53,7 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { Name: ProtocolName, ProtocolId: ProtocolId, Muxer: protoOptions.Muxer, + Logger: protoOptions.Logger, ErrorChan: protoOptions.ErrorChan, Mode: protoOptions.Mode, Role: protocol.ProtocolRoleClient, @@ -68,6 +69,8 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { // Start begins the handshake process func (c *Client) Start() { c.onceStart.Do(func() { + c.Protocol.Logger(). + Debug(fmt.Sprintf("starting protocol: %s", ProtocolName)) c.Protocol.Start() // Send our ProposeVersions message msg := NewMsgProposeVersions(c.config.ProtocolVersionMap) @@ -76,6 +79,8 @@ func (c *Client) Start() { } func (c *Client) handleMessage(msg protocol.Message) error { + c.Protocol.Logger(). + Debug(fmt.Sprintf("handling client message for %s", ProtocolName)) var err error switch msg.Type() { case MessageTypeAcceptVersion: @@ -93,6 +98,8 @@ func (c *Client) handleMessage(msg protocol.Message) error { } func (c *Client) handleAcceptVersion(msg protocol.Message) error { + c.Protocol.Logger(). + Debug(fmt.Sprintf("handling client accept version for %s", ProtocolName)) if c.config.FinishedFunc == nil { return fmt.Errorf( "received handshake AcceptVersion message but no callback function is defined", @@ -114,6 +121,8 @@ func (c *Client) handleAcceptVersion(msg protocol.Message) error { } func (c *Client) handleRefuse(msgGeneric protocol.Message) error { + c.Protocol.Logger(). + Debug(fmt.Sprintf("handling client refuse for %s", ProtocolName)) msg := msgGeneric.(*MsgRefuse) var err error switch msg.Reason[0].(uint64) { diff --git a/protocol/handshake/server.go b/protocol/handshake/server.go index 866b6477..cc26fcab 100644 --- a/protocol/handshake/server.go +++ b/protocol/handshake/server.go @@ -47,6 +47,7 @@ func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server { Name: ProtocolName, ProtocolId: ProtocolId, Muxer: protoOptions.Muxer, + Logger: protoOptions.Logger, ErrorChan: protoOptions.ErrorChan, Mode: protoOptions.Mode, Role: protocol.ProtocolRoleServer, @@ -60,6 +61,8 @@ func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server { } func (s *Server) handleMessage(msg protocol.Message) error { + s.Protocol.Logger(). + Debug(fmt.Sprintf("handling server message for %s", ProtocolName)) var err error switch msg.Type() { case MessageTypeProposeVersions: @@ -75,6 +78,8 @@ func (s *Server) handleMessage(msg protocol.Message) error { } func (s *Server) handleProposeVersions(msg protocol.Message) error { + s.Protocol.Logger(). + Debug(fmt.Sprintf("handling server propose versions for %s", ProtocolName)) if s.config.FinishedFunc == nil { return fmt.Errorf( "received handshake ProposeVersions message but no callback function is defined", diff --git a/protocol/localstatequery/client.go b/protocol/localstatequery/client.go index 9578c747..076da965 100644 --- a/protocol/localstatequery/client.go +++ b/protocol/localstatequery/client.go @@ -72,6 +72,7 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { Name: ProtocolName, ProtocolId: ProtocolId, Muxer: protoOptions.Muxer, + Logger: protoOptions.Logger, ErrorChan: protoOptions.ErrorChan, Mode: protoOptions.Mode, Role: protocol.ProtocolRoleClient, @@ -94,6 +95,8 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { func (c *Client) Start() { c.onceStart.Do(func() { + c.Protocol.Logger(). + Debug(fmt.Sprintf("starting protocol: %s", ProtocolName)) c.Protocol.Start() // Start goroutine to cleanup resources on protocol shutdown go func() { @@ -104,123 +107,10 @@ func (c *Client) Start() { }) } -func (c *Client) messageHandler(msg protocol.Message) error { - var err error - switch msg.Type() { - case MessageTypeAcquired: - err = c.handleAcquired() - case MessageTypeFailure: - err = c.handleFailure(msg) - case MessageTypeResult: - err = c.handleResult(msg) - default: - err = fmt.Errorf( - "%s: received unexpected message type %d", - ProtocolName, - msg.Type(), - ) - } - return err -} - -func (c *Client) handleAcquired() error { - c.acquired = true - c.acquireResultChan <- nil - c.currentEra = -1 - return nil -} - -func (c *Client) handleFailure(msg protocol.Message) error { - msgFailure := msg.(*MsgFailure) - switch msgFailure.Failure { - case AcquireFailurePointTooOld: - c.acquireResultChan <- AcquireFailurePointTooOldError{} - case AcquireFailurePointNotOnChain: - c.acquireResultChan <- AcquireFailurePointNotOnChainError{} - default: - return fmt.Errorf("unknown failure type: %d", msgFailure.Failure) - } - return nil -} - -func (c *Client) handleResult(msg protocol.Message) error { - msgResult := msg.(*MsgResult) - c.queryResultChan <- msgResult.Result - return nil -} - -func (c *Client) acquire(point *common.Point) error { - var msg protocol.Message - if c.acquired { - if point != nil { - msg = NewMsgReAcquire(*point) - } else { - msg = NewMsgReAcquireNoPoint() - } - } else { - if point != nil { - msg = NewMsgAcquire(*point) - } else { - msg = NewMsgAcquireNoPoint() - } - } - if err := c.SendMessage(msg); err != nil { - return err - } - err, ok := <-c.acquireResultChan - if !ok { - return protocol.ProtocolShuttingDownError - } - return err -} - -func (c *Client) release() error { - msg := NewMsgRelease() - if err := c.SendMessage(msg); err != nil { - return err - } - c.acquired = false - c.currentEra = -1 - return nil -} - -func (c *Client) runQuery(query interface{}, result interface{}) error { - msg := NewMsgQuery(query) - if !c.acquired { - if err := c.acquire(nil); err != nil { - return err - } - } - if err := c.SendMessage(msg); err != nil { - return err - } - resultCbor, ok := <-c.queryResultChan - if !ok { - return protocol.ProtocolShuttingDownError - } - if _, err := cbor.Decode(resultCbor, result); err != nil { - return err - } - return nil -} - -// Helper function for getting the current era -// The current era is needed for many other queries -func (c *Client) getCurrentEra() (int, error) { - // Return cached era, if available - if c.currentEra > -1 { - return c.currentEra, nil - } - query := buildHardForkQuery(QueryTypeHardForkCurrentEra) - var result int - if err := c.runQuery(query, &result); err != nil { - return -1, err - } - return result, nil -} - // Acquire starts the acquire process for the specified chain point func (c *Client) Acquire(point *common.Point) error { + c.Protocol.Logger(). + Debug(fmt.Sprintf("client called %s Acquire(point: %+v)", ProtocolName, point)) c.busyMutex.Lock() defer c.busyMutex.Unlock() return c.acquire(point) @@ -228,6 +118,8 @@ func (c *Client) Acquire(point *common.Point) error { // Release releases the previously acquired chain point func (c *Client) Release() error { + c.Protocol.Logger(). + Debug(fmt.Sprintf("client called %s Release()", ProtocolName)) c.busyMutex.Lock() defer c.busyMutex.Unlock() return c.release() @@ -235,6 +127,8 @@ func (c *Client) Release() error { // GetCurrentEra returns the current era ID func (c *Client) GetCurrentEra() (int, error) { + c.Protocol.Logger(). + Debug(fmt.Sprintf("client called %s GetCurrentEra()", ProtocolName)) c.busyMutex.Lock() defer c.busyMutex.Unlock() return c.getCurrentEra() @@ -242,6 +136,8 @@ func (c *Client) GetCurrentEra() (int, error) { // GetSystemStart returns the SystemStart value func (c *Client) GetSystemStart() (*SystemStartResult, error) { + c.Protocol.Logger(). + Debug(fmt.Sprintf("client called %s GetSystemStart()", ProtocolName)) c.busyMutex.Lock() defer c.busyMutex.Unlock() query := buildQuery( @@ -256,6 +152,8 @@ func (c *Client) GetSystemStart() (*SystemStartResult, error) { // GetChainBlockNo returns the latest block number func (c *Client) GetChainBlockNo() (int64, error) { + c.Protocol.Logger(). + Debug(fmt.Sprintf("client called %s GetChainBlockNo()", ProtocolName)) c.busyMutex.Lock() defer c.busyMutex.Unlock() query := buildQuery( @@ -270,6 +168,8 @@ func (c *Client) GetChainBlockNo() (int64, error) { // GetChainPoint returns the current chain tip func (c *Client) GetChainPoint() (*common.Point, error) { + c.Protocol.Logger(). + Debug(fmt.Sprintf("client called %s GetChainPoint()", ProtocolName)) c.busyMutex.Lock() defer c.busyMutex.Unlock() query := buildQuery( @@ -284,6 +184,8 @@ func (c *Client) GetChainPoint() (*common.Point, error) { // GetEraHistory returns the era history func (c *Client) GetEraHistory() ([]EraHistoryResult, error) { + c.Protocol.Logger(). + Debug(fmt.Sprintf("client called %s GetEraHistory()", ProtocolName)) c.busyMutex.Lock() defer c.busyMutex.Unlock() query := buildHardForkQuery(QueryTypeHardForkEraHistory) @@ -296,6 +198,8 @@ func (c *Client) GetEraHistory() ([]EraHistoryResult, error) { // GetEpochNo returns the current epoch number func (c *Client) GetEpochNo() (int, error) { + c.Protocol.Logger(). + Debug(fmt.Sprintf("client called %s GetEpochNo()", ProtocolName)) c.busyMutex.Lock() defer c.busyMutex.Unlock() currentEra, err := c.getCurrentEra() @@ -318,6 +222,8 @@ func (c *Client) GetEpochNo() (int, error) { query [2 #6.258([*[0 int]]) int is the stake the user intends to delegate, the array must be sorted */ func (c *Client) GetNonMyopicMemberRewards() (*NonMyopicMemberRewardsResult, error) { + c.Protocol.Logger(). + Debug(fmt.Sprintf("client called %s GetNonMyopicMemberRewards()", ProtocolName)) c.busyMutex.Lock() defer c.busyMutex.Unlock() currentEra, err := c.getCurrentEra() @@ -337,6 +243,8 @@ func (c *Client) GetNonMyopicMemberRewards() (*NonMyopicMemberRewardsResult, err // GetCurrentProtocolParams returns the set of protocol params that are currently in effect func (c *Client) GetCurrentProtocolParams() (CurrentProtocolParamsResult, error) { + c.Protocol.Logger(). + Debug(fmt.Sprintf("client called %s GetCurrentProtocolParams()", ProtocolName)) c.busyMutex.Lock() defer c.busyMutex.Unlock() currentEra, err := c.getCurrentEra() @@ -389,7 +297,10 @@ func (c *Client) GetCurrentProtocolParams() (CurrentProtocolParamsResult, error) } } +// GetProposedProtocolParamsUpdates returns the set of proposed protocol params updates func (c *Client) GetProposedProtocolParamsUpdates() (*ProposedProtocolParamsUpdatesResult, error) { + c.Protocol.Logger(). + Debug(fmt.Sprintf("client called %s GetProposedProtocolParamsUpdates()", ProtocolName)) c.busyMutex.Lock() defer c.busyMutex.Unlock() currentEra, err := c.getCurrentEra() @@ -410,6 +321,8 @@ func (c *Client) GetProposedProtocolParamsUpdates() (*ProposedProtocolParamsUpda // GetStakeDistribution returns the stake distribution func (c *Client) GetStakeDistribution() (*StakeDistributionResult, error) { + c.Protocol.Logger(). + Debug(fmt.Sprintf("client called %s GetStakeDistribution()", ProtocolName)) c.busyMutex.Lock() defer c.busyMutex.Unlock() currentEra, err := c.getCurrentEra() @@ -427,9 +340,12 @@ func (c *Client) GetStakeDistribution() (*StakeDistributionResult, error) { return &result, nil } +// GetUTxOByAddress returns the UTxOs for a given list of ledger.Address structs func (c *Client) GetUTxOByAddress( addrs []ledger.Address, ) (*UTxOByAddressResult, error) { + c.Protocol.Logger(). + Debug(fmt.Sprintf("client called %s GetUTxOByAddress(addrs: %+v)", ProtocolName, addrs)) c.busyMutex.Lock() defer c.busyMutex.Unlock() currentEra, err := c.getCurrentEra() @@ -448,7 +364,10 @@ func (c *Client) GetUTxOByAddress( return &result, nil } +// GetUTxOWhole returns the current UTxO set func (c *Client) GetUTxOWhole() (*UTxOWholeResult, error) { + c.Protocol.Logger(). + Debug(fmt.Sprintf("client called %s GetUTxOWhole()", ProtocolName)) c.busyMutex.Lock() defer c.busyMutex.Unlock() currentEra, err := c.getCurrentEra() @@ -468,6 +387,8 @@ func (c *Client) GetUTxOWhole() (*UTxOWholeResult, error) { // TODO func (c *Client) DebugEpochState() (*DebugEpochStateResult, error) { + c.Protocol.Logger(). + Debug(fmt.Sprintf("client called %s DebugEpochState()", ProtocolName)) c.busyMutex.Lock() defer c.busyMutex.Unlock() currentEra, err := c.getCurrentEra() @@ -492,6 +413,8 @@ query [10 #6.258([ *rwdr ])] func (c *Client) GetFilteredDelegationsAndRewardAccounts( creds []interface{}, ) (*FilteredDelegationsAndRewardAccountsResult, error) { + c.Protocol.Logger(). + Debug(fmt.Sprintf("client called %s GetFilteredDelegationsAndRewardAccounts(creds: %+v)", ProtocolName, creds)) c.busyMutex.Lock() defer c.busyMutex.Unlock() currentEra, err := c.getCurrentEra() @@ -511,6 +434,8 @@ func (c *Client) GetFilteredDelegationsAndRewardAccounts( } func (c *Client) GetGenesisConfig() (*GenesisConfigResult, error) { + c.Protocol.Logger(). + Debug(fmt.Sprintf("client called %s GetGenesisConfig()", ProtocolName)) c.busyMutex.Lock() defer c.busyMutex.Unlock() currentEra, err := c.getCurrentEra() @@ -530,6 +455,8 @@ func (c *Client) GetGenesisConfig() (*GenesisConfigResult, error) { // TODO func (c *Client) DebugNewEpochState() (*DebugNewEpochStateResult, error) { + c.Protocol.Logger(). + Debug(fmt.Sprintf("client called %s DebugNewEpochState()", ProtocolName)) c.busyMutex.Lock() defer c.busyMutex.Unlock() currentEra, err := c.getCurrentEra() @@ -549,6 +476,8 @@ func (c *Client) DebugNewEpochState() (*DebugNewEpochStateResult, error) { // TODO func (c *Client) DebugChainDepState() (*DebugChainDepStateResult, error) { + c.Protocol.Logger(). + Debug(fmt.Sprintf("client called %s DebugChainDepState()", ProtocolName)) c.busyMutex.Lock() defer c.busyMutex.Unlock() currentEra, err := c.getCurrentEra() @@ -567,6 +496,8 @@ func (c *Client) DebugChainDepState() (*DebugChainDepStateResult, error) { } func (c *Client) GetRewardProvenance() (*RewardProvenanceResult, error) { + c.Protocol.Logger(). + Debug(fmt.Sprintf("client called %s GetRewardProvenance()", ProtocolName)) c.busyMutex.Lock() defer c.busyMutex.Unlock() currentEra, err := c.getCurrentEra() @@ -587,6 +518,8 @@ func (c *Client) GetRewardProvenance() (*RewardProvenanceResult, error) { func (c *Client) GetUTxOByTxIn( txIns []ledger.TransactionInput, ) (*UTxOByTxInResult, error) { + c.Protocol.Logger(). + Debug(fmt.Sprintf("client called %s GetUTxOByTxIn(txIns: %+v)", ProtocolName, txIns)) c.busyMutex.Lock() defer c.busyMutex.Unlock() currentEra, err := c.getCurrentEra() @@ -606,6 +539,8 @@ func (c *Client) GetUTxOByTxIn( } func (c *Client) GetStakePools() (*StakePoolsResult, error) { + c.Protocol.Logger(). + Debug(fmt.Sprintf("client called %s GetStakePools()", ProtocolName)) c.busyMutex.Lock() defer c.busyMutex.Unlock() currentEra, err := c.getCurrentEra() @@ -626,6 +561,8 @@ func (c *Client) GetStakePools() (*StakePoolsResult, error) { func (c *Client) GetStakePoolParams( poolIds []ledger.PoolId, ) (*StakePoolParamsResult, error) { + c.Protocol.Logger(). + Debug(fmt.Sprintf("client called %s GetStakePoolParams(poolIds: %+v)", ProtocolName, poolIds)) c.busyMutex.Lock() defer c.busyMutex.Unlock() currentEra, err := c.getCurrentEra() @@ -649,6 +586,8 @@ func (c *Client) GetStakePoolParams( // TODO func (c *Client) GetRewardInfoPools() (*RewardInfoPoolsResult, error) { + c.Protocol.Logger(). + Debug(fmt.Sprintf("client called %s GetRewardInfoPools()", ProtocolName)) c.busyMutex.Lock() defer c.busyMutex.Unlock() currentEra, err := c.getCurrentEra() @@ -668,6 +607,8 @@ func (c *Client) GetRewardInfoPools() (*RewardInfoPoolsResult, error) { // TODO func (c *Client) GetPoolState(poolIds []interface{}) (*PoolStateResult, error) { + c.Protocol.Logger(). + Debug(fmt.Sprintf("client called %s GetPoolState(poolIds: %+v)", ProtocolName, poolIds)) c.busyMutex.Lock() defer c.busyMutex.Unlock() currentEra, err := c.getCurrentEra() @@ -689,6 +630,8 @@ func (c *Client) GetPoolState(poolIds []interface{}) (*PoolStateResult, error) { func (c *Client) GetStakeSnapshots( poolId interface{}, ) (*StakeSnapshotsResult, error) { + c.Protocol.Logger(). + Debug(fmt.Sprintf("client called %s GetStakeSnapshots(poolId: %+v)", ProtocolName, poolId)) c.busyMutex.Lock() defer c.busyMutex.Unlock() currentEra, err := c.getCurrentEra() @@ -708,6 +651,8 @@ func (c *Client) GetStakeSnapshots( // TODO func (c *Client) GetPoolDistr(poolIds []interface{}) (*PoolDistrResult, error) { + c.Protocol.Logger(). + Debug(fmt.Sprintf("client called %s GetPoolDistr(poolIds: %+v)", ProtocolName, poolIds)) c.busyMutex.Lock() defer c.busyMutex.Unlock() currentEra, err := c.getCurrentEra() @@ -724,3 +669,126 @@ func (c *Client) GetPoolDistr(poolIds []interface{}) (*PoolDistrResult, error) { } return &result, nil } + +func (c *Client) messageHandler(msg protocol.Message) error { + c.Protocol.Logger(). + Debug(fmt.Sprintf("handling client message for %s", ProtocolName)) + var err error + switch msg.Type() { + case MessageTypeAcquired: + err = c.handleAcquired() + case MessageTypeFailure: + err = c.handleFailure(msg) + case MessageTypeResult: + err = c.handleResult(msg) + default: + err = fmt.Errorf( + "%s: received unexpected message type %d", + ProtocolName, + msg.Type(), + ) + } + return err +} + +func (c *Client) handleAcquired() error { + c.Protocol.Logger(). + Debug(fmt.Sprintf("handling client acquired for %s", ProtocolName)) + c.acquired = true + c.acquireResultChan <- nil + c.currentEra = -1 + return nil +} + +func (c *Client) handleFailure(msg protocol.Message) error { + c.Protocol.Logger(). + Debug(fmt.Sprintf("handling client failure for %s", ProtocolName)) + msgFailure := msg.(*MsgFailure) + switch msgFailure.Failure { + case AcquireFailurePointTooOld: + c.acquireResultChan <- AcquireFailurePointTooOldError{} + case AcquireFailurePointNotOnChain: + c.acquireResultChan <- AcquireFailurePointNotOnChainError{} + default: + return fmt.Errorf("unknown failure type: %d", msgFailure.Failure) + } + return nil +} + +func (c *Client) handleResult(msg protocol.Message) error { + c.Protocol.Logger(). + Debug(fmt.Sprintf("handling client result for %s", ProtocolName)) + msgResult := msg.(*MsgResult) + c.queryResultChan <- msgResult.Result + return nil +} + +func (c *Client) acquire(point *common.Point) error { + var msg protocol.Message + if c.acquired { + if point != nil { + msg = NewMsgReAcquire(*point) + } else { + msg = NewMsgReAcquireNoPoint() + } + } else { + if point != nil { + msg = NewMsgAcquire(*point) + } else { + msg = NewMsgAcquireNoPoint() + } + } + if err := c.SendMessage(msg); err != nil { + return err + } + err, ok := <-c.acquireResultChan + if !ok { + return protocol.ProtocolShuttingDownError + } + return err +} + +func (c *Client) release() error { + msg := NewMsgRelease() + if err := c.SendMessage(msg); err != nil { + return err + } + c.acquired = false + c.currentEra = -1 + return nil +} + +func (c *Client) runQuery(query interface{}, result interface{}) error { + msg := NewMsgQuery(query) + if !c.acquired { + if err := c.acquire(nil); err != nil { + return err + } + } + if err := c.SendMessage(msg); err != nil { + return err + } + resultCbor, ok := <-c.queryResultChan + if !ok { + return protocol.ProtocolShuttingDownError + } + if _, err := cbor.Decode(resultCbor, result); err != nil { + return err + } + return nil +} + +// Helper function for getting the current era +// The current era is needed for many other queries +func (c *Client) getCurrentEra() (int, error) { + // Return cached era, if available + if c.currentEra > -1 { + return c.currentEra, nil + } + query := buildHardForkQuery(QueryTypeHardForkCurrentEra) + var result int + if err := c.runQuery(query, &result); err != nil { + return -1, err + } + return result, nil +} diff --git a/protocol/localstatequery/server.go b/protocol/localstatequery/server.go index 23c0ea05..27d92635 100644 --- a/protocol/localstatequery/server.go +++ b/protocol/localstatequery/server.go @@ -43,6 +43,7 @@ func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server { Name: ProtocolName, ProtocolId: ProtocolId, Muxer: protoOptions.Muxer, + Logger: protoOptions.Logger, ErrorChan: protoOptions.ErrorChan, Mode: protoOptions.Mode, Role: protocol.ProtocolRoleServer, @@ -64,6 +65,8 @@ func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server { } func (s *Server) messageHandler(msg protocol.Message) error { + s.Protocol.Logger(). + Debug(fmt.Sprintf("handling server message for %s", ProtocolName)) var err error switch msg.Type() { case MessageTypeAcquire: @@ -91,6 +94,8 @@ func (s *Server) messageHandler(msg protocol.Message) error { } func (s *Server) handleAcquire(msg protocol.Message) error { + s.Protocol.Logger(). + Debug(fmt.Sprintf("handling server acquire for %s", ProtocolName)) if s.config.AcquireFunc == nil { return fmt.Errorf( "received local-state-query Acquire message but no callback function is defined", @@ -108,6 +113,8 @@ func (s *Server) handleAcquire(msg protocol.Message) error { } func (s *Server) handleQuery(msg protocol.Message) error { + s.Protocol.Logger(). + Debug(fmt.Sprintf("handling server query for %s", ProtocolName)) if s.config.QueryFunc == nil { return fmt.Errorf( "received local-state-query Query message but no callback function is defined", @@ -119,6 +126,8 @@ func (s *Server) handleQuery(msg protocol.Message) error { } func (s *Server) handleRelease() error { + s.Protocol.Logger(). + Debug(fmt.Sprintf("handling server release for %s", ProtocolName)) if s.config.ReleaseFunc == nil { return fmt.Errorf( "received local-state-query Release message but no callback function is defined", @@ -129,6 +138,8 @@ func (s *Server) handleRelease() error { } func (s *Server) handleReAcquire(msg protocol.Message) error { + s.Protocol.Logger(). + Debug(fmt.Sprintf("handling server reacquire for %s", ProtocolName)) if s.config.ReAcquireFunc == nil { return fmt.Errorf( "received local-state-query ReAcquire message but no callback function is defined", @@ -146,6 +157,8 @@ func (s *Server) handleReAcquire(msg protocol.Message) error { } func (s *Server) handleDone() error { + s.Protocol.Logger(). + Debug(fmt.Sprintf("handling server done for %s", ProtocolName)) if s.config.DoneFunc == nil { return fmt.Errorf( "received local-state-query Done message but no callback function is defined", diff --git a/protocol/localtxmonitor/client.go b/protocol/localtxmonitor/client.go index 9b9df33a..c46b8918 100644 --- a/protocol/localtxmonitor/client.go +++ b/protocol/localtxmonitor/client.go @@ -69,6 +69,7 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { Name: ProtocolName, ProtocolId: ProtocolId, Muxer: protoOptions.Muxer, + Logger: protoOptions.Logger, ErrorChan: protoOptions.ErrorChan, Mode: protoOptions.Mode, Role: protocol.ProtocolRoleClient, @@ -83,6 +84,8 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { func (c *Client) Start() { c.onceStart.Do(func() { + c.Protocol.Logger(). + Debug(fmt.Sprintf("starting protocol: %s", ProtocolName)) c.Protocol.Start() // Start goroutine to cleanup resources on protocol shutdown go func() { @@ -95,51 +98,26 @@ func (c *Client) Start() { }) } -func (c *Client) messageHandler(msg protocol.Message) error { +// Stop transitions the protocol to the Done state. No more operations will be possible +func (c *Client) Stop() error { var err error - switch msg.Type() { - case MessageTypeAcquired: - err = c.handleAcquired(msg) - case MessageTypeReplyHasTx: - err = c.handleReplyHasTx(msg) - case MessageTypeReplyNextTx: - err = c.handleReplyNextTx(msg) - case MessageTypeReplyGetSizes: - err = c.handleReplyGetSizes(msg) - default: - err = fmt.Errorf( - "%s: received unexpected message type %d", - ProtocolName, - msg.Type(), - ) - } + c.onceStop.Do(func() { + c.Protocol.Logger(). + Debug(fmt.Sprintf("stopping protocol: %s", ProtocolName)) + c.busyMutex.Lock() + defer c.busyMutex.Unlock() + msg := NewMsgDone() + if err = c.SendMessage(msg); err != nil { + return + } + }) return err } -func (c *Client) acquire() error { - msg := NewMsgAcquire() - if err := c.SendMessage(msg); err != nil { - return err - } - // Wait for reply - _, ok := <-c.acquireResultChan - if !ok { - return protocol.ProtocolShuttingDownError - } - return nil -} - -func (c *Client) release() error { - msg := NewMsgRelease() - if err := c.SendMessage(msg); err != nil { - return err - } - c.acquired = false - return nil -} - // Acquire starts the acquire process for a current mempool snapshot func (c *Client) Acquire() error { + c.Protocol.Logger(). + Debug(fmt.Sprintf("client called %s Acquire()", ProtocolName)) c.busyMutex.Lock() defer c.busyMutex.Unlock() return c.acquire() @@ -147,27 +125,17 @@ func (c *Client) Acquire() error { // Release releases the previously acquired mempool snapshot func (c *Client) Release() error { + c.Protocol.Logger(). + Debug(fmt.Sprintf("client called %s Release()", ProtocolName)) c.busyMutex.Lock() defer c.busyMutex.Unlock() return c.release() } -// Stop transitions the protocol to the Done state. No more operations will be possible -func (c *Client) Stop() error { - var err error - c.onceStop.Do(func() { - c.busyMutex.Lock() - defer c.busyMutex.Unlock() - msg := NewMsgDone() - if err = c.SendMessage(msg); err != nil { - return - } - }) - return err -} - // HasTx returns whether or not the specified transaction ID exists in the mempool snapshot func (c *Client) HasTx(txId []byte) (bool, error) { + c.Protocol.Logger(). + Debug(fmt.Sprintf("client called %s HasTx(txId: %x)", ProtocolName, txId)) c.busyMutex.Lock() defer c.busyMutex.Unlock() if !c.acquired { @@ -188,6 +156,8 @@ func (c *Client) HasTx(txId []byte) (bool, error) { // NextTx returns the next transaction in the mempool snapshot func (c *Client) NextTx() ([]byte, error) { + c.Protocol.Logger(). + Debug(fmt.Sprintf("client called %s NextTx()", ProtocolName)) c.busyMutex.Lock() defer c.busyMutex.Unlock() if !c.acquired { @@ -208,6 +178,8 @@ func (c *Client) NextTx() ([]byte, error) { // GetSizes returns the capacity (in bytes), size (in bytes), and number of transactions in the mempool snapshot func (c *Client) GetSizes() (uint32, uint32, uint32, error) { + c.Protocol.Logger(). + Debug(fmt.Sprintf("client called %s GetSizes()", ProtocolName)) c.busyMutex.Lock() defer c.busyMutex.Unlock() if !c.acquired { @@ -226,7 +198,32 @@ func (c *Client) GetSizes() (uint32, uint32, uint32, error) { return result.Capacity, result.Size, result.NumberOfTxs, nil } +func (c *Client) messageHandler(msg protocol.Message) error { + c.Protocol.Logger(). + Debug(fmt.Sprintf("handling client message for %s", ProtocolName)) + var err error + switch msg.Type() { + case MessageTypeAcquired: + err = c.handleAcquired(msg) + case MessageTypeReplyHasTx: + err = c.handleReplyHasTx(msg) + case MessageTypeReplyNextTx: + err = c.handleReplyNextTx(msg) + case MessageTypeReplyGetSizes: + err = c.handleReplyGetSizes(msg) + default: + err = fmt.Errorf( + "%s: received unexpected message type %d", + ProtocolName, + msg.Type(), + ) + } + return err +} + func (c *Client) handleAcquired(msg protocol.Message) error { + c.Protocol.Logger(). + Debug(fmt.Sprintf("handling client acquired for %s", ProtocolName)) msgAcquired := msg.(*MsgAcquired) c.acquired = true c.acquiredSlot = msgAcquired.SlotNo @@ -235,19 +232,47 @@ func (c *Client) handleAcquired(msg protocol.Message) error { } func (c *Client) handleReplyHasTx(msg protocol.Message) error { + c.Protocol.Logger(). + Debug(fmt.Sprintf("handling client reply has tx for %s", ProtocolName)) msgReplyHasTx := msg.(*MsgReplyHasTx) c.hasTxResultChan <- msgReplyHasTx.Result return nil } func (c *Client) handleReplyNextTx(msg protocol.Message) error { + c.Protocol.Logger(). + Debug(fmt.Sprintf("handling client reply next tx for %s", ProtocolName)) msgReplyNextTx := msg.(*MsgReplyNextTx) c.nextTxResultChan <- msgReplyNextTx.Transaction.Tx return nil } func (c *Client) handleReplyGetSizes(msg protocol.Message) error { + c.Protocol.Logger(). + Debug(fmt.Sprintf("handling client reply get sizes for %s", ProtocolName)) msgReplyGetSizes := msg.(*MsgReplyGetSizes) c.getSizesResultChan <- msgReplyGetSizes.Result return nil } + +func (c *Client) acquire() error { + msg := NewMsgAcquire() + if err := c.SendMessage(msg); err != nil { + return err + } + // Wait for reply + _, ok := <-c.acquireResultChan + if !ok { + return protocol.ProtocolShuttingDownError + } + return nil +} + +func (c *Client) release() error { + msg := NewMsgRelease() + if err := c.SendMessage(msg); err != nil { + return err + } + c.acquired = false + return nil +} diff --git a/protocol/localtxmonitor/server.go b/protocol/localtxmonitor/server.go index 6d4cd5bf..9b7b58d8 100644 --- a/protocol/localtxmonitor/server.go +++ b/protocol/localtxmonitor/server.go @@ -45,6 +45,7 @@ func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server { Name: ProtocolName, ProtocolId: ProtocolId, Muxer: protoOptions.Muxer, + Logger: protoOptions.Logger, ErrorChan: protoOptions.ErrorChan, Mode: protoOptions.Mode, Role: protocol.ProtocolRoleServer, @@ -58,6 +59,8 @@ func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server { } func (s *Server) messageHandler(msg protocol.Message) error { + s.Protocol.Logger(). + Debug(fmt.Sprintf("handling server message for %s", ProtocolName)) var err error switch msg.Type() { case MessageTypeAcquire: @@ -83,6 +86,8 @@ func (s *Server) messageHandler(msg protocol.Message) error { } func (s *Server) handleAcquire() error { + s.Protocol.Logger(). + Debug(fmt.Sprintf("handling server acquire for %s", ProtocolName)) if s.config.GetMempoolFunc == nil { return fmt.Errorf( "received local-tx-monitor Acquire message but no GetMempool callback function is defined", @@ -122,16 +127,22 @@ func (s *Server) handleAcquire() error { } func (s *Server) handleDone() error { + s.Protocol.Logger(). + Debug(fmt.Sprintf("handling server done for %s", ProtocolName)) return nil } func (s *Server) handleRelease() error { + s.Protocol.Logger(). + Debug(fmt.Sprintf("handling server release for %s", ProtocolName)) s.mempoolCapacity = 0 s.mempoolTxs = nil return nil } func (s *Server) handleHasTx(msg protocol.Message) error { + s.Protocol.Logger(). + Debug(fmt.Sprintf("handling server has tx for %s", ProtocolName)) msgHasTx := msg.(*MsgHasTx) txId := hex.EncodeToString(msgHasTx.TxId) hasTx := false @@ -149,6 +160,8 @@ func (s *Server) handleHasTx(msg protocol.Message) error { } func (s *Server) handleNextTx() error { + s.Protocol.Logger(). + Debug(fmt.Sprintf("handling server next tx for %s", ProtocolName)) if s.mempoolNextTxIdx > len(s.mempoolTxs) { newMsg := NewMsgReplyNextTx(0, nil) if err := s.SendMessage(newMsg); err != nil { @@ -166,6 +179,8 @@ func (s *Server) handleNextTx() error { } func (s *Server) handleGetSizes() error { + s.Protocol.Logger(). + Debug(fmt.Sprintf("handling server get sizes for %s", ProtocolName)) totalTxSize := 0 for _, tx := range s.mempoolTxs { totalTxSize += len(tx.Tx) diff --git a/protocol/localtxsubmission/client.go b/protocol/localtxsubmission/client.go index 118364c7..b747bae8 100644 --- a/protocol/localtxsubmission/client.go +++ b/protocol/localtxsubmission/client.go @@ -58,6 +58,7 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { Name: ProtocolName, ProtocolId: ProtocolId, Muxer: protoOptions.Muxer, + Logger: protoOptions.Logger, ErrorChan: protoOptions.ErrorChan, Mode: protoOptions.Mode, Role: protocol.ProtocolRoleClient, @@ -72,6 +73,8 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { func (c *Client) Start() { c.onceStart.Do(func() { + c.Protocol.Logger(). + Debug(fmt.Sprintf("starting protocol: %s", ProtocolName)) c.Protocol.Start() // Start goroutine to cleanup resources on protocol shutdown go func() { @@ -81,20 +84,19 @@ func (c *Client) Start() { }) } -func (c *Client) messageHandler(msg protocol.Message) error { +// Stop transitions the protocol to the Done state. No more operations will be possible +func (c *Client) Stop() error { var err error - switch msg.Type() { - case MessageTypeAcceptTx: - err = c.handleAcceptTx() - case MessageTypeRejectTx: - err = c.handleRejectTx(msg) - default: - err = fmt.Errorf( - "%s: received unexpected message type %d", - ProtocolName, - msg.Type(), - ) - } + c.onceStop.Do(func() { + c.Protocol.Logger(). + Debug(fmt.Sprintf("stopping protocol: %s", ProtocolName)) + c.busyMutex.Lock() + defer c.busyMutex.Unlock() + msg := NewMsgDone() + if err = c.SendMessage(msg); err != nil { + return + } + }) return err } @@ -113,26 +115,35 @@ func (c *Client) SubmitTx(eraId uint16, tx []byte) error { return err } -// Stop transitions the protocol to the Done state. No more operations will be possible -func (c *Client) Stop() error { +func (c *Client) messageHandler(msg protocol.Message) error { + c.Protocol.Logger(). + Debug(fmt.Sprintf("handling client message for %s", ProtocolName)) var err error - c.onceStop.Do(func() { - c.busyMutex.Lock() - defer c.busyMutex.Unlock() - msg := NewMsgDone() - if err = c.SendMessage(msg); err != nil { - return - } - }) + switch msg.Type() { + case MessageTypeAcceptTx: + err = c.handleAcceptTx() + case MessageTypeRejectTx: + err = c.handleRejectTx(msg) + default: + err = fmt.Errorf( + "%s: received unexpected message type %d", + ProtocolName, + msg.Type(), + ) + } return err } func (c *Client) handleAcceptTx() error { + c.Protocol.Logger(). + Debug(fmt.Sprintf("handling client accept tx for %s", ProtocolName)) c.submitResultChan <- nil return nil } func (c *Client) handleRejectTx(msg protocol.Message) error { + c.Protocol.Logger(). + Debug(fmt.Sprintf("handling client reject tx for %s", ProtocolName)) msgRejectTx := msg.(*MsgRejectTx) rejectErr, err := ledger.NewTxSubmitErrorFromCbor(msgRejectTx.Reason) if err != nil { diff --git a/protocol/localtxsubmission/server.go b/protocol/localtxsubmission/server.go index 98c181fb..40357ea2 100644 --- a/protocol/localtxsubmission/server.go +++ b/protocol/localtxsubmission/server.go @@ -41,6 +41,7 @@ func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server { Name: ProtocolName, ProtocolId: ProtocolId, Muxer: protoOptions.Muxer, + Logger: protoOptions.Logger, ErrorChan: protoOptions.ErrorChan, Mode: protoOptions.Mode, Role: protocol.ProtocolRoleServer, @@ -54,6 +55,8 @@ func NewServer(protoOptions protocol.ProtocolOptions, cfg *Config) *Server { } func (s *Server) messageHandler(msg protocol.Message) error { + s.Protocol.Logger(). + Debug(fmt.Sprintf("handling server message for %s", ProtocolName)) var err error switch msg.Type() { case MessageTypeSubmitTx: @@ -71,6 +74,8 @@ func (s *Server) messageHandler(msg protocol.Message) error { } func (s *Server) handleSubmitTx(msg protocol.Message) error { + s.Protocol.Logger(). + Debug(fmt.Sprintf("handling server submit tx for %s", ProtocolName)) if s.config.SubmitTxFunc == nil { return fmt.Errorf( "received local-tx-submission SubmitTx message but no callback function is defined", @@ -98,5 +103,7 @@ func (s *Server) handleSubmitTx(msg protocol.Message) error { } func (s *Server) handleDone() error { + s.Protocol.Logger(). + Debug(fmt.Sprintf("handling server done for %s", ProtocolName)) return nil }