diff --git a/pkg/agent/multicast/mcast_route.go b/pkg/agent/multicast/mcast_route.go index 279c11c4f19..95fb5cee4db 100644 --- a/pkg/agent/multicast/mcast_route.go +++ b/pkg/agent/multicast/mcast_route.go @@ -17,6 +17,7 @@ import ( "fmt" "net" "strings" + "time" "k8s.io/apimachinery/pkg/util/sets" "k8s.io/client-go/tools/cache" @@ -38,6 +39,7 @@ func newRouteClient(nodeconfig *config.NodeConfig, groupCache cache.Indexer, mul nodeConfig: nodeconfig, groupCache: groupCache, inboundRouteCache: cache.NewIndexer(getMulticastInboundEntryKey, cache.Indexers{GroupNameIndexName: inboundGroupIndexFunc}), + outboundRouteCache: cache.NewIndexer(getMulticastOutboundEntryKey, cache.Indexers{}), multicastInterfaces: multicastInterfaces.List(), socket: multicastSocket, } @@ -72,6 +74,7 @@ type MRouteClient struct { nodeConfig *config.NodeConfig multicastInterfaces []string inboundRouteCache cache.Indexer + outboundRouteCache cache.Indexer groupCache cache.Indexer socket RouteInterface multicastInterfaceConfigs []multicastInterfaceConfig @@ -165,6 +168,12 @@ func (c *MRouteClient) addOutboundMrouteEntry(src net.IP, group net.IP) (err err if err != nil { return err } + routeEntry := &outboundMulticastRouteEntry{ + group: group.String(), + src: src.String(), + createdTime: time.Now(), + } + c.outboundRouteCache.Add(routeEntry) return nil } @@ -200,11 +209,33 @@ type inboundMulticastRouteEntry struct { vif uint16 } +// outboundMulticastRouteEntry encodes the outbound multicast routing entry. +// For example, +// type inboundMulticastRouteEntry struct { +// group "226.94.9.9" +// src "10.0.0.55" +// } encodes the multicast route entry from Antrea gateway to multicast interfaces +// (10.0.0.55,226.94.9.9) Iif: antrea-gw0 Oifs: list of multicastInterfaces. +// The iif is always Antrea gateway and oifs are always outbound interfaces +// so we do not put them in the struct. +// Field pktCount and createdTime are used for removing staled multicast routes. +type outboundMulticastRouteEntry struct { + group string + src string + pktCount uint32 + createdTime time.Time +} + func getMulticastInboundEntryKey(obj interface{}) (string, error) { entry := obj.(*inboundMulticastRouteEntry) return entry.group + "/" + entry.src + "/" + fmt.Sprint(entry.vif), nil } +func getMulticastOutboundEntryKey(obj interface{}) (string, error) { + entry := obj.(*outboundMulticastRouteEntry) + return entry.group + "/" + entry.src, nil +} + func inboundGroupIndexFunc(obj interface{}) ([]string, error) { entry, ok := obj.(*inboundMulticastRouteEntry) if !ok { @@ -272,6 +303,8 @@ type RouteInterface interface { // AddMrouteEntry adds multicast route with specified source(src), multicast group IP(group), // inbound multicast interface(iif) and outbound multicast interfaces(oifs). AddMrouteEntry(src net.IP, group net.IP, iif uint16, oifs []uint16) (err error) + // GetoutboundMroutePacketCount returns number of routed by outboundRoute entry. + GetoutboundMroutePacketCount(src net.IP, group net.IP) (pktCount uint32, err error) // DelMrouteEntry deletes multicast route with specified source(src), multicast group IP(group), // inbound multicast interface(iif). DelMrouteEntry(src net.IP, group net.IP, iif uint16) (err error) diff --git a/pkg/agent/multicast/mcast_route_linux.go b/pkg/agent/multicast/mcast_route_linux.go index 9495a007647..ace64af922e 100644 --- a/pkg/agent/multicast/mcast_route_linux.go +++ b/pkg/agent/multicast/mcast_route_linux.go @@ -21,12 +21,18 @@ import ( "fmt" "net" "syscall" + "time" + "k8s.io/apimachinery/pkg/util/wait" "k8s.io/klog/v2" "antrea.io/antrea/pkg/util/runtime" ) +const ( + outboundMRouteTimeout = time.Minute * 1 +) + // parseIGMPMsg parses the kernel version into parsedIGMPMsg. Note we need to consider the change // after linux 5.9 in the igmpmsg struct when parsing vif. Please check // https://github.com/torvalds/linux/commit/c8715a8e9f38906e73d6d78764216742db13ba0e. @@ -71,6 +77,8 @@ func (c *MRouteClient) run(stopCh <-chan struct{}) { } }() + go wait.NonSlidingUntil(c.updateOutboundMrouteStats, outboundMRouteTimeout, stopCh) + for i := 0; i < int(workerCount); i++ { go c.worker(stopCh) } @@ -78,3 +86,34 @@ func (c *MRouteClient) run(stopCh <-chan struct{}) { c.socket.FlushMRoute() syscall.Close(c.socket.GetFD()) } + +func (c *MRouteClient) updateOutboundMrouteStats() { + klog.V(2).InfoS("Updating outbound multicast route statistics and removing staled routes") + deletedOutboundRoutes := make([]*outboundMulticastRouteEntry, 0) + now := time.Now() + for _, obj := range c.outboundRouteCache.List() { + outboundRoute, _ := obj.(*outboundMulticastRouteEntry) + packetCount, err := c.socket.GetoutboundMroutePacketCount(net.ParseIP(outboundRoute.src).To4(), net.ParseIP(outboundRoute.group).To4()) + if err != nil { + klog.ErrorS(err, "Failed to getpacket count for outbound multicast route", "outboundRoute", outboundRoute) + return + } + packetCountDiff := packetCount - outboundRoute.pktCount + klog.V(4).Infof("Outbound multicast route %v routes %d packets in last %s", outboundRoute, packetCountDiff, outboundMRouteTimeout.String()) + if packetCountDiff == uint32(0) && now.Sub(outboundRoute.createdTime) > outboundMRouteTimeout { + deletedOutboundRoutes = append(deletedOutboundRoutes, outboundRoute) + } else { + outboundRoute.pktCount = packetCount + c.outboundRouteCache.Update(outboundRoute) + } + } + for _, outboundRoute := range deletedOutboundRoutes { + klog.InfoS("Deleting staled outbound multicast route", "group", outboundRoute.group, "source", outboundRoute.src) + err := c.socket.DelMrouteEntry(net.ParseIP(outboundRoute.src).To4(), net.ParseIP(outboundRoute.group).To4(), c.internalInterfaceVIF) + if err != nil { + klog.ErrorS(err, "Failed to delete outbound multicast route", "group", outboundRoute.group, "source", outboundRoute.src) + return + } + c.outboundRouteCache.Delete(outboundRoute) + } +} diff --git a/pkg/agent/multicast/mcast_route_test.go b/pkg/agent/multicast/mcast_route_test.go index 081171e7c00..f4fcea3ce7d 100644 --- a/pkg/agent/multicast/mcast_route_test.go +++ b/pkg/agent/multicast/mcast_route_test.go @@ -21,6 +21,7 @@ import ( "fmt" "net" "testing" + "time" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" @@ -64,6 +65,55 @@ func TestParseIGMPMsg(t *testing.T) { } } +func TestUpdateOutboundMrouteStats(t *testing.T) { + mRoute := newMockMulticastRouteClient(t) + err := mRoute.initialize(t) + assert.Nil(t, err) + now := time.Now() + for _, m := range []struct { + outboundMrouteEntry *outboundMulticastRouteEntry + isStaled bool + currStats uint32 + }{ + { + outboundMrouteEntry: &outboundMulticastRouteEntry{ + group: "224.3.5.7", + src: "10.1.2.3", + createdTime: now, + }, + isStaled: false, + currStats: 0, + }, + { + outboundMrouteEntry: &outboundMulticastRouteEntry{ + group: "224.3.5.8", + src: "10.1.2.3", + createdTime: now.Add(time.Duration(-outboundMRouteTimeout)), + pktCount: 10, + }, + isStaled: false, + currStats: 9, + }, + { + outboundMrouteEntry: &outboundMulticastRouteEntry{ + group: "224.3.5.9", + src: "10.1.2.3", + createdTime: now.Add(time.Duration(-outboundMRouteTimeout)), + pktCount: 0, + }, + isStaled: true, + currStats: 0, + }, + } { + mRoute.outboundRouteCache.Add(m.outboundMrouteEntry) + mockMulticastSocket.EXPECT().GetoutboundMroutePacketCount(net.ParseIP(m.outboundMrouteEntry.src).To4(), net.ParseIP(m.outboundMrouteEntry.group).To4()).Times(1).Return(m.currStats, nil) + if m.isStaled { + mockMulticastSocket.EXPECT().DelMrouteEntry(net.ParseIP(m.outboundMrouteEntry.src).To4(), net.ParseIP(m.outboundMrouteEntry.group).To4(), uint16(0)).Times(1) + } + } + mRoute.updateOutboundMrouteStats() +} + func TestProcessIGMPNocacheMsg(t *testing.T) { mRoute := newMockMulticastRouteClient(t) err := mRoute.initialize(t) diff --git a/pkg/agent/multicast/mcast_socket_linux.go b/pkg/agent/multicast/mcast_socket_linux.go index f8c946b06cb..87495e9ae8a 100644 --- a/pkg/agent/multicast/mcast_socket_linux.go +++ b/pkg/agent/multicast/mcast_socket_linux.go @@ -60,6 +60,18 @@ func (s *Socket) AddMrouteEntry(src net.IP, group net.IP, iif uint16, oifVIFs [] return multicastsyscall.SetsockoptMfcctl(s.GetFD(), syscall.IPPROTO_IP, multicastsyscall.MRT_ADD_MFC, mc) } +func (s *Socket) GetoutboundMroutePacketCount(src net.IP, group net.IP) (pktCount uint32, err error) { + siocSgReq := multicastsyscall.SiocSgReq{ + Src: [4]byte{src[0], src[1], src[2], src[3]}, + Grp: [4]byte{group[0], group[1], group[2], group[3]}, + } + stats, err := multicastsyscall.IoctlGetSiocSgReq(s.GetFD(), &siocSgReq) + if err != nil { + return 0, err + } + return stats.Pktcnt, nil +} + func (s *Socket) DelMrouteEntry(src net.IP, group net.IP, iif uint16) (err error) { mc := &multicastsyscall.Mfcctl{} origin := src.To4() diff --git a/pkg/agent/multicast/mcast_socket_others.go b/pkg/agent/multicast/mcast_socket_others.go index f60848c1f8a..cecf9781bce 100644 --- a/pkg/agent/multicast/mcast_socket_others.go +++ b/pkg/agent/multicast/mcast_socket_others.go @@ -31,6 +31,10 @@ func (s *Socket) AddMrouteEntry(src net.IP, group net.IP, iif uint16, oifVIFs [] return nil } +func (s *Socket) GetoutboundMroutePacketCount(src net.IP, group net.IP) (pktCount uint32, err error) { + return 0, nil +} + func (s *Socket) DelMrouteEntry(src net.IP, group net.IP, iif uint16) (err error) { return nil } diff --git a/pkg/agent/multicast/testing/mock_multicast.go b/pkg/agent/multicast/testing/mock_multicast.go index 79a2701ebe3..274d7a58d89 100644 --- a/pkg/agent/multicast/testing/mock_multicast.go +++ b/pkg/agent/multicast/testing/mock_multicast.go @@ -117,6 +117,21 @@ func (mr *MockRouteInterfaceMockRecorder) GetFD() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFD", reflect.TypeOf((*MockRouteInterface)(nil).GetFD)) } +// GetoutboundMroutePacketCount mocks base method +func (m *MockRouteInterface) GetoutboundMroutePacketCount(arg0, arg1 net.IP) (uint32, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetoutboundMroutePacketCount", arg0, arg1) + ret0, _ := ret[0].(uint32) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetoutboundMroutePacketCount indicates an expected call of GetoutboundMroutePacketCount +func (mr *MockRouteInterfaceMockRecorder) GetoutboundMroutePacketCount(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetoutboundMroutePacketCount", reflect.TypeOf((*MockRouteInterface)(nil).GetoutboundMroutePacketCount), arg0, arg1) +} + // MulticastInterfaceJoinMgroup mocks base method func (m *MockRouteInterface) MulticastInterfaceJoinMgroup(arg0, arg1 net.IP, arg2 string) error { m.ctrl.T.Helper() diff --git a/pkg/agent/util/syscall/linux/types.go b/pkg/agent/util/syscall/linux/types.go index fc1fc877855..a81f5c2b16c 100644 --- a/pkg/agent/util/syscall/linux/types.go +++ b/pkg/agent/util/syscall/linux/types.go @@ -47,6 +47,7 @@ const ( type Mfcctl C.struct_mfcctl type Vifctl C.struct_vifctl_with_ifindex +type SiocSgReq C.struct_siocsgreq const SizeofMfcctl = C.sizeof_struct_mfcctl const SizeofVifctl = C.sizeof_struct_vifctl_with_ifindex diff --git a/pkg/agent/util/syscall/syscall_unix.go b/pkg/agent/util/syscall/syscall_unix.go index 6a1a494f9cd..6f85281d1a5 100644 --- a/pkg/agent/util/syscall/syscall_unix.go +++ b/pkg/agent/util/syscall/syscall_unix.go @@ -19,6 +19,7 @@ package syscall import ( + "runtime" "syscall" "unsafe" ) @@ -34,7 +35,15 @@ func setsockopt(s int, level int, name int, val unsafe.Pointer, vallen uintptr) return } -// Please add your wrapped syscall functions below +func ioctl(fd int, req uint, arg uintptr) (err error) { + _, _, e1 := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), uintptr(req), uintptr(arg)) + if e1 != 0 { + return e1 + } + return +} + +// Please add your wrapped syscall functions below. func SetsockoptMfcctl(fd, level, opt int, mfcctl *Mfcctl) error { return setsockopt(fd, level, opt, unsafe.Pointer(mfcctl), SizeofMfcctl) @@ -43,3 +52,9 @@ func SetsockoptMfcctl(fd, level, opt int, mfcctl *Mfcctl) error { func SetsockoptVifctl(fd, level, opt int, vifctl *Vifctl) error { return setsockopt(fd, level, opt, unsafe.Pointer(vifctl), SizeofVifctl) } + +func IoctlGetSiocSgReq(fd int, siocsgreq *SiocSgReq) (*SiocSgReq, error) { + err := ioctl(fd, SIOCGETSGCNT, uintptr(unsafe.Pointer(siocsgreq))) + runtime.KeepAlive(siocsgreq) + return siocsgreq, err +} diff --git a/pkg/agent/util/syscall/ztypes_linux.go b/pkg/agent/util/syscall/ztypes_linux.go index 2d064ccfe3c..2664ca9ba38 100644 --- a/pkg/agent/util/syscall/ztypes_linux.go +++ b/pkg/agent/util/syscall/ztypes_linux.go @@ -26,6 +26,7 @@ const ( MRT_INIT = 0xc8 MRT_FLUSH = 0xd4 MAXVIFS = 0x20 + SIOCGETSGCNT = 0x89e1 ) type Mfcctl struct { @@ -48,6 +49,14 @@ type Vifctl struct { Rmt_addr [4]byte /* in_addr */ } +type SiocSgReq = struct { + Src [4]byte /* in_addr */ + Grp [4]byte /* in_addr */ + Pktcnt uint32 + Bytecnt uint32 + If uint32 +} + const SizeofMfcctl = 0x3c const SizeofVifctl = 0x10 const SizeofIgmpmsg = 0x14