Skip to content

Commit

Permalink
Fix transfer leadership. Fixes #2343
Browse files Browse the repository at this point in the history
  • Loading branch information
plorenz committed Aug 21, 2024
1 parent fc56324 commit 74656fb
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 22 deletions.
2 changes: 1 addition & 1 deletion controller/handler_peer_ctrl/transfer_leadership.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ import (
raft2 "github.com/hashicorp/raft"
"github.com/michaelquigley/pfxlog"
"github.com/openziti/channel/v2"
"github.com/openziti/ziti/common/pb/cmd_pb"
"github.com/openziti/ziti/controller/peermsg"
"github.com/openziti/ziti/controller/raft"
"github.com/openziti/ziti/common/pb/cmd_pb"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"google.golang.org/protobuf/proto"
Expand Down
97 changes: 84 additions & 13 deletions controller/raft/mesh/mesh.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,40 @@ const (
RaftDataType = 2049
SigningCertHeader = 2050
ApiAddressesHeader = 2051

ChannelTypeMesh = "ctrl.mesh"
RaftDisconnectType = 2052
ChannelTypeMesh = "ctrl.mesh"
)

type Peer struct {
mesh *impl
Id raft.ServerID
Address string
Channel channel.Channel
RaftConn *raftPeerConn
RaftConn atomic.Pointer[raftPeerConn]
Version *versions.VersionInfo
SigningCerts []*x509.Certificate
ApiAddresses map[string][]event.ApiAddress
}

func (self *Peer) initRaftConn() *raftPeerConn {
self.mesh.lock.Lock()
defer self.mesh.lock.Unlock()
conn := self.RaftConn.Load()
if conn == nil {
conn = newRaftPeerConn(self, self.mesh.netAddr)
self.RaftConn.Store(conn)
}
return conn
}

func (self *Peer) HandleClose(channel.Channel) {
self.mesh.lock.Lock()
conn := self.RaftConn.Swap(nil)
if conn != nil {
conn.close()
}
self.mesh.lock.Unlock()

self.mesh.PeerDisconnected(self)
}

Expand All @@ -77,17 +95,68 @@ func (self *Peer) HandleReceive(m *channel.Message, _ channel.Channel) {
if err := response.WithTimeout(5 * time.Second).Send(self.Channel); err != nil {
logrus.WithError(err).Error("failed to send connect response")
} else {
conn := self.initRaftConn()
select {
case self.mesh.raftAccepts <- self.RaftConn:
case self.mesh.raftAccepts <- conn:
case <-self.mesh.closeNotify:
}
}
}()
}

func (self *Peer) Connect(timeout time.Duration) error {
func (self *Peer) handleReceiveDisconnect(m *channel.Message, _ channel.Channel) {
go func() {
self.mesh.lock.Lock()
conn := self.RaftConn.Swap(nil)
self.mesh.lock.Unlock()

if conn != nil {
conn.close()
}

response := channel.NewResult(true, "")
response.ReplyTo(m)

if err := response.WithTimeout(5 * time.Second).Send(self.Channel); err != nil {
logrus.WithError(err).Error("failed to send close response")
}
}()
}

func (self *Peer) handleReceiveData(m *channel.Message, ch channel.Channel) {
if conn := self.RaftConn.Load(); conn != nil {
conn.HandleReceive(m, ch)
}
}

func (self *Peer) Connect(timeout time.Duration) (net.Conn, error) {
msg := channel.NewMessage(RaftConnectType, nil)
response, err := msg.WithTimeout(timeout).SendForReply(self.Channel)
if err != nil {
return nil, err
}
result := channel.UnmarshalResult(response)
if !result.Success {
return nil, errors.Errorf("connect failed: %v", result.Message)
}

logrus.Infof("connected peer %v at %v", self.Id, self.Address)

return self.initRaftConn(), nil
}

func (self *Peer) closeRaftConn(timeout time.Duration) error {
self.mesh.lock.Lock()
conn := self.RaftConn.Swap(nil)
defer self.mesh.lock.Unlock()
if conn == nil {
return nil
}

conn.close()

msg := channel.NewMessage(RaftDisconnectType, nil)
response, err := msg.WithTimeout(timeout).SendForReply(self.Channel)
if err != nil {
return err
}
Expand All @@ -96,7 +165,7 @@ func (self *Peer) Connect(timeout time.Duration) error {
return errors.Errorf("connect failed: %v", result.Message)
}

logrus.Infof("connected peer %v at %v", self.Id, self.Address)
logrus.Infof("disconnected peer %v at %v", self.Id, self.Address)

return nil
}
Expand Down Expand Up @@ -227,10 +296,12 @@ func (self *impl) Dial(address raft.ServerAddress, timeout time.Duration) (net.C
if err != nil {
return nil, err
}
if err := peer.Connect(timeout); err != nil {
return nil, err

if peerConn := peer.RaftConn.Load(); peerConn != nil {
return peerConn, nil
}
return peer.RaftConn, nil

return peer.Connect(timeout)
}

func (self *impl) GetOrConnectPeer(address string, timeout time.Duration) (*Peer, error) {
Expand Down Expand Up @@ -310,11 +381,11 @@ func (self *impl) GetOrConnectPeer(address string, timeout time.Duration) (*Peer
}

peer.Version = versionInfo
peer.RaftConn = newRaftPeerConn(peer, self.netAddr)
peer.SigningCerts = []*x509.Certificate{underlay.Certificates()[0]}

binding.AddTypedReceiveHandler(peer)
binding.AddTypedReceiveHandler(peer.RaftConn)
binding.AddReceiveHandlerF(RaftDataType, peer.handleReceiveData)
binding.AddReceiveHandlerF(RaftDisconnectType, peer.handleReceiveDisconnect)
binding.AddCloseHandler(peer)

return self.PeerConnected(peer)
Expand Down Expand Up @@ -552,9 +623,9 @@ func (self *impl) AcceptUnderlay(underlay channel.Underlay) error {

peer.Version = versionInfo

peer.RaftConn = newRaftPeerConn(peer, self.netAddr)
binding.AddTypedReceiveHandler(peer)
binding.AddTypedReceiveHandler(peer.RaftConn)
binding.AddReceiveHandlerF(RaftDataType, peer.handleReceiveData)
binding.AddReceiveHandlerF(RaftDisconnectType, peer.handleReceiveDisconnect)
binding.AddCloseHandler(peer)
return self.PeerConnected(peer)
})
Expand Down
8 changes: 6 additions & 2 deletions controller/raft/mesh/peerconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,15 @@ func (self *raftPeerConn) Write(b []byte) (n int, err error) {
}

func (self *raftPeerConn) Close() error {
return self.peer.closeRaftConn(5 * time.Second)
}

func (self *raftPeerConn) close() bool {
if self.closed.CompareAndSwap(false, true) {
close(self.closeNotify)
return self.peer.Channel.Close()
return true
}
return nil
return false
}

func (self *raftPeerConn) LocalAddr() net.Addr {
Expand Down
6 changes: 3 additions & 3 deletions zititest/models/smoke/smoketest.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ var Model = &model.Model{
},
},
"iperf-server-ert": {
Scope: model.Scope{Tags: model.Tags{"iperf", "service"}},
Scope: model.Scope{Tags: model.Tags{"iperf", "service", "ert"}},
Type: &zitilab.IPerfServerType{},
},
"caddy-ert": {
Expand All @@ -247,7 +247,7 @@ var Model = &model.Model{
},
},
"iperf-server-zet": {
Scope: model.Scope{Tags: model.Tags{"iperf", "service"}},
Scope: model.Scope{Tags: model.Tags{"iperf", "service", "zet"}},
Type: &zitilab.IPerfServerType{},
},
"caddy-zet": {
Expand All @@ -267,7 +267,7 @@ var Model = &model.Model{
},
},
"iperf-server-zt": {
Scope: model.Scope{Tags: model.Tags{"iperf", "service"}},
Scope: model.Scope{Tags: model.Tags{"iperf", "service", "ziti-tunnel"}},
Type: &zitilab.IPerfServerType{},
},
"caddy-zt": {
Expand Down
16 changes: 14 additions & 2 deletions zititest/models/smoke/tests.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,20 @@ func TestFileDownload(hostSelector string, client HttpClient, hostType string, e
return host.ExecLoggedWithTimeout(timeout, cmds...)
}

func TestIperf(hostSelector, hostType string, encrypted, reversed bool) (string, error) {
host, err := model.GetModel().SelectHost("." + hostSelector + "-client")
func TestIperf(clientHostSelector, hostType string, encrypted, reversed bool, run model.Run) (string, error) {
c, err := model.GetModel().SelectComponent(".iperf." + hostType)
if err != nil {
return "", err
}
if err = c.Type.Stop(run, c); err != nil {
return "", err
}
iperfServer := c.Type.(model.ServerComponent)
if err = iperfServer.Start(run, c); err != nil {
return "", err
}

host, err := model.GetModel().SelectHost("." + clientHostSelector + "-client")
if err != nil {
return "", err
}
Expand Down
2 changes: 1 addition & 1 deletion zititest/tests/iperf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func testIPerf(t *testing.T, hostSelector string, hostType string, encrypted boo
success := false

t.Run(fmt.Sprintf("(%s%s%s)-%v", hostSelector, direction, hostType, encDesk), func(t *testing.T) {
o, err := smoke.TestIperf(hostSelector, hostType, encrypted, reversed)
o, err := smoke.TestIperf(hostSelector, hostType, encrypted, reversed, run)
if hostType == "zet" && err != nil {
t.Skipf("zet hosted iperf test failed [%v]", err.Error())
return
Expand Down

0 comments on commit 74656fb

Please sign in to comment.