diff --git a/controller/handler_peer_ctrl/transfer_leadership.go b/controller/handler_peer_ctrl/transfer_leadership.go index ce466ed83..3b0eb8f23 100644 --- a/controller/handler_peer_ctrl/transfer_leadership.go +++ b/controller/handler_peer_ctrl/transfer_leadership.go @@ -20,9 +20,9 @@ import ( raft2 "github.com/hashicorp/raft" "github.com/michaelquigley/pfxlog" "github.com/openziti/channel/v2" + "github.com/openziti/ziti/common/pb/cmd_pb" "github.com/openziti/ziti/controller/peermsg" "github.com/openziti/ziti/controller/raft" - "github.com/openziti/ziti/common/pb/cmd_pb" "github.com/pkg/errors" "github.com/sirupsen/logrus" "google.golang.org/protobuf/proto" diff --git a/controller/raft/mesh/mesh.go b/controller/raft/mesh/mesh.go index 9bab19268..3c015aad4 100644 --- a/controller/raft/mesh/mesh.go +++ b/controller/raft/mesh/mesh.go @@ -46,8 +46,8 @@ const ( RaftDataType = 2049 SigningCertHeader = 2050 ApiAddressesHeader = 2051 - - ChannelTypeMesh = "ctrl.mesh" + RaftDisconnectType = 2052 + ChannelTypeMesh = "ctrl.mesh" ) type Peer struct { @@ -55,13 +55,31 @@ type Peer struct { Id raft.ServerID Address string Channel channel.Channel - RaftConn *raftPeerConn + RaftConn atomic.Pointer[raftPeerConn] Version *versions.VersionInfo SigningCerts []*x509.Certificate ApiAddresses map[string][]event.ApiAddress } +func (self *Peer) initRaftConn() *raftPeerConn { + self.mesh.lock.Lock() + defer self.mesh.lock.Unlock() + conn := self.RaftConn.Load() + if conn == nil { + conn = newRaftPeerConn(self, self.mesh.netAddr) + self.RaftConn.Store(conn) + } + return conn +} + func (self *Peer) HandleClose(channel.Channel) { + self.mesh.lock.Lock() + conn := self.RaftConn.Swap(nil) + if conn != nil { + conn.close() + } + self.mesh.lock.Unlock() + self.mesh.PeerDisconnected(self) } @@ -77,17 +95,68 @@ func (self *Peer) HandleReceive(m *channel.Message, _ channel.Channel) { if err := response.WithTimeout(5 * time.Second).Send(self.Channel); err != nil { logrus.WithError(err).Error("failed to send connect response") } else { + conn := self.initRaftConn() select { - case self.mesh.raftAccepts <- self.RaftConn: + case self.mesh.raftAccepts <- conn: case <-self.mesh.closeNotify: } } }() } -func (self *Peer) Connect(timeout time.Duration) error { +func (self *Peer) handleReceiveDisconnect(m *channel.Message, _ channel.Channel) { + go func() { + self.mesh.lock.Lock() + conn := self.RaftConn.Swap(nil) + self.mesh.lock.Unlock() + + if conn != nil { + conn.close() + } + + response := channel.NewResult(true, "") + response.ReplyTo(m) + + if err := response.WithTimeout(5 * time.Second).Send(self.Channel); err != nil { + logrus.WithError(err).Error("failed to send close response") + } + }() +} + +func (self *Peer) handleReceiveData(m *channel.Message, ch channel.Channel) { + if conn := self.RaftConn.Load(); conn != nil { + conn.HandleReceive(m, ch) + } +} + +func (self *Peer) Connect(timeout time.Duration) (net.Conn, error) { msg := channel.NewMessage(RaftConnectType, nil) response, err := msg.WithTimeout(timeout).SendForReply(self.Channel) + if err != nil { + return nil, err + } + result := channel.UnmarshalResult(response) + if !result.Success { + return nil, errors.Errorf("connect failed: %v", result.Message) + } + + logrus.Infof("connected peer %v at %v", self.Id, self.Address) + + return self.initRaftConn(), nil +} + +func (self *Peer) closeRaftConn(timeout time.Duration) error { + self.mesh.lock.Lock() + conn := self.RaftConn.Swap(nil) + defer self.mesh.lock.Unlock() + if conn == nil { + return nil + } + + conn.close() + + msg := channel.NewMessage(RaftDisconnectType, nil) + response, err := msg.WithTimeout(timeout).SendForReply(self.Channel) if err != nil { return err } @@ -96,7 +165,7 @@ func (self *Peer) Connect(timeout time.Duration) error { return errors.Errorf("connect failed: %v", result.Message) } - logrus.Infof("connected peer %v at %v", self.Id, self.Address) + logrus.Infof("disconnected peer %v at %v", self.Id, self.Address) return nil } @@ -227,10 +296,12 @@ func (self *impl) Dial(address raft.ServerAddress, timeout time.Duration) (net.C if err != nil { return nil, err } - if err := peer.Connect(timeout); err != nil { - return nil, err + + if peerConn := peer.RaftConn.Load(); peerConn != nil { + return peerConn, nil } - return peer.RaftConn, nil + + return peer.Connect(timeout) } func (self *impl) GetOrConnectPeer(address string, timeout time.Duration) (*Peer, error) { @@ -310,11 +381,11 @@ func (self *impl) GetOrConnectPeer(address string, timeout time.Duration) (*Peer } peer.Version = versionInfo - peer.RaftConn = newRaftPeerConn(peer, self.netAddr) peer.SigningCerts = []*x509.Certificate{underlay.Certificates()[0]} binding.AddTypedReceiveHandler(peer) - binding.AddTypedReceiveHandler(peer.RaftConn) + binding.AddReceiveHandlerF(RaftDataType, peer.handleReceiveData) + binding.AddReceiveHandlerF(RaftDisconnectType, peer.handleReceiveDisconnect) binding.AddCloseHandler(peer) return self.PeerConnected(peer) @@ -552,9 +623,9 @@ func (self *impl) AcceptUnderlay(underlay channel.Underlay) error { peer.Version = versionInfo - peer.RaftConn = newRaftPeerConn(peer, self.netAddr) binding.AddTypedReceiveHandler(peer) - binding.AddTypedReceiveHandler(peer.RaftConn) + binding.AddReceiveHandlerF(RaftDataType, peer.handleReceiveData) + binding.AddReceiveHandlerF(RaftDisconnectType, peer.handleReceiveDisconnect) binding.AddCloseHandler(peer) return self.PeerConnected(peer) }) diff --git a/controller/raft/mesh/peerconn.go b/controller/raft/mesh/peerconn.go index 9ce9cf448..aa56249e4 100644 --- a/controller/raft/mesh/peerconn.go +++ b/controller/raft/mesh/peerconn.go @@ -129,11 +129,15 @@ func (self *raftPeerConn) Write(b []byte) (n int, err error) { } func (self *raftPeerConn) Close() error { + return self.peer.closeRaftConn(5 * time.Second) +} + +func (self *raftPeerConn) close() bool { if self.closed.CompareAndSwap(false, true) { close(self.closeNotify) - return self.peer.Channel.Close() + return true } - return nil + return false } func (self *raftPeerConn) LocalAddr() net.Addr { diff --git a/zititest/models/smoke/smoketest.go b/zititest/models/smoke/smoketest.go index bed6d25dc..8a86c25f0 100644 --- a/zititest/models/smoke/smoketest.go +++ b/zititest/models/smoke/smoketest.go @@ -226,7 +226,7 @@ var Model = &model.Model{ }, }, "iperf-server-ert": { - Scope: model.Scope{Tags: model.Tags{"iperf", "service"}}, + Scope: model.Scope{Tags: model.Tags{"iperf", "service", "ert"}}, Type: &zitilab.IPerfServerType{}, }, "caddy-ert": { @@ -247,7 +247,7 @@ var Model = &model.Model{ }, }, "iperf-server-zet": { - Scope: model.Scope{Tags: model.Tags{"iperf", "service"}}, + Scope: model.Scope{Tags: model.Tags{"iperf", "service", "zet"}}, Type: &zitilab.IPerfServerType{}, }, "caddy-zet": { @@ -267,7 +267,7 @@ var Model = &model.Model{ }, }, "iperf-server-zt": { - Scope: model.Scope{Tags: model.Tags{"iperf", "service"}}, + Scope: model.Scope{Tags: model.Tags{"iperf", "service", "ziti-tunnel"}}, Type: &zitilab.IPerfServerType{}, }, "caddy-zt": { diff --git a/zititest/models/smoke/tests.go b/zititest/models/smoke/tests.go index 60ea72ce9..f1f03c0a2 100644 --- a/zititest/models/smoke/tests.go +++ b/zititest/models/smoke/tests.go @@ -57,8 +57,20 @@ func TestFileDownload(hostSelector string, client HttpClient, hostType string, e return host.ExecLoggedWithTimeout(timeout, cmds...) } -func TestIperf(hostSelector, hostType string, encrypted, reversed bool) (string, error) { - host, err := model.GetModel().SelectHost("." + hostSelector + "-client") +func TestIperf(clientHostSelector, hostType string, encrypted, reversed bool, run model.Run) (string, error) { + c, err := model.GetModel().SelectComponent(".iperf." + hostType) + if err != nil { + return "", err + } + if err = c.Type.Stop(run, c); err != nil { + return "", err + } + iperfServer := c.Type.(model.ServerComponent) + if err = iperfServer.Start(run, c); err != nil { + return "", err + } + + host, err := model.GetModel().SelectHost("." + clientHostSelector + "-client") if err != nil { return "", err } diff --git a/zititest/tests/iperf_test.go b/zititest/tests/iperf_test.go index 31a438f60..80863d741 100644 --- a/zititest/tests/iperf_test.go +++ b/zititest/tests/iperf_test.go @@ -77,7 +77,7 @@ func testIPerf(t *testing.T, hostSelector string, hostType string, encrypted boo success := false t.Run(fmt.Sprintf("(%s%s%s)-%v", hostSelector, direction, hostType, encDesk), func(t *testing.T) { - o, err := smoke.TestIperf(hostSelector, hostType, encrypted, reversed) + o, err := smoke.TestIperf(hostSelector, hostType, encrypted, reversed, run) if hostType == "zet" && err != nil { t.Skipf("zet hosted iperf test failed [%v]", err.Error()) return