diff --git a/docs/design/windows-design.md b/docs/design/windows-design.md index 1fb1cc0847a..4143f546ac9 100644 --- a/docs/design/windows-design.md +++ b/docs/design/windows-design.md @@ -209,7 +209,7 @@ It is processed and forwarded by OVS, and controlled with OpenFlow entries. ### Service Traffic -Kube-proxy userspace mode is configured to provide NodePort Service function. A specific Network Adapter named +Kube-proxy userspace mode is configured to provide NodePort Service function. A specific Network adapter named "HNS Internal NIC" is provided to kube-proxy to configure Service addresses. The OpenFlow entries for the NodePort Service traffic on Windows are the same as those on Linux. diff --git a/pkg/agent/agent_windows.go b/pkg/agent/agent_windows.go index beb50c9ef1f..00034fe18cb 100644 --- a/pkg/agent/agent_windows.go +++ b/pkg/agent/agent_windows.go @@ -92,12 +92,10 @@ func (i *Initializer) prepareHNSNetworkAndOVSExtension() error { i.nodeConfig.UplinkNetConfig.Index = adapter.Index defaultGW, err := util.GetDefaultGatewayByInterfaceIndex(adapter.Index) if err != nil { - if strings.Contains(err.Error(), "No matching MSFT_NetRoute objects found") { - klog.InfoS("No default gateway found on interface", "interface", adapter.Name) - defaultGW = "" - } else { - return err - } + return err + } + if defaultGW == "" { + klog.InfoS("No default gateway found on interface", "interface", adapter.Name) } i.nodeConfig.UplinkNetConfig.Gateway = defaultGW dnsServers, err := util.GetDNServersByInterfaceIndex(adapter.Index) diff --git a/pkg/agent/cniserver/interface_configuration_windows.go b/pkg/agent/cniserver/interface_configuration_windows.go index d33ac9f2363..a083354369d 100644 --- a/pkg/agent/cniserver/interface_configuration_windows.go +++ b/pkg/agent/cniserver/interface_configuration_windows.go @@ -505,7 +505,7 @@ func (ic *ifConfigurator) addPostInterfaceCreateHook(containerID, endpointName s go func() { ifaceName := fmt.Sprintf("vEthernet (%s)", endpointName) var err error - pollErr := wait.PollImmediate(time.Second, 60*time.Second, func() (bool, error) { + pollErr := wait.PollImmediate(100*time.Millisecond, 60*time.Second, func() (bool, error) { containerAccess.lockContainer(containerID) defer containerAccess.unlockContainer(containerID) currentEP, ok := ic.getEndpoint(endpointName) @@ -518,7 +518,7 @@ func (ic *ifConfigurator) addPostInterfaceCreateHook(containerID, endpointName s return true, nil } if !hostInterfaceExistsFunc(ifaceName) { - klog.InfoS("Waiting for interface to be created", "interface", ifaceName) + klog.V(2).InfoS("Waiting for interface to be created", "interface", ifaceName) return false, nil } if err = hook(); err != nil { diff --git a/pkg/agent/route/route_windows.go b/pkg/agent/route/route_windows.go index 4a757045197..399ac389567 100644 --- a/pkg/agent/route/route_windows.go +++ b/pkg/agent/route/route_windows.go @@ -22,7 +22,6 @@ import ( "fmt" "net" "reflect" - "strings" "sync" "k8s.io/apimachinery/pkg/util/sets" @@ -312,11 +311,7 @@ func (c *Client) addServiceCIDRRoute(serviceCIDR *net.IPNet) error { // Remove stale routes. for _, rt := range staleRoutes { if err := util.RemoveNetRoute(rt); err != nil { - if strings.Contains(err.Error(), "No matching MSFT_NetRoute objects") { - klog.InfoS("Failed to delete stale Service CIDR route since the route has been deleted", "route", rt) - } else { - return fmt.Errorf("failed to delete stale Service CIDR route %s: %w", rt.String(), err) - } + return fmt.Errorf("failed to delete stale Service CIDR route %s: %w", rt.String(), err) } else { klog.V(4).InfoS("Deleted stale Service CIDR route successfully", "route", rt) } @@ -547,11 +542,7 @@ func (c *Client) DeleteExternalIPRoute(externalIP net.IP) error { return nil } if err := util.RemoveNetRoute(route.(*util.Route)); err != nil { - if strings.Contains(err.Error(), "No matching MSFT_NetRoute objects") { - klog.InfoS("Failed to delete route for external IP since it doesn't exist", "IP", externalIPStr) - } else { - return fmt.Errorf("failed to delete route for external IP %s: %w", externalIPStr, err) - } + return fmt.Errorf("failed to delete route for external IP %s: %w", externalIPStr, err) } c.serviceRoutes.Delete(externalIPStr) klog.V(4).InfoS("Deleted route for external IP", "IP", externalIPStr) diff --git a/pkg/agent/route/route_windows_test.go b/pkg/agent/route/route_windows_test.go index ebd7154df55..f5e731e599b 100644 --- a/pkg/agent/route/route_windows_test.go +++ b/pkg/agent/route/route_windows_test.go @@ -84,10 +84,6 @@ func TestRouteOperation(t *testing.T) { err = client.Reconcile([]string{dest2}) require.Nil(t, err) - routes5, err := util.GetNetRoutes(gwLink, destCIDR1) - require.Nil(t, err) - assert.Equal(t, 0, len(routes5)) - err = client.DeleteRoutes(destCIDR2) require.Nil(t, err) routes7, err := util.GetNetRoutes(gwLink, destCIDR2) diff --git a/pkg/agent/util/net_linux_test.go b/pkg/agent/util/net_linux_test.go index 97555766608..4320a2f4c05 100644 --- a/pkg/agent/util/net_linux_test.go +++ b/pkg/agent/util/net_linux_test.go @@ -407,7 +407,7 @@ func TestSetAdapterMACAddress(t *testing.T) { wantErr error }{ { - name: "Set Adapter MAC", + name: "Set adapter MAC", expectedCalls: func(mockNetlink *netlinktest.MockInterfaceMockRecorder) { mockNetlink.LinkByName("test-en0").Return(testLink, nil) mockNetlink.LinkSetHardwareAddr(testLink, testMACAddr).Return(nil) diff --git a/pkg/agent/util/net_windows.go b/pkg/agent/util/net_windows.go index db6464baedf..f15e02665e6 100644 --- a/pkg/agent/util/net_windows.go +++ b/pkg/agent/util/net_windows.go @@ -1,7 +1,7 @@ //go:build windows // +build windows -//Copyright 2020 Antrea Authors +// Copyright 2020 Antrea Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -21,19 +21,26 @@ import ( "bufio" "bytes" "encoding/json" + "errors" "fmt" "net" + "os" + "runtime" "strconv" "strings" + "syscall" "time" + "unsafe" "github.com/Microsoft/go-winio" "github.com/Microsoft/hcsshim" "github.com/containernetworking/plugins/pkg/ip" + "golang.org/x/sys/windows" "k8s.io/apimachinery/pkg/util/wait" "k8s.io/klog/v2" ps "antrea.io/antrea/pkg/agent/util/powershell" + antreasyscall "antrea.io/antrea/pkg/agent/util/syscall" binding "antrea.io/antrea/pkg/ovs/openflow" ) @@ -56,11 +63,13 @@ const ( var ( // Declared variables which are meant to be overridden for testing. - runCommand = ps.RunCommand - getHNSNetworkByName = hcsshim.GetHNSNetworkByName - hnsNetworkRequest = hcsshim.HNSNetworkRequest - hnsNetworkCreate = (*hcsshim.HNSNetwork).Create - hnsNetworkDelete = (*hcsshim.HNSNetwork).Delete + antreaNetIO = antreasyscall.NewNetIO() + getAdaptersAddresses = windows.GetAdaptersAddresses + runCommand = ps.RunCommand + getHNSNetworkByName = hcsshim.GetHNSNetworkByName + hnsNetworkRequest = hcsshim.HNSNetworkRequest + hnsNetworkCreate = (*hcsshim.HNSNetwork).Create + hnsNetworkDelete = (*hcsshim.HNSNetwork).Delete ) type Route struct { @@ -70,19 +79,39 @@ type Route struct { RouteMetric int } -func (r Route) String() string { +func (r *Route) String() string { return fmt.Sprintf("LinkIndex: %d, DestinationSubnet: %s, GatewayAddress: %s, RouteMetric: %d", r.LinkIndex, r.DestinationSubnet, r.GatewayAddress, r.RouteMetric) } -func (r Route) Equal(x Route) bool { +func (r *Route) Equal(x Route) bool { return x.LinkIndex == r.LinkIndex && x.DestinationSubnet != nil && r.DestinationSubnet != nil && - x.DestinationSubnet.IP.Equal(r.DestinationSubnet.IP) && + x.DestinationSubnet.String() == r.DestinationSubnet.String() && x.GatewayAddress.Equal(r.GatewayAddress) } +func (r *Route) toMibIPForwardRow() *antreasyscall.MibIPForwardRow { + row := antreasyscall.NewIPForwardRow() + row.DestinationPrefix = *antreasyscall.NewAddressPrefixFromIPNet(r.DestinationSubnet) + row.NextHop = *antreasyscall.NewRawSockAddrInetFromIP(r.GatewayAddress) + row.Metric = uint32(r.RouteMetric) + row.Index = uint32(r.LinkIndex) + return row +} + +func routeFromIPForwardRow(row *antreasyscall.MibIPForwardRow) *Route { + destination := row.DestinationPrefix.IPNet() + gatewayAddr := row.NextHop.IP() + return &Route{ + DestinationSubnet: destination, + GatewayAddress: gatewayAddr, + LinkIndex: int(row.Index), + RouteMetric: int(row.Metric), + } +} + type Neighbor struct { LinkIndex int IPAddress net.IP @@ -202,9 +231,11 @@ func ConfigureInterfaceAddressWithDefaultGateway(ifaceName string, ipConfig *net // EnableIPForwarding enables the IP interface to forward packets that arrive at this interface to other interfaces. func EnableIPForwarding(ifaceName string) error { - cmd := fmt.Sprintf(`Set-NetIPInterface -InterfaceAlias "%s" -Forwarding Enabled`, ifaceName) - _, err := runCommand(cmd) - return err + adapter, err := getAdapterInAllCompartmentsByName(ifaceName) + if err != nil { + return fmt.Errorf("unable to find NetAdapter on host in all compartments with name %s: %v", ifaceName, err) + } + return adapter.setForwarding(true, antreasyscall.AF_INET) } func RenameVMNetworkAdapter(networkName string, macStr, newName string, renameNetAdapter bool) error { @@ -544,13 +575,18 @@ func EnableRSCOnVSwitch(vSwitch string) error { // GetDefaultGatewayByInterfaceIndex returns the default gateway configured on the specified interface. func GetDefaultGatewayByInterfaceIndex(ifIndex int) (string, error) { - cmd := fmt.Sprintf("$(Get-NetRoute -InterfaceIndex %d -DestinationPrefix 0.0.0.0/0 ).NextHop", ifIndex) - defaultGW, err := runCommand(cmd) + ip, defaultDestination, _ := net.ParseCIDR("0.0.0.0/0") + family := addressFamilyByIP(ip) + routes, err := listRoutes(family, func(row antreasyscall.MibIPForwardRow) bool { + return row.DestinationPrefix.EqualsTo(defaultDestination) && row.Index == uint32(ifIndex) + }) if err != nil { return "", err } - defaultGW = strings.ReplaceAll(defaultGW, "\r\n", "") - return defaultGW, nil + if len(routes) == 0 { + return "", nil + } + return routes[0].GatewayAddress.String(), nil } // GetDNServersByInterfaceIndex returns the DNS servers configured on the specified interface. @@ -595,15 +631,8 @@ func DialLocalSocket(address string) (net.Conn, error) { } func HostInterfaceExists(ifaceName string) bool { - if _, err := netInterfaceByName(ifaceName); err == nil { - return true - } - // Some kinds of interfaces cannot be retrieved by "net.InterfaceByName" such as - // container vnic. - // So if an interface cannot be found by above function, use powershell command - // "Get-NetAdapter" to check if it exists. - cmd := fmt.Sprintf(`Get-NetAdapter -InterfaceAlias "%s"`, ifaceName) - if _, err := runCommand(cmd); err != nil { + _, err := getAdapterInAllCompartmentsByName(ifaceName) + if err != nil { return false } return true @@ -613,92 +642,104 @@ func HostInterfaceExists(ifaceName string) bool { // there's no MTU field in HNSEndpoint: // https://github.com/Microsoft/hcsshim/blob/4a468a6f7ae547974bc32911395c51fb1862b7df/internal/hns/hnsendpoint.go#L12 func SetInterfaceMTU(ifaceName string, mtu int) error { - cmd := fmt.Sprintf("Set-NetIPInterface -IncludeAllCompartments -InterfaceAlias \"%s\" -NlMtuBytes %d", - ifaceName, mtu) - _, err := runCommand(cmd) - return err + adapter, err := getAdapterInAllCompartmentsByName(ifaceName) + if err != nil { + return fmt.Errorf("unable to find NetAdapter on host in all compartments with name %s: %v", ifaceName, err) + } + return adapter.setMTU(mtu, antreasyscall.AF_INET) } func NewNetRoute(route *Route) error { - cmd := fmt.Sprintf("New-NetRoute -InterfaceIndex %v -DestinationPrefix %v -NextHop %v -RouteMetric %d -Verbose", - route.LinkIndex, route.DestinationSubnet.String(), route.GatewayAddress.String(), route.RouteMetric) - _, err := runCommand(cmd) - return err + if route == nil { + return nil + } + row := route.toMibIPForwardRow() + if err := antreaNetIO.CreateIPForwardEntry(row); err != nil { + return fmt.Errorf("failed to create new IPForward row: %v", err) + } + return nil } func RemoveNetRoute(route *Route) error { - cmd := fmt.Sprintf("Remove-NetRoute -InterfaceIndex %v -DestinationPrefix %v -Verbose -Confirm:$false", - route.LinkIndex, route.DestinationSubnet.String()) - _, err := runCommand(cmd) - return err -} - -func ReplaceNetRoute(route *Route) error { - rs, err := GetNetRoutes(route.LinkIndex, route.DestinationSubnet) - if err != nil { - return err - } - - if len(rs) == 0 { - if err := NewNetRoute(route); err != nil { - return err - } + if route == nil || route.DestinationSubnet == nil { return nil } - - for _, r := range rs { - if r.GatewayAddress.Equal(route.GatewayAddress) { - return nil + family := addressFamilyByIP(route.DestinationSubnet.IP) + rows, err := antreaNetIO.ListIPForwardRows(family) + if err != nil { + return fmt.Errorf("unable to list Windows IPForward rows: %v", err) + } + for i := range rows { + row := rows[i] + if row.DestinationPrefix.EqualsTo(route.DestinationSubnet) && row.Index == uint32(route.LinkIndex) && row.NextHop.IP().Equal(route.GatewayAddress) { + if err := antreaNetIO.DeleteIPForwardEntry(&row); err != nil { + return fmt.Errorf("failed to delete existing route %s: %v", route.String(), err) + } } } + return nil +} - if err := RemoveNetRoute(route); err != nil { - return err +func ReplaceNetRoute(route *Route) error { + if route == nil || route.DestinationSubnet == nil { + return nil } - if err := NewNetRoute(route); err != nil { - return err + family := addressFamilyByIP(route.DestinationSubnet.IP) + rows, err := antreaNetIO.ListIPForwardRows(family) + if err != nil { + return fmt.Errorf("unable to list Windows IPForward rows: %v", err) + } + for i := range rows { + row := rows[i] + if row.DestinationPrefix.EqualsTo(route.DestinationSubnet) && row.Index == uint32(route.LinkIndex) { + if row.NextHop.IP().Equal(route.GatewayAddress) { + return nil + } else { + if err := antreaNetIO.DeleteIPForwardEntry(&row); err != nil { + return fmt.Errorf("failed to delete existing route with nextHop %s: %v", route.GatewayAddress, err) + } + } + } } - return nil + return NewNetRoute(route) } func GetNetRoutes(linkIndex int, dstSubnet *net.IPNet) ([]Route, error) { - cmd := fmt.Sprintf("Get-NetRoute -InterfaceIndex %d -DestinationPrefix %s -ErrorAction Ignore | Format-Table -HideTableHeaders", - linkIndex, dstSubnet.String()) - return getNetRoutes(cmd) + if dstSubnet == nil { + return nil, fmt.Errorf("unable to get net routes for %d", linkIndex) + } + family := addressFamilyByIP(dstSubnet.IP) + return listRoutes(family, func(row antreasyscall.MibIPForwardRow) bool { + return row.DestinationPrefix.EqualsTo(dstSubnet) && row.Index == uint32(linkIndex) + }) } func GetNetRoutesAll() ([]Route, error) { - cmd := "Get-NetRoute -ErrorAction Ignore | Format-Table -HideTableHeaders" - return getNetRoutes(cmd) + return listRoutes(antreasyscall.AF_UNSPEC, nil) } -func getNetRoutes(cmd string) ([]Route, error) { - routesStr, _ := runCommand(cmd) - parsed := parseGetNetCmdResult(routesStr, 6) - var routes []Route - for _, items := range parsed { - idx, err := strconv.Atoi(items[0]) - if err != nil { - return nil, fmt.Errorf("failed to parse the LinkIndex '%s': %v", items[0], err) - } - _, dstSubnet, err := net.ParseCIDR(items[1]) - if err != nil { - return nil, fmt.Errorf("failed to parse the DestinationSubnet '%s': %v", items[1], err) - } - gw := net.ParseIP(items[2]) - metric, err := strconv.Atoi(items[3]) - if err != nil { - return nil, fmt.Errorf("failed to parse the RouteMetric '%s': %v", items[3], err) - } - route := Route{ - LinkIndex: idx, - DestinationSubnet: dstSubnet, - GatewayAddress: gw, - RouteMetric: metric, +type routeFilter func(row antreasyscall.MibIPForwardRow) bool + +func listRoutes(family uint16, filter routeFilter) ([]Route, error) { + rows, err := antreaNetIO.ListIPForwardRows(family) + if err != nil { + return nil, fmt.Errorf("unable to list Windows IPForward rows: %v", err) + } + rts := make([]Route, 0, len(rows)) + for i, r := range rows { + if filter == nil || filter(r) { + route := routeFromIPForwardRow(&rows[i]) + rts = append(rts, *route) } - routes = append(routes, route) } - return routes, nil + return rts, nil +} + +func addressFamilyByIP(ip net.IP) uint16 { + if ip.To4() != nil { + return antreasyscall.AF_INET + } + return antreasyscall.AF_INET6 } func parseGetNetCmdResult(result string, itemNum int) [][]string { @@ -860,7 +901,7 @@ func GetNetNeighbor(neighbor *Neighbor) ([]Neighbor, error) { if err != nil { return nil, fmt.Errorf("failed to parse the DestinationIP '%s': %v", items[1], err) } - // Get-NetRoute returns LinkLayerAddress like "AA-BB-CC-DD-EE-FF". + // Get-NetNeighbor returns LinkLayerAddress like "AA-BB-CC-DD-EE-FF". mac, err := net.ParseMAC(strings.ReplaceAll(items[2], "-", ":")) if err != nil { return nil, fmt.Errorf("failed to parse the Gateway MAC '%s': %v", items[2], err) @@ -1026,6 +1067,75 @@ func GenHostInterfaceName(upLinkIfName string) string { return strings.TrimSuffix(upLinkIfName, bridgedUplinkSuffix) } +type updateIPInterfaceFunc func(entry *antreasyscall.MibIPInterfaceRow) *antreasyscall.MibIPInterfaceRow + +type adapter struct { + net.Interface + compartmentID uint32 +} + +func (a *adapter) setMTU(mtu int, family uint16) error { + if err := a.setIPInterfaceEntry(family, func(entry *antreasyscall.MibIPInterfaceRow) *antreasyscall.MibIPInterfaceRow { + newEntry := *entry + newEntry.NlMtu = uint32(mtu) + return &newEntry + }); err != nil { + return fmt.Errorf("unable to set IPInterface with MTU %d: %v", mtu, err) + } + return nil +} + +func (a *adapter) setForwarding(enabledForwarding bool, family uint16) error { + if err := a.setIPInterfaceEntry(family, func(entry *antreasyscall.MibIPInterfaceRow) *antreasyscall.MibIPInterfaceRow { + newEntry := *entry + newEntry.ForwardingEnabled = enabledForwarding + return &newEntry + }); err != nil { + return fmt.Errorf("unable to enable IPForwarding on net adapter: %v", err) + } + return nil +} + +func (a *adapter) setIPInterfaceEntry(family uint16, updateFunc updateIPInterfaceFunc) error { + if a.compartmentID > 1 { + runtime.LockOSThread() + defer func() { + hcsshim.SetCurrentThreadCompartmentId(0) + runtime.UnlockOSThread() + }() + if err := hcsshim.SetCurrentThreadCompartmentId(a.compartmentID); err != nil { + klog.ErrorS(err, "Failed to change current thread's compartment", "compartment", a.compartmentID) + return err + } + } + ipInterfaceRow := &antreasyscall.MibIPInterfaceRow{Family: family, Index: uint32(a.Index)} + if err := antreaNetIO.GetIPInterfaceEntry(ipInterfaceRow); err != nil { + return fmt.Errorf("unable to get IPInterface entry with Index %d: %v", a.Index, err) + } + updatedRow := updateFunc(ipInterfaceRow) + updatedRow.SitePrefixLength = 0 + return antreaNetIO.SetIPInterfaceEntry(updatedRow) +} + +var ( + errInvalidInterfaceName = errors.New("invalid network interface name") + errNoSuchInterface = errors.New("no such network interface") +) + +func getAdapterInAllCompartmentsByName(name string) (*adapter, error) { + if name == "" { + return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: errInvalidInterfaceName} + } + adapters, err := getAdaptersByName(name) + if err != nil { + return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: err} + } + if len(adapters) == 0 { + return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: errNoSuchInterface} + } + return &adapters[0], nil +} + // createVMSwitchWithTeaming creates VMSwitch and enables OVS extension. // Connection to VM is lost for few seconds func createVMSwitchWithTeaming(switchName, ifName string) error { @@ -1047,13 +1157,14 @@ func enableOVSExtension() error { } func getRoutesOnInterface(linkIndex int) ([]interface{}, error) { - cmd := fmt.Sprintf("Get-NetRoute -InterfaceIndex %d -ErrorAction Ignore | Format-Table -HideTableHeaders", linkIndex) - rs, err := getNetRoutes(cmd) + rts, err := listRoutes(antreasyscall.AF_UNSPEC, func(row antreasyscall.MibIPForwardRow) bool { + return row.Index == uint32(linkIndex) + }) if err != nil { return nil, fmt.Errorf("failed to get routes: %v", err) } var routes []interface{} - for _, r := range rs { + for _, r := range rts { // Skip the routes automatically generated by Windows host when adding IP address on the network adapter. if r.GatewayAddress != nil && r.GatewayAddress.IsUnspecified() { continue @@ -1097,3 +1208,91 @@ func renameHostInterface(oriName string, newName string) error { _, err := runCommand(cmd) return err } + +func getAdaptersByName(name string) ([]adapter, error) { + aas, err := adapterAddresses() + if err != nil { + return nil, err + } + var adapters []adapter + for _, aa := range aas { + ifName := windows.UTF16PtrToString(aa.FriendlyName) + if ifName != name { + continue + } + index := aa.IfIndex + if index == 0 { // ipv6IfIndex is a substitute for ifIndex + index = aa.Ipv6IfIndex + } + ifi := net.Interface{ + Index: int(index), + Name: ifName, + } + if aa.OperStatus == windows.IfOperStatusUp { + ifi.Flags |= net.FlagUp + } + // For now we need to infer link-layer service capabilities from media types. + // TODO: use MIB_IF_ROW2.AccessType now that we no longer support Windows XP. + switch aa.IfType { + case windows.IF_TYPE_ETHERNET_CSMACD, windows.IF_TYPE_ISO88025_TOKENRING, windows.IF_TYPE_IEEE80211, windows.IF_TYPE_IEEE1394: + ifi.Flags |= net.FlagBroadcast | net.FlagMulticast + case windows.IF_TYPE_PPP, windows.IF_TYPE_TUNNEL: + ifi.Flags |= net.FlagPointToPoint | net.FlagMulticast + case windows.IF_TYPE_SOFTWARE_LOOPBACK: + ifi.Flags |= net.FlagLoopback | net.FlagMulticast + case windows.IF_TYPE_ATM: + ifi.Flags |= net.FlagBroadcast | net.FlagPointToPoint | net.FlagMulticast // assume all services available; LANE, point-to-point and point-to-multipoint + } + if aa.Mtu == 0xffffffff { + ifi.MTU = -1 + } else { + ifi.MTU = int(aa.Mtu) + } + if aa.PhysicalAddressLength > 0 { + ifi.HardwareAddr = make(net.HardwareAddr, aa.PhysicalAddressLength) + copy(ifi.HardwareAddr, aa.PhysicalAddress[:]) + } + adapter := adapter{ + Interface: ifi, + compartmentID: aa.CompartmentId, + } + adapters = append(adapters, adapter) + } + return adapters, nil +} + +// GAA_FLAG_INCLUDE_ALL_COMPARTMENTS is used in windows.GetAdapterAddresses parameter +// flags to return addresses in all routing compartments. +const GAA_FLAG_INCLUDE_ALL_COMPARTMENTS = 0x00000200 + +// adapterAddresses returns a list of IpAdapterAddresses structures. The structure +// contains an IP adapter and flattened multiple IP addresses including unicast, anycast +// and multicast addresses. +// This function is copied from go/src/net/interface_windows.go, with a change that flag +// GAA_FLAG_INCLUDE_ALL_COMPARTMENTS is introduced to query interfaces in all compartments. +func adapterAddresses() ([]*windows.IpAdapterAddresses, error) { + flags := uint32(windows.GAA_FLAG_INCLUDE_PREFIX | GAA_FLAG_INCLUDE_ALL_COMPARTMENTS) + var b []byte + l := uint32(15000) // recommended initial size + for { + b = make([]byte, l) + err := getAdaptersAddresses(syscall.AF_UNSPEC, flags, 0, (*windows.IpAdapterAddresses)(unsafe.Pointer(&b[0])), &l) + if err == nil { + if l == 0 { + return nil, nil + } + break + } + if err.(syscall.Errno) != syscall.ERROR_BUFFER_OVERFLOW { + return nil, os.NewSyscallError("getadaptersaddresses", err) + } + if l <= uint32(len(b)) { + return nil, os.NewSyscallError("getadaptersaddresses", err) + } + } + var aas []*windows.IpAdapterAddresses + for aa := (*windows.IpAdapterAddresses)(unsafe.Pointer(&b[0])); aa != nil; aa = aa.Next { + aas = append(aas, aa) + } + return aas, nil +} diff --git a/pkg/agent/util/net_windows_test.go b/pkg/agent/util/net_windows_test.go index 8245ec94b48..9eec87ebf72 100644 --- a/pkg/agent/util/net_windows_test.go +++ b/pkg/agent/util/net_windows_test.go @@ -20,13 +20,18 @@ package util import ( "fmt" "net" + "os" "strings" "testing" + antreasyscalltest "antrea.io/antrea/pkg/agent/util/syscall/testing" + "github.com/Microsoft/hcsshim" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/sys/windows" + antreasyscall "antrea.io/antrea/pkg/agent/util/syscall" "antrea.io/antrea/pkg/ovs/openflow" ) @@ -42,6 +47,19 @@ func TestRouteString(t *testing.T) { assert.Equal(t, "LinkIndex: 1, DestinationSubnet: 192.168.2.0/24, GatewayAddress: 192.168.2.0, RouteMetric: 256", gotRoute) } +func TestRouteTranslation(t *testing.T) { + _, subnet, _ := net.ParseCIDR("1.1.1.0/28") + oriRoute := &Route{ + LinkIndex: 27, + RouteMetric: 35, + DestinationSubnet: subnet, + GatewayAddress: net.ParseIP("1.1.1.254"), + } + row := oriRoute.toMibIPForwardRow() + newRoute := routeFromIPForwardRow(row) + assert.Equal(t, oriRoute, newRoute) +} + func TestNeighborString(t *testing.T) { testNeighbor := Neighbor{ LinkIndex: 1, @@ -70,12 +88,12 @@ func TestIsVirtualAdapter(t *testing.T) { wantIsVirtual bool }{ { - name: "Virtual Adapter", + name: "Virtual adapter", commandOut: " true ", wantIsVirtual: true, }, { - name: "Virtual Adapter Err", + name: "Virtual adapter Err", commandErr: testInvalidErr, wantIsVirtual: false, }, @@ -230,7 +248,7 @@ func TestSetAdapterMACAddress(t *testing.T) { wantErr error }{ { - name: "Set Adapter MAC", + name: "Set adapter MAC", commandOut: "success", }, { @@ -262,7 +280,7 @@ func TestPrepareHNSNetwork(t *testing.T) { GatewayAddress: gw, RouteMetric: MetricDefault, }} - testRoutes := createTestRoutes(routes) + testRoutes := convertTestRoutes(routes) testSubnetCIDR := &net.IPNet{ IP: net.ParseIP("8.8.8.7"), Mask: net.CIDRMask(32, 32), @@ -282,8 +300,6 @@ func TestPrepareHNSNetwork(t *testing.T) { } newIPCmd := fmt.Sprintf(`New-NetIPAddress -InterfaceAlias "%s" -IPAddress %s -PrefixLength %s -DefaultGateway %s`, VirtualAdapterName("0"), nodeZeroIPNetStr[0], nodeZeroIPNetStr[1], "testGateway") setServerCmd := fmt.Sprintf(`Set-DnsClientServerAddress -InterfaceAlias "%s" -ServerAddresses "%s"`, VirtualAdapterName("0"), testDNSServer) - newRouteCmd := fmt.Sprintf("New-NetRoute -InterfaceIndex %v -DestinationPrefix %v -NextHop %v -RouteMetric %d -Verbose", - routes[0].LinkIndex, routes[0].DestinationSubnet.String(), routes[0].GatewayAddress.String(), routes[0].RouteMetric) getAdapterCmd := fmt.Sprintf(`Get-VMNetworkAdapter -ManagementOS -ComputerName "$(hostname)" -SwitchName "%s" | ? MacAddress -EQ "%s" | Select-Object -Property Name | Format-Table -HideTableHeaders`, LocalHNSNetwork, testUplinkMACStr) renameAdapterCmd := fmt.Sprintf(`Get-VMNetworkAdapter -ManagementOS -ComputerName "$(hostname)" -Name "%s" | Rename-VMNetworkAdapter -NewName "%s"`, testAdapterName, testNewName) renameNetCmd := fmt.Sprintf(`Get-NetAdapter -Name "%s" | Rename-NetAdapter -NewName "%s"`, VirtualAdapterName(testNewName), testNewName) @@ -299,6 +315,7 @@ func TestPrepareHNSNetwork(t *testing.T) { commandErr error hnsNetworkRequestError error testNetInterfaceErr error + createRowErr error wantCmds []string wantErr error }{ @@ -320,7 +337,7 @@ func TestPrepareHNSNetwork(t *testing.T) { wantErr: fmt.Errorf("error creating HNSNetwork: invalid"), }, { - name: "Adapter Err", + name: "adapter Err", nodeIPNet: &ipv4PublicIPNet, dnsServers: testDNSServer, ipFound: true, @@ -359,7 +376,7 @@ func TestPrepareHNSNetwork(t *testing.T) { nodeIPNet: &ipv4ZeroIPNet, dnsServers: testDNSServer, ipFound: false, - wantCmds: []string{newIPCmd, setServerCmd, newRouteCmd, getVMCmd, setVMCmd}, + wantCmds: []string{newIPCmd, setServerCmd, getVMCmd, setVMCmd}, }, { name: "IP Not Found Configure Default Err", @@ -371,7 +388,7 @@ func TestPrepareHNSNetwork(t *testing.T) { wantErr: testInvalidErr, }, { - name: "IP Not Found Set Adapter Err", + name: "IP Not Found Set adapter Err", nodeIPNet: &ipv4ZeroIPNet, dnsServers: testDNSServer, ipFound: false, @@ -380,12 +397,12 @@ func TestPrepareHNSNetwork(t *testing.T) { wantErr: alreadyExistsErr, }, { - name: "IP Not Found New Net Route Err", - nodeIPNet: &ipv4ZeroIPNet, - ipFound: false, - commandErr: alreadyExistsErr, - wantCmds: []string{newIPCmd, newRouteCmd}, - wantErr: alreadyExistsErr, + name: "IP Not Found New Net Route Err", + nodeIPNet: &ipv4ZeroIPNet, + ipFound: false, + createRowErr: fmt.Errorf("ip route not found"), + wantCmds: []string{newIPCmd}, + wantErr: fmt.Errorf("failed to create new IPForward row: ip route not found"), }, } @@ -397,25 +414,62 @@ func TestPrepareHNSNetwork(t *testing.T) { defer mockHNSNetworkRequest(nil, tc.hnsNetworkRequestError)() defer mockHNSNetworkCreate(tc.hnsNetworkCreateErr)() defer mockHNSNetworkDelete(nil)() + defer mockAntreaNetIO(&antreasyscalltest.MockNetIO{CreateIPForwardEntryErr: tc.createRowErr})() gotErr := PrepareHNSNetwork(testSubnetCIDR, tc.nodeIPNet, testUplinkAdapter, "testGateway", tc.dnsServers, testRoutes, tc.newName) assert.Equal(t, tc.wantErr, gotErr) }) } } -func TestInterfaceIndexing(t *testing.T) { +func TestGetDefaultGatewayByInterfaceIndex(t *testing.T) { + _, subnet, _ := net.ParseCIDR("0.0.0.0/0") + testIndex := uint32(27) + testIPForwardRow := createTestMibIPForwardRow(testIndex, subnet, net.ParseIP("1.1.1.254")) + listIPForwardRowsErr := fmt.Errorf("unable to list Windows IPForward rows: ip route not found") + tests := []struct { + name string + listRows []antreasyscall.MibIPForwardRow + listRowsErr error + wantGateway string + wantErr error + }{ + { + name: "Index Success", + listRows: []antreasyscall.MibIPForwardRow{testIPForwardRow}, + wantGateway: "1.1.1.254", + }, + { + name: "Index Error", + listRowsErr: fmt.Errorf("ip route not found"), + wantErr: listIPForwardRowsErr, + }, + { + name: "Routes not found", + listRows: []antreasyscall.MibIPForwardRow{}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + defer mockAntreaNetIO(&antreasyscalltest.MockNetIO{ListIPForwardRowsErr: tc.listRowsErr, IPForwardRows: tc.listRows})() + gotGateway, err := GetDefaultGatewayByInterfaceIndex((int)(testIndex)) + assert.Equal(t, tc.wantGateway, gotGateway) + assert.Equal(t, tc.wantErr, err) + }) + } +} + +func TestGetDNServersByInterfaceIndex(t *testing.T) { testIndex := 1 tests := []struct { name string commandOut string commandErr error - wantDefaultGW string wantDNSServer string }{ { name: "Index Success", commandOut: "hello\r\nworld\r\n\r\n", - wantDefaultGW: "helloworld", wantDNSServer: "hello,world", }, { @@ -428,13 +482,8 @@ func TestInterfaceIndexing(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { defer mockRunCommand(t, []string{ - fmt.Sprintf("$(Get-NetRoute -InterfaceIndex %d -DestinationPrefix 0.0.0.0/0 ).NextHop", testIndex), fmt.Sprintf("$(Get-DnsClientServerAddress -InterfaceIndex %d -AddressFamily IPv4).ServerAddresses", testIndex), }, tc.commandOut, tc.commandErr, true)() - gotDefaultGW, err := GetDefaultGatewayByInterfaceIndex(testIndex) - assert.Equal(t, tc.wantDefaultGW, gotDefaultGW) - assert.Equal(t, tc.commandErr, err) - gotDNSServer, err := GetDNServersByInterfaceIndex(testIndex) assert.Equal(t, tc.wantDNSServer, gotDNSServer) assert.Equal(t, tc.commandErr, err) @@ -443,155 +492,148 @@ func TestInterfaceIndexing(t *testing.T) { } func TestHostInterfaceExists(t *testing.T) { - generateWantCmd := func(str string) []string { - return []string{fmt.Sprintf(`Get-NetAdapter -InterfaceAlias "%s"`, str)} - } tests := []struct { name string testNetInterfaceName string - testNetInterfaceErr error - commandErr error - wantCmds []string - wantExists bool + testAdapterAddresses *windows.IpAdapterAddresses }{ { name: "Normal Exist", testNetInterfaceName: "host", - wantExists: true, + testAdapterAddresses: createTestAdapterAddresses("host"), }, { - name: "Container vnic", - testNetInterfaceName: "vnic", - testNetInterfaceErr: fmt.Errorf("not found"), - wantCmds: generateWantCmd("vnic"), - wantExists: true, - }, - { - name: "Interface not exist", - testNetInterfaceName: "0", - testNetInterfaceErr: fmt.Errorf("not found"), - commandErr: testInvalidErr, - wantCmds: generateWantCmd("0"), - wantExists: false, + name: "Interface not exist", }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - defer mockRunCommand(t, tc.wantCmds, "success", tc.commandErr, true)() - defer mockNetInterfaceByName(&net.Interface{}, tc.testNetInterfaceErr)() + defer mockGetAdaptersAddresses(tc.testAdapterAddresses, nil)() gotExists := HostInterfaceExists(tc.testNetInterfaceName) - assert.Equal(t, tc.wantExists, gotExists) + assert.Equal(t, tc.testNetInterfaceName != "", gotExists) }) } } func TestSetInterfaceMTU(t *testing.T) { + testName := "host" + testAdapterAddresses := createTestAdapterAddresses(testName) + testMTU := 2 tests := []struct { - name string - commandOut string - commandErr error - wantErr error + name string + testNetInterfaceName string + testAdapterAddresses *windows.IpAdapterAddresses + getIPInterfaceErr error + setIPInterfaceErr error + wantErr error }{ { - name: "Set Interface MTU", - commandOut: "success", + name: "Set Success", + testNetInterfaceName: testName, + testAdapterAddresses: testAdapterAddresses, }, { - name: "Set Err", - commandErr: testInvalidErr, - wantErr: testInvalidErr, + name: "Interface name invalid", + wantErr: fmt.Errorf("unable to find NetAdapter on host in all compartments with name %s: %v", "", + &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: errInvalidInterfaceName}), + }, + { + name: "Get Interface Err", + testNetInterfaceName: testName, + testAdapterAddresses: testAdapterAddresses, + getIPInterfaceErr: fmt.Errorf("IP interface not found"), + wantErr: fmt.Errorf("unable to set IPInterface with MTU %d: %v", testMTU, + fmt.Errorf("unable to get IPInterface entry with Index %d: IP interface not found", (int)(testAdapterAddresses.IfIndex))), + }, + { + name: "Set Interface Err", + testNetInterfaceName: testName, + testAdapterAddresses: testAdapterAddresses, + setIPInterfaceErr: fmt.Errorf("IP interface set error"), + wantErr: fmt.Errorf("unable to set IPInterface with MTU %d: IP interface set error", testMTU), }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - defer mockRunCommand(t, []string{ - fmt.Sprintf("Set-NetIPInterface -IncludeAllCompartments -InterfaceAlias \"%s\" -NlMtuBytes %d", - "test", 1), - }, tc.commandOut, tc.commandErr, true)() - gotErr := SetInterfaceMTU("test", 1) + defer mockGetAdaptersAddresses(tc.testAdapterAddresses, nil)() + defer mockAntreaNetIO(&antreasyscalltest.MockNetIO{GetIPInterfaceEntryErr: tc.getIPInterfaceErr, SetIPInterfaceEntryErr: tc.setIPInterfaceErr})() + gotErr := SetInterfaceMTU(tc.testNetInterfaceName, testMTU) assert.Equal(t, tc.wantErr, gotErr) }) } } func TestReplaceNetRoute(t *testing.T) { - _, testSubnet, _ := net.ParseCIDR("192.168.2.0/32") - testRoute := &Route{ - LinkIndex: 0, - DestinationSubnet: testSubnet, - GatewayAddress: net.ParseIP("192.168.2.0"), + _, subnet, _ := net.ParseCIDR("1.1.1.0/28") + testIP := net.ParseIP("1.1.1.254") + testIndex := uint32(27) + testIPForwardRow := createTestMibIPForwardRow(testIndex, subnet, testIP) + testRoute := Route{ + LinkIndex: (int)(testIPForwardRow.Index), + DestinationSubnet: subnet, + GatewayAddress: net.ParseIP("1.1.1.254"), RouteMetric: MetricDefault, } - getCmd := fmt.Sprintf("Get-NetRoute -InterfaceIndex %d -DestinationPrefix %s -ErrorAction Ignore | Format-Table -HideTableHeaders", - testRoute.LinkIndex, testRoute.DestinationSubnet.String()) - newCmd := fmt.Sprintf("New-NetRoute -InterfaceIndex %v -DestinationPrefix %v -NextHop %v -RouteMetric %d -Verbose", - testRoute.LinkIndex, testRoute.DestinationSubnet.String(), testRoute.GatewayAddress.String(), testRoute.RouteMetric) - removeCmd := fmt.Sprintf("Remove-NetRoute -InterfaceIndex %v -DestinationPrefix %v -Verbose -Confirm:$false", - testRoute.LinkIndex, testRoute.DestinationSubnet.String()) + listIPForwardRowsErr := fmt.Errorf("unable to list Windows IPForward rows: unable to list IP forward entry") + deleteIPForwardEntryErr := fmt.Errorf("failed to delete existing route with nextHop %s: unable to delete IP forward entry", testRoute.GatewayAddress) + createIPForwardEntryErr := fmt.Errorf("failed to create new IPForward row: unable to create IP forward entry") tests := []struct { - name string - route *Route - commandOut string - commandErr error - wantCmds []string - wantErrStr string + name string + listRows []antreasyscall.MibIPForwardRow + listRowsErr error + createIPForwardErr error + deleteIPForwardErr error + wantErr error }{ { - name: "Replace Route", - route: testRoute, - commandOut: "0 192.168.1.0/24 192.168.1.0 256 nil nil", - wantCmds: []string{getCmd, removeCmd, newCmd}, + name: "Replace Success", + listRows: []antreasyscall.MibIPForwardRow{createTestMibIPForwardRow(testIndex, subnet, net.ParseIP("1.1.1.1"))}, }, { - name: "Get Route Err", - route: testRoute, - commandOut: "err 192.168.1.0/24 192.168.1.0 256 nil nil", - wantCmds: []string{getCmd}, - wantErrStr: "failed to parse the LinkIndex", + name: "Same GatewayAddress", + listRows: []antreasyscall.MibIPForwardRow{createTestMibIPForwardRow(testIndex, subnet, testIP)}, }, { - name: "Get Route Not Exist", - route: testRoute, - wantCmds: []string{getCmd, newCmd}, + name: "List Rows Err", + listRowsErr: fmt.Errorf("unable to list IP forward entry"), + wantErr: listIPForwardRowsErr, }, { - name: "New Route Err", - route: testRoute, - commandErr: testInvalidErr, - wantCmds: []string{getCmd, newCmd}, - wantErrStr: "invalid", + name: "Delete Ip Forward Entry Err", + listRows: []antreasyscall.MibIPForwardRow{createTestMibIPForwardRow(testIndex, subnet, net.ParseIP("1.1.1.1"))}, + deleteIPForwardErr: fmt.Errorf("unable to delete IP forward entry"), + wantErr: deleteIPForwardEntryErr, }, { - name: "Duplicate Route", - route: testRoute, - commandOut: "0 192.168.2.0/24 192.168.2.0 256 nil nil", - wantCmds: []string{getCmd}, - }, - { - name: "Remove Route Err", - route: testRoute, - commandOut: "0 192.168.1.0/24 192.168.1.0 256 nil nil", - commandErr: testInvalidErr, - wantCmds: []string{getCmd, removeCmd}, - wantErrStr: "invalid", + name: "Add Route Err", + listRows: []antreasyscall.MibIPForwardRow{createTestMibIPForwardRow(testIndex, subnet, net.ParseIP("1.1.1.1"))}, + createIPForwardErr: fmt.Errorf("unable to create IP forward entry"), + wantErr: createIPForwardEntryErr, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - defer mockRunCommand(t, tc.wantCmds, tc.commandOut, tc.commandErr, true)() - gotErr := ReplaceNetRoute(tc.route) - if tc.wantErrStr == "" { - require.NoError(t, gotErr) - } else { - assert.ErrorContains(t, gotErr, tc.wantErrStr) - } + defer mockAntreaNetIO(&antreasyscalltest.MockNetIO{CreateIPForwardEntryErr: tc.createIPForwardErr, DeleteIPForwardEntryErr: tc.deleteIPForwardErr, ListIPForwardRowsErr: tc.listRowsErr, IPForwardRows: tc.listRows})() + gotErr := ReplaceNetRoute(&testRoute) + assert.Equal(t, tc.wantErr, gotErr) }) } } +func TestGetNetRoutesAll(t *testing.T) { + gw, subnet, _ := net.ParseCIDR("192.168.2.0/24") + testRow := createTestMibIPForwardRow(0, subnet, gw) + listRows := []antreasyscall.MibIPForwardRow{testRow} + wantRoutes := []Route{*routeFromIPForwardRow(&testRow)} + defer mockAntreaNetIO(&antreasyscalltest.MockNetIO{IPForwardRows: listRows, ListIPForwardRowsErr: nil})() + gotRoutes, gotErr := GetNetRoutesAll() + assert.Equal(t, wantRoutes, gotRoutes) + assert.Nil(t, gotErr) +} + func TestNewNetNat(t *testing.T) { notFoundErr := fmt.Errorf("received error No MSFT_NetNat objects found") testNetNat := "test-nat" @@ -830,72 +872,50 @@ func TestVirtualAdapterName(t *testing.T) { func TestGetInterfaceConfig(t *testing.T) { gw, subnet, _ := net.ParseCIDR("192.168.2.0/24") - routes := []Route{{ - LinkIndex: 0, - DestinationSubnet: subnet, - GatewayAddress: gw, - RouteMetric: MetricDefault, - }} - testRoutes := createTestRoutes(routes) + testRow := createTestMibIPForwardRow(0, subnet, gw) + routes := []Route{*routeFromIPForwardRow(&testRow)} + testRoutes := convertTestRoutes(routes) testNetInterface := generateNetInterface("0") - wantCmds := []string{fmt.Sprintf("Get-NetRoute -InterfaceIndex %d -ErrorAction Ignore | Format-Table -HideTableHeaders", 0)} tests := []struct { name string testNetInterfaceErr error - commandOut string + listRows []antreasyscall.MibIPForwardRow + listRowsErr error wantAddrs []*net.IPNet wantRoutes []interface{} - wantCmds []string - wantErrStr string + wantErr error }{ { name: "Get Interface Config Success", - commandOut: "0 192.168.2.0/24 192.168.2.0 256 nil nil nil", + listRows: []antreasyscall.MibIPForwardRow{testRow}, wantAddrs: []*net.IPNet{&ipv4PublicIPNet}, wantRoutes: testRoutes, - wantCmds: wantCmds, }, { name: "Interface Err", testNetInterfaceErr: testInvalidErr, - commandOut: "0 192.168.2.0/24 192.168.2.0 256 nil nil nil", - wantErrStr: "failed to get interface 0: invalid", + wantErr: fmt.Errorf("failed to get interface %s: %v", "0", testInvalidErr), }, { - name: "Route Index Err", - commandOut: "err 192.168.2.0/24 192.168.2.0 256 nil nil nil", - wantErrStr: "failed to parse the LinkIndex", - wantCmds: wantCmds, - }, - { - name: "Route Subnet Err", - commandOut: "0 err 192.168.2.0 256 nil nil nil", - wantErrStr: "failed to parse the DestinationSubnet", - wantCmds: wantCmds, - }, - { - name: "Route Metric Err", - commandOut: "0 192.168.2.0/24 192.168.2.0 err nil nil nil", - wantErrStr: "failed to parse the RouteMetric", - wantCmds: wantCmds, + name: "Route Err", + listRows: []antreasyscall.MibIPForwardRow{testRow}, + listRowsErr: fmt.Errorf("unable to list IP forward rows"), + wantErr: fmt.Errorf("failed to get routes for interface index %d: %v", testNetInterface.Index, + fmt.Errorf("failed to get routes: unable to list Windows IPForward rows: unable to list IP forward rows")), }, } - for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - defer mockRunCommand(t, tc.wantCmds, tc.commandOut, nil, true)() defer mockNetInterfaceByName(&testNetInterface, tc.testNetInterfaceErr)() defer mockNetInterfaceAddrs(testNetInterface, nil)() + defer mockAntreaNetIO(&antreasyscalltest.MockNetIO{IPForwardRows: tc.listRows, ListIPForwardRowsErr: tc.listRowsErr})() gotInterface, gotAddrs, gotRoutes, gotErr := GetInterfaceConfig("0") - assert.Equal(t, tc.wantAddrs, gotAddrs) - assert.EqualValues(t, tc.wantRoutes, gotRoutes) - if tc.wantErrStr == "" { + if tc.wantErr == nil { assert.EqualValues(t, testNetInterface, *gotInterface) - require.NoError(t, gotErr) - } else { - assert.Nil(t, gotInterface) - assert.ErrorContains(t, gotErr, tc.wantErrStr) } + assert.Equal(t, tc.wantAddrs, gotAddrs) + assert.EqualValues(t, tc.wantRoutes, gotRoutes) + assert.Equal(t, tc.wantErr, gotErr) }) } } @@ -1055,7 +1075,61 @@ func TestGenHostInterfaceName(t *testing.T) { assert.Equal(t, "host", hostInterface) } -func createTestRoutes(routes []Route) []interface{} { +func TestGetAdapterInAllCompartmentsByName(t *testing.T) { + testName := "host" + testFlags := net.FlagUp | net.FlagBroadcast | net.FlagPointToPoint | net.FlagMulticast + testAdapter := adapter{ + Interface: net.Interface{ + Index: 1, + Name: testName, + Flags: testFlags, + MTU: 1, + HardwareAddr: testMACAddr, + }, + compartmentID: 1, + } + tests := []struct { + name string + testName string + testAdapters *windows.IpAdapterAddresses + testAdaptersErr error + wantAdapters *adapter + wantErr error + }{ + { + name: "Normal", + testName: testName, + testAdapters: createTestAdapterAddresses(testName), + wantAdapters: &testAdapter, + }, + { + name: "Invalid name", + wantErr: &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: errInvalidInterfaceName}, + }, + { + name: "adapter Err", + testName: testName, + testAdaptersErr: windows.ERROR_FILE_NOT_FOUND, + wantErr: &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: os.NewSyscallError("getadaptersaddresses", windows.ERROR_FILE_NOT_FOUND)}, + }, + { + name: "adapter not found", + testName: testName, + wantErr: &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: errNoSuchInterface}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + defer mockGetAdaptersAddresses(tc.testAdapters, tc.testAdaptersErr)() + gotAdapters, gotErr := getAdapterInAllCompartmentsByName(tc.testName) + assert.EqualValues(t, tc.wantAdapters, gotAdapters) + assert.EqualValues(t, tc.wantErr, gotErr) + }) + } +} + +func convertTestRoutes(routes []Route) []interface{} { testRoutes := make([]interface{}, len(routes)) for i, v := range routes { testRoutes[i] = v @@ -1063,6 +1137,59 @@ func createTestRoutes(routes []Route) []interface{} { return testRoutes } +func createTestAdapterAddresses(name string) *windows.IpAdapterAddresses { + testPhysicalAddress := [8]byte{} + copy(testPhysicalAddress[:6], testMACAddr) + testName, _ := windows.UTF16FromString(name) + return &windows.IpAdapterAddresses{ + FriendlyName: &testName[0], + IfIndex: 1, + OperStatus: windows.IfOperStatusUp, + IfType: windows.IF_TYPE_ATM, + Mtu: 1, + PhysicalAddressLength: 6, + PhysicalAddress: testPhysicalAddress, + CompartmentId: 1, + } +} + +func createTestMibIPForwardRow(index uint32, subnet *net.IPNet, ip net.IP) antreasyscall.MibIPForwardRow { + return antreasyscall.MibIPForwardRow{ + Index: index, + Metric: MetricDefault, + DestinationPrefix: *antreasyscall.NewAddressPrefixFromIPNet(subnet), + NextHop: *antreasyscall.NewRawSockAddrInetFromIP(ip), + } +} + +func mockAntreaNetIO(mockNetIO *antreasyscalltest.MockNetIO) func() { + originalNetIO := antreaNetIO + antreaNetIO = mockNetIO + return func() { + antreaNetIO = originalNetIO + } +} + +func mockGetAdaptersAddresses(testAdaptersAddresses *windows.IpAdapterAddresses, err error) func() { + originalGetAdaptersAddresses := getAdaptersAddresses + getAdaptersAddresses = func(family uint32, flags uint32, reserved uintptr, adapterAddresses *windows.IpAdapterAddresses, sizePointer *uint32) (errcode error) { + if adapterAddresses != nil && testAdaptersAddresses != nil { + adapterAddresses.IfIndex = testAdaptersAddresses.IfIndex + adapterAddresses.FriendlyName = testAdaptersAddresses.FriendlyName + adapterAddresses.OperStatus = testAdaptersAddresses.OperStatus + adapterAddresses.IfType = testAdaptersAddresses.IfType + adapterAddresses.Mtu = testAdaptersAddresses.Mtu + adapterAddresses.PhysicalAddressLength = testAdaptersAddresses.PhysicalAddressLength + adapterAddresses.PhysicalAddress = testAdaptersAddresses.PhysicalAddress + adapterAddresses.CompartmentId = testAdaptersAddresses.CompartmentId + } + return err + } + return func() { + getAdaptersAddresses = originalGetAdaptersAddresses + } +} + // mockRunCommand mocks runCommand with a custom command output and error message. // If exactMatch is enabled, this function asserts that the executed commands are // exactly the same with wantCmds in terms of order and value. Otherwise, for tests diff --git a/pkg/agent/util/powershell/powershell_windows.go b/pkg/agent/util/powershell/powershell_windows.go index df8629de4cf..f34c01014c5 100644 --- a/pkg/agent/util/powershell/powershell_windows.go +++ b/pkg/agent/util/powershell/powershell_windows.go @@ -1,7 +1,7 @@ //go:build windows // +build windows -//Copyright 2021 Antrea Authors +// Copyright 2021 Antrea Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pkg/agent/util/syscall/interfaces.go b/pkg/agent/util/syscall/interfaces.go deleted file mode 100644 index a6624530830..00000000000 --- a/pkg/agent/util/syscall/interfaces.go +++ /dev/null @@ -1,15 +0,0 @@ -//Copyright 2023 Antrea Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package syscall diff --git a/pkg/agent/util/syscall/syscall_windows.go b/pkg/agent/util/syscall/syscall_windows.go new file mode 100644 index 00000000000..c0b5d29e9ca --- /dev/null +++ b/pkg/agent/util/syscall/syscall_windows.go @@ -0,0 +1,386 @@ +// Copyright 2023 Antrea Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package syscall + +import ( + "net" + "net/netip" + "os" + "strconv" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" + + utilip "antrea.io/antrea/pkg/util/ip" +) + +const ( + AF_UNSPEC uint16 = uint16(windows.AF_UNSPEC) + AF_INET uint16 = uint16(windows.AF_INET) + AF_INET6 uint16 = uint16(windows.AF_INET6) +) + +// The following definitions are copied from Nldef header in Win32 API reference documentation. +// RouterDiscoveryBehavior defines the router discovery behavior. +type RouterDiscoveryBehavior int32 + +const ( + RouterDiscoveryDisabled RouterDiscoveryBehavior = 0 + RouterDiscoveryEnabled RouterDiscoveryBehavior = 1 + RouterDiscoveryDHCP RouterDiscoveryBehavior = 2 + RouterDiscoveryUnchanged RouterDiscoveryBehavior = -1 +) + +// LinkLocalAddressBehavior defines the link local address behavior. +type LinkLocalAddressBehavior int32 + +const ( + LinkLocalAlwaysOff LinkLocalAddressBehavior = 0 + LinkLocalDelayed LinkLocalAddressBehavior = 1 + LinkLocalAlwaysOn LinkLocalAddressBehavior = 2 + LinkLocalUnchanged LinkLocalAddressBehavior = -1 +) + +const ScopeLevelCount = 16 + +// NlInterfaceOffloadRodFlags specifies a set of flags that indicate the offload +// capabilities for an IP interface. +type NlInterfaceOffloadRodFlags uint8 + +const ( + NlChecksumSupported NlInterfaceOffloadRodFlags = 0x01 + nlOptionsSupported NlInterfaceOffloadRodFlags = 0x02 + TlDatagramChecksumSupported NlInterfaceOffloadRodFlags = 0x04 + TlStreamChecksumSupported NlInterfaceOffloadRodFlags = 0x08 + TlStreamOptionsSupported NlInterfaceOffloadRodFlags = 0x10 + FastPathCompatible NlInterfaceOffloadRodFlags = 0x20 + TlLargeSendOffloadSupported NlInterfaceOffloadRodFlags = 0x40 + TlGiantSendOffloadSupported NlInterfaceOffloadRodFlags = 0x80 +) + +type MibIPInterfaceRow struct { + Family uint16 + Luid uint64 + Index uint32 + MaxReassemblySize uint32 + Identifier uint64 + MinRouterAdvertisementInterval uint32 + MaxRouterAdvertisementInterval uint32 + AdvertisingEnabled bool + ForwardingEnabled bool + WeakHostSend bool + WeakHostReceive bool + UseAutomaticMetric bool + UseNeighborUnreachabilityDetection bool + ManagedAddressConfigurationSupported bool + OtherStatefulConfigurationSupported bool + AdvertiseDefaultRoute bool + RouterDiscoveryBehavior RouterDiscoveryBehavior + DadTransmits uint32 + BaseReachableTime uint32 + RetransmitTime uint32 + PathMtuDiscoveryTimeout uint32 + LinkLocalAddressBehavior LinkLocalAddressBehavior + LinkLocalAddressTimeout uint32 + ZoneIndices [ScopeLevelCount]uint32 + SitePrefixLength uint32 + Metric uint32 + NlMtu uint32 + Connected bool + SupportsWakeUpPatterns bool + SupportsNeighborDiscovery bool + SupportsRouterDiscovery bool + ReachableTime uint32 + TransmitOffload NlInterfaceOffloadRodFlags + ReceiveOffload NlInterfaceOffloadRodFlags + DisableDefaultRoutes bool +} + +type RawSockAddrInet struct { + Family uint16 + data [26]byte +} + +func (a *RawSockAddrInet) IP() net.IP { + if a == nil { + return nil + } + if a.Family == AF_INET { + addr := (*syscall.RawSockaddrInet4)(unsafe.Pointer(a)) + return net.IPv4(addr.Addr[0], addr.Addr[1], addr.Addr[2], addr.Addr[3]) + } + if a.Family == AF_INET6 { + addr := (*syscall.RawSockaddrInet6)(unsafe.Pointer(a)) + return addr.Addr[:] + } + return net.IPv6unspecified +} + +func (a *RawSockAddrInet) String() string { + return a.IP().String() +} + +func NewRawSockAddrInetFromIP(ip net.IP) *RawSockAddrInet { + sockAddrInet := new(RawSockAddrInet) + if ip.To4() != nil { + addr, _ := netip.AddrFromSlice(ip.To4()) + addr4 := (*windows.RawSockaddrInet4)(unsafe.Pointer(sockAddrInet)) + addr4.Family = AF_INET + addr4.Addr = addr.As4() + addr4.Port = 0 + addr4.Zero = [8]byte{} + return sockAddrInet + } + addr, _ := netip.AddrFromSlice(ip) + addr6 := (*windows.RawSockaddrInet6)(unsafe.Pointer(sockAddrInet)) + addr6.Family = AF_INET6 + addr6.Addr = addr.As16() + addr6.Port = 0 + addr6.Flowinfo = 0 + scopeId := uint32(0) + if z := addr.Zone(); z != "" { + if s, err := strconv.ParseUint(z, 10, 32); err == nil { + scopeId = uint32(s) + } + } + addr6.Scope_id = scopeId + return sockAddrInet +} + +type AddressPrefix struct { + Prefix RawSockAddrInet + prefixLength uint8 + _ [2]byte // Add two bytes to keep alignment. +} + +func (p *AddressPrefix) IPNet() *net.IPNet { + if p == nil { + return nil + } + sockAddr := p.Prefix + if sockAddr.Family == AF_INET { + return &net.IPNet{ + IP: (&sockAddr).IP().To4(), + Mask: net.CIDRMask(int(p.prefixLength), 8*net.IPv4len), + } + } + if p.Prefix.Family == AF_INET6 { + return &net.IPNet{ + IP: (&sockAddr).IP(), + Mask: net.CIDRMask(int(p.prefixLength), 8*net.IPv6len), + } + } + return nil +} + +func (p *AddressPrefix) EqualsTo(ipNet *net.IPNet) bool { + if ipNet == nil && p == nil { + return true + } else if ipNet == nil || p == nil { + return false + } + if p.prefixLength == 0 { + return ipNet.IP.Equal(net.IPv4zero) || ipNet.IP.Equal(net.IPv6zero) + } + return utilip.IPNetEqual(p.IPNet(), ipNet) +} + +func (p *AddressPrefix) String() string { + return p.IPNet().String() +} + +func NewAddressPrefixFromIPNet(ipnet *net.IPNet) *AddressPrefix { + if ipnet == nil { + return nil + } + sockAddr := NewRawSockAddrInetFromIP(ipnet.IP) + prefixLength, _ := ipnet.Mask.Size() + return &AddressPrefix{ + Prefix: *sockAddr, + prefixLength: uint8(prefixLength), + } +} + +// NlRouteProtocol defines the routing mechanism that an IP route was added with. +type NlRouteProtocol uint32 + +const ( + RouteProtocolOther NlRouteProtocol = 1 + RouteProtocolLocal NlRouteProtocol = 2 + RouteProtocolNetMgmt NlRouteProtocol = 3 + RouteProtocolIcmp NlRouteProtocol = 4 + RouteProtocolEgp NlRouteProtocol = 5 + RouteProtocolGgp NlRouteProtocol = 6 + RouteProtocolHello NlRouteProtocol = 7 + RouteProtocolRip NlRouteProtocol = 8 + RouteProtocolIsIs NlRouteProtocol = 9 + RouteProtocolEsIs NlRouteProtocol = 10 + RouteProtocolCisco NlRouteProtocol = 11 + RouteProtocolBbn NlRouteProtocol = 12 + RouteProtocolOspf NlRouteProtocol = 13 + RouteProtocolBgp NlRouteProtocol = 14 + RouteProtocolIdpr NlRouteProtocol = 15 + RouteProtocolEigrp NlRouteProtocol = 16 + RouteProtocolDvmrp NlRouteProtocol = 17 + RouteProtocolRpl NlRouteProtocol = 18 + RouteProtocolDhcp NlRouteProtocol = 19 + + // + // Windows-specific definitions. + // + NT_AUTOSTATIC NlRouteProtocol = 10002 + NT_STATIC NlRouteProtocol = 10006 + NT_STATIC_NON_DOD NlRouteProtocol = 10007 +) + +// NlRouteOrigin defines the origin of the IP route. +type NlRouteOrigin uint32 + +const ( + NlroManual NlRouteOrigin = 0 + NlroWellKnown NlRouteOrigin = 1 + NlroDHCP NlRouteOrigin = 2 + NlroRouterAdvertisement NlRouteOrigin = 3 + Nlro6to4 NlRouteOrigin = 4 +) + +type MibIPForwardRow struct { + Luid uint64 + Index uint32 + DestinationPrefix AddressPrefix + NextHop RawSockAddrInet + + SitePrefixLength uint8 + ValidLifetime uint32 + PreferredLifetime uint32 + Metric uint32 + Protocol NlRouteProtocol + + Loopback bool + AutoconfigureAddress bool + Publish bool + Immortal bool + + Age uint32 + Origin NlRouteOrigin +} + +type MibIPForwardTable struct { + NumEntries uint32 + Table [1]MibIPForwardRow +} + +var ( + modiphlpapi = syscall.NewLazyDLL("iphlpapi.dll") + + procGetIPInterfaceEntry = modiphlpapi.NewProc("GetIpInterfaceEntry") + procSetIPInterfaceEntry = modiphlpapi.NewProc("SetIpInterfaceEntry") + procCreateIPForwardEntry = modiphlpapi.NewProc("CreateIpForwardEntry2") + procDeleteIPForwardEntry = modiphlpapi.NewProc("DeleteIpForwardEntry2") + procGetIPForwardTable = modiphlpapi.NewProc("GetIpForwardTable2") + procFreeMibTable = modiphlpapi.NewProc("FreeMibTable") +) + +type NetIOInterface interface { + GetIPInterfaceEntry(ipInterfaceRow *MibIPInterfaceRow) (errcode error) + + SetIPInterfaceEntry(ipInterfaceRow *MibIPInterfaceRow) (errcode error) + + CreateIPForwardEntry(ipForwardEntry *MibIPForwardRow) (errcode error) + + DeleteIPForwardEntry(ipForwardEntry *MibIPForwardRow) (errcode error) + + ListIPForwardRows(family uint16) ([]MibIPForwardRow, error) +} + +type netIO struct { + syscallN func(trap uintptr, args ...uintptr) (r1, r2 uintptr, err syscall.Errno) +} + +func NewNetIO() NetIOInterface { + return &netIO{syscallN: syscall.SyscallN} +} + +func (n *netIO) GetIPInterfaceEntry(ipInterfaceRow *MibIPInterfaceRow) (errcode error) { + r0, _, _ := n.syscallN(procGetIPInterfaceEntry.Addr(), uintptr(unsafe.Pointer(ipInterfaceRow))) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} + +func (n *netIO) SetIPInterfaceEntry(ipInterfaceRow *MibIPInterfaceRow) (errcode error) { + r0, _, _ := n.syscallN(procSetIPInterfaceEntry.Addr(), uintptr(unsafe.Pointer(ipInterfaceRow))) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} + +func (n *netIO) CreateIPForwardEntry(ipForwardEntry *MibIPForwardRow) (errcode error) { + r0, _, _ := n.syscallN(procCreateIPForwardEntry.Addr(), uintptr(unsafe.Pointer(ipForwardEntry))) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} + +func (n *netIO) DeleteIPForwardEntry(ipForwardEntry *MibIPForwardRow) (errcode error) { + r0, _, _ := n.syscallN(procDeleteIPForwardEntry.Addr(), uintptr(unsafe.Pointer(ipForwardEntry))) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} + +func (n *netIO) freeMibTable(table unsafe.Pointer) { + n.syscallN(procFreeMibTable.Addr(), uintptr(table)) + return +} + +func (n *netIO) getIPForwardTable(family uint16, ipForwardTable **MibIPForwardTable) (errcode error) { + r0, _, _ := n.syscallN(procGetIPForwardTable.Addr(), uintptr(family), uintptr(unsafe.Pointer(ipForwardTable))) + if r0 != 0 { + errcode = syscall.Errno(r0) + } + return +} + +func (n *netIO) ListIPForwardRows(family uint16) ([]MibIPForwardRow, error) { + var table *MibIPForwardTable + err := n.getIPForwardTable(family, &table) + if table != nil { + defer n.freeMibTable(unsafe.Pointer(table)) + } + if err != nil { + return nil, os.NewSyscallError("iphlpapi.GetIpForwardTable", err) + } + return unsafe.Slice(&table.Table[0], table.NumEntries), nil +} + +func NewIPForwardRow() *MibIPForwardRow { + return &MibIPForwardRow{ + SitePrefixLength: 255, + Metric: 0, + Loopback: true, + AutoconfigureAddress: true, + Publish: true, + Immortal: true, + ValidLifetime: 0xffffffff, + PreferredLifetime: 0xffffffff, + Protocol: RouteProtocolOther, + } +} diff --git a/pkg/agent/util/syscall/syscall_windows_test.go b/pkg/agent/util/syscall/syscall_windows_test.go new file mode 100644 index 00000000000..e7f9d36be42 --- /dev/null +++ b/pkg/agent/util/syscall/syscall_windows_test.go @@ -0,0 +1,219 @@ +// Copyright 2023 Antrea Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package syscall + +import ( + "net" + "os" + "syscall" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRawSockAddrTranslation(t *testing.T) { + for _, ipStr := range []string{ + "1.1.1.2", + "abcd:12:03::adb3", + } { + ip := net.ParseIP(ipStr) + sockAddr := NewRawSockAddrInetFromIP(ip) + parsedIP := sockAddr.IP() + assert.True(t, ip.Equal(parsedIP)) + } +} + +func TestAddressPrefixTranslation(t *testing.T) { + for _, ipnet := range []*net.IPNet{ + { + IP: net.ParseIP("1.1.1.0"), + Mask: net.CIDRMask(28, 32), + }, + { + IP: net.ParseIP("1.1.1.2"), + Mask: net.CIDRMask(28, 32), + }, + { + IP: net.ParseIP("abcd:12:03::adb3"), + Mask: net.CIDRMask(96, 128), + }, + { + IP: net.ParseIP("abcd:12:03::"), + Mask: net.CIDRMask(96, 128), + }, + { + IP: net.IPv4zero, + Mask: net.CIDRMask(0, 32), + }, + { + IP: net.IPv6zero, + Mask: net.CIDRMask(0, 128), + }, + } { + sockAddr := NewAddressPrefixFromIPNet(ipnet) + parsedIPNet := sockAddr.IPNet() + assert.True(t, ipnet.IP.Equal(parsedIPNet.IP)) + assert.Equal(t, ipnet.String(), parsedIPNet.String()) + } +} + +func TestRawSockAddrInetBasics(t *testing.T) { + tests := []struct { + name string + testInet *RawSockAddrInet + wantIP net.IP + }{ + { + name: "IPv4", + testInet: NewRawSockAddrInetFromIP(net.IPv4bcast), + wantIP: net.IPv4bcast, + }, + { + name: "IPv6", + testInet: NewRawSockAddrInetFromIP(net.IPv6zero), + wantIP: net.IPv6zero, + }, + { + name: "Unspecified", + testInet: &RawSockAddrInet{Family: AF_UNSPEC}, + wantIP: net.IPv6unspecified, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.wantIP, tc.testInet.IP()) + assert.Equal(t, tc.wantIP.String(), tc.testInet.String()) + }) + } +} + +func TestAddressPrefixBasics(t *testing.T) { + testIPv4Net := &net.IPNet{ + IP: net.ParseIP("1.1.1.0").To4(), + Mask: net.CIDRMask(28, 32), + } + testIpv6Net := &net.IPNet{ + IP: net.ParseIP("abcd:12:03::adb3"), + Mask: net.CIDRMask(96, 128), + } + testDiffNet := &net.IPNet{ + IP: net.ParseIP("1.1.2.0"), + Mask: net.CIDRMask(28, 32), + } + tests := []struct { + name string + testInet *AddressPrefix + wantIPNet *net.IPNet + }{ + { + name: "IPv4", + testInet: NewAddressPrefixFromIPNet(testIPv4Net), + wantIPNet: testIPv4Net, + }, + { + name: "IPv6", + testInet: NewAddressPrefixFromIPNet(testIpv6Net), + wantIPNet: testIpv6Net, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.wantIPNet, tc.testInet.IPNet()) + assert.Equal(t, tc.wantIPNet.String(), tc.testInet.String()) + assert.True(t, tc.testInet.EqualsTo(tc.wantIPNet)) + assert.False(t, tc.testInet.EqualsTo(testDiffNet)) + }) + } + // Test more cases AddressPrefix EqualsTo + testZeroNet := &net.IPNet{ + IP: net.IPv4zero, + Mask: net.CIDRMask(0, 32), + } + assert.True(t, NewAddressPrefixFromIPNet(testZeroNet).EqualsTo(testZeroNet)) +} + +func TestIPInterfaceEntryOperations(t *testing.T) { + tests := []struct { + name string + syscallR1 uintptr + wantErr error + }{ + { + name: "Normal", + syscallR1: 0, + }, + { + name: "Get Err", + syscallR1: 22, + wantErr: syscall.Errno(22), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + testNetIO := NewTestNetIO(tc.syscallR1) + gotErr := testNetIO.GetIPInterfaceEntry(&MibIPInterfaceRow{}) + assert.Equal(t, tc.wantErr, gotErr) + gotErr = testNetIO.SetIPInterfaceEntry(&MibIPInterfaceRow{}) + assert.Equal(t, tc.wantErr, gotErr) + }) + } +} + +func TestIPForwardEntryOperations(t *testing.T) { + tests := []struct { + name string + syscallR1 uintptr + wantErr error + }{ + { + name: "Normal", + syscallR1: 0, + }, + { + name: "Get Err", + syscallR1: 22, + wantErr: syscall.Errno(22), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + testNetIO := NewTestNetIO(tc.syscallR1) + gotErr := testNetIO.CreateIPForwardEntry(&MibIPForwardRow{}) + assert.Equal(t, tc.wantErr, gotErr) + gotErr = testNetIO.DeleteIPForwardEntry(&MibIPForwardRow{}) + assert.Equal(t, tc.wantErr, gotErr) + }) + } +} + +func TestListIPForwardRows(t *testing.T) { + wantErr := os.NewSyscallError("iphlpapi.GetIpForwardTable", syscall.Errno(22)) + testNetIO := NewTestNetIO(22) + // Skipping no error case because converting uintptr back to Pointer is not valid in general. + gotRow, gotErr := testNetIO.ListIPForwardRows(AF_INET) + assert.Nil(t, gotRow) + assert.Equal(t, wantErr, gotErr) +} + +func NewTestNetIO(wantR1 uintptr) NetIOInterface { + mockSyscallN := func(trap uintptr, args ...uintptr) (r1, r2 uintptr, err syscall.Errno) { + return wantR1, 0, 0 + } + return &netIO{syscallN: mockSyscallN} +} diff --git a/pkg/agent/util/syscall/testing/mock_syscall_windows.go b/pkg/agent/util/syscall/testing/mock_syscall_windows.go new file mode 100644 index 00000000000..b5e35e537dd --- /dev/null +++ b/pkg/agent/util/syscall/testing/mock_syscall_windows.go @@ -0,0 +1,55 @@ +// Copyright 2023 Antrea Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testing + +import ( + antreasyscall "antrea.io/antrea/pkg/agent/util/syscall" +) + +// MockNetIO is a custom defined mock struct, not generated by MockGen. +type MockNetIO struct { + GetIPInterfaceEntryErr error + SetIPInterfaceEntryErr error + CreateIPForwardEntryErr error + DeleteIPForwardEntryErr error + ListIPForwardRowsErr error + IPForwardRows []antreasyscall.MibIPForwardRow +} + +func NewMockNetIO(testMibIPForwardRows []antreasyscall.MibIPForwardRow) antreasyscall.NetIOInterface { + return &MockNetIO{ + IPForwardRows: testMibIPForwardRows, + } +} + +func (n *MockNetIO) GetIPInterfaceEntry(_ *antreasyscall.MibIPInterfaceRow) (errcode error) { + return n.GetIPInterfaceEntryErr +} + +func (n *MockNetIO) SetIPInterfaceEntry(_ *antreasyscall.MibIPInterfaceRow) (errcode error) { + return n.SetIPInterfaceEntryErr +} + +func (n *MockNetIO) CreateIPForwardEntry(_ *antreasyscall.MibIPForwardRow) (errcode error) { + return n.CreateIPForwardEntryErr +} + +func (n *MockNetIO) DeleteIPForwardEntry(_ *antreasyscall.MibIPForwardRow) (errcode error) { + return n.DeleteIPForwardEntryErr +} + +func (n *MockNetIO) ListIPForwardRows(_ uint16) ([]antreasyscall.MibIPForwardRow, error) { + return n.IPForwardRows, n.ListIPForwardRowsErr +}