diff --git a/gateway.go b/gateway.go index 2f1021e..41ad765 100644 --- a/gateway.go +++ b/gateway.go @@ -43,6 +43,15 @@ func (e *ErrInvalidRouteFileFormat) Error() string { // DiscoverGateway is the OS independent function to get the default gateway func DiscoverGateway() (ip net.IP, err error) { + ips, err := DiscoverGateways() + if err != nil { + return nil, err + } + return ips[0], nil +} + +// DiscoverGateways is the OS independent function to get all gateways +func DiscoverGateways() (ips []net.IP, err error) { return discoverGatewayOSSpecific() } diff --git a/gatewayForBSDs.go b/gatewayForBSDs.go index e46c9fe..28546da 100644 --- a/gatewayForBSDs.go +++ b/gatewayForBSDs.go @@ -15,7 +15,7 @@ func readNetstat() ([]byte, error) { return routeCmd.CombinedOutput() } -func discoverGatewayOSSpecific() (ip net.IP, err error) { +func discoverGatewayOSSpecific() (ips []net.IP, err error) { rib, err := route.FetchRIB(syscall.AF_INET, syscall.NET_RT_DUMP, 0) if err != nil { return nil, err @@ -26,6 +26,7 @@ func discoverGatewayOSSpecific() (ip net.IP, err error) { return nil, err } + var result []net.IP for _, m := range msgs { switch m := m.(type) { case *route.RouteMessage: @@ -33,15 +34,18 @@ func discoverGatewayOSSpecific() (ip net.IP, err error) { switch sa := m.Addrs[syscall.RTAX_GATEWAY].(type) { case *route.Inet4Addr: ip = net.IPv4(sa.IP[0], sa.IP[1], sa.IP[2], sa.IP[3]) - return ip, nil + result = append(result, ip) case *route.Inet6Addr: ip = make(net.IP, net.IPv6len) copy(ip, sa.IP[:]) - return ip, nil + result = append(result, ip) } } } - return nil, &ErrNoGateway{} + if len(result) == 0 { + return nil, &ErrNoGateway{} + } + return result, nil } func discoverGatewayInterfaceOSSpecific() (ip net.IP, err error) { diff --git a/gateway_linux.go b/gateway_linux.go index 6bb2adf..3afe3e9 100644 --- a/gateway_linux.go +++ b/gateway_linux.go @@ -30,7 +30,7 @@ func readRoutes() ([]byte, error) { return bytes, nil } -func discoverGatewayOSSpecific() (ip net.IP, err error) { +func discoverGatewayOSSpecific() (ips []net.IP, err error) { bytes, err := readRoutes() if err != nil { return nil, err diff --git a/gateway_parsers.go b/gateway_parsers.go index c36218b..7895e4e 100644 --- a/gateway_parsers.go +++ b/gateway_parsers.go @@ -103,7 +103,7 @@ func flagsContain(flags string, flag ...string) bool { return contain } -func parseToWindowsRouteStruct(output []byte) (windowsRouteStruct, error) { +func parseToWindowsRouteStruct(output []byte) ([]windowsRouteStruct, error) { // Windows route output format is always like this: // =========================================================================== // Interface List @@ -137,7 +137,7 @@ func parseToWindowsRouteStruct(output []byte) (windowsRouteStruct, error) { if sep == 3 { // We just entered the 2nd section containing "Active Routes:" if len(lines) <= idx+2 { - return windowsRouteStruct{}, &ErrNoGateway{} + return nil, &ErrNoGateway{} } inputLine := lines[idx+2] @@ -147,7 +147,7 @@ func parseToWindowsRouteStruct(output []byte) (windowsRouteStruct, error) { } fields := strings.Fields(inputLine) if len(fields) < 5 || !ipRegex.MatchString(fields[0]) { - return windowsRouteStruct{}, &ErrCantParse{} + return nil, &ErrCantParse{} } if fields[0] != "0.0.0.0" { @@ -159,7 +159,7 @@ func parseToWindowsRouteStruct(output []byte) (windowsRouteStruct, error) { metric, err := strconv.Atoi(fields[4]) if err != nil { - return windowsRouteStruct{}, err + return nil, err } defaultRoutes = append(defaultRoutes, gatewayEntry{ @@ -176,25 +176,32 @@ func parseToWindowsRouteStruct(output []byte) (windowsRouteStruct, error) { if sep == 0 { // We saw no separator lines, so input must have been garbage. - return windowsRouteStruct{}, &ErrCantParse{} + return nil, &ErrCantParse{} } if len(defaultRoutes) == 0 { - return windowsRouteStruct{}, &ErrNoGateway{} + return nil, &ErrNoGateway{} } - minDefaultRoute := slices.MinFunc(defaultRoutes, + slices.SortFunc(defaultRoutes, func(a, b gatewayEntry) int { return a.metric - b.metric }) - return windowsRouteStruct{ - Gateway: minDefaultRoute.gateway, - Interface: minDefaultRoute.iface, - }, nil + result := make([]windowsRouteStruct, 0, len(defaultRoutes)) + for _, defaultRoute := range defaultRoutes { + result = append(result, windowsRouteStruct{ + Gateway: defaultRoute.gateway, + Interface: defaultRoute.iface, + }) + } + if len(result) == 0 { + return nil, &ErrNoGateway{} + } + return result, nil } -func parseToLinuxRouteStruct(output []byte) (linuxRouteStruct, error) { +func parseToLinuxRouteStruct(output []byte) ([]linuxRouteStruct, error) { // parseLinuxProcNetRoute parses the route file located at /proc/net/route // and returns the IP address of the default gateway. The default gateway // is the one with Destination value of 0.0.0.0. @@ -218,17 +225,18 @@ func parseToLinuxRouteStruct(output []byte) (linuxRouteStruct, error) { if !scanner.Scan() { err := scanner.Err() if err == nil { - return linuxRouteStruct{}, &ErrNoGateway{} + return nil, &ErrNoGateway{} } - return linuxRouteStruct{}, err + return nil, err } + var result []linuxRouteStruct for scanner.Scan() { row := scanner.Text() tokens := strings.Split(row, sep) if len(tokens) < 11 { - return linuxRouteStruct{}, &ErrInvalidRouteFileFormat{row: row} + return nil, &ErrInvalidRouteFileFormat{row: row} } // The default interface is the one that's 0 for both destination and mask. @@ -236,59 +244,74 @@ func parseToLinuxRouteStruct(output []byte) (linuxRouteStruct, error) { continue } - return linuxRouteStruct{ + result = append(result, linuxRouteStruct{ Iface: tokens[0], Gateway: tokens[2], - }, nil + }) + } + if len(result) == 0 { + return nil, &ErrNoGateway{} } - return linuxRouteStruct{}, &ErrNoGateway{} + return result, nil } -func parseWindowsGatewayIP(output []byte) (net.IP, error) { - parsedOutput, err := parseToWindowsRouteStruct(output) +func parseWindowsGatewayIP(output []byte) ([]net.IP, error) { + parsedOutputs, err := parseToWindowsRouteStruct(output) if err != nil { return nil, err } - ip := net.ParseIP(parsedOutput.Gateway) - if ip == nil { - return nil, &ErrCantParse{} + result := make([]net.IP, 0, len(parsedOutputs)) + for _, parsedOutput := range parsedOutputs { + ip := net.ParseIP(parsedOutput.Gateway) + if ip == nil { + return nil, &ErrCantParse{} + } + result = append(result, ip) } - return ip, nil + return result, nil } -func parseWindowsInterfaceIP(output []byte) (net.IP, error) { - parsedOutput, err := parseToWindowsRouteStruct(output) +func parseWindowsInterfaceIP(output []byte) ([]net.IP, error) { + parsedOutputs, err := parseToWindowsRouteStruct(output) if err != nil { return nil, err } - ip := net.ParseIP(parsedOutput.Interface) - if ip == nil { - return nil, &ErrCantParse{} + result := make([]net.IP, 0, len(parsedOutputs)) + for _, parsedOutput := range parsedOutputs { + ip := net.ParseIP(parsedOutput.Interface) + if ip == nil { + return nil, &ErrCantParse{} + } + result = append(result, ip) } - return ip, nil + return result, nil } -func parseLinuxGatewayIP(output []byte) (net.IP, error) { - parsedStruct, err := parseToLinuxRouteStruct(output) +func parseLinuxGatewayIP(output []byte) ([]net.IP, error) { + parsedStructs, err := parseToLinuxRouteStruct(output) if err != nil { return nil, err } - // cast hex address to uint32 - d, err := strconv.ParseUint(parsedStruct.Gateway, 16, 32) - if err != nil { - return nil, fmt.Errorf( - "parsing default interface address field hex %q: %w", - parsedStruct.Gateway, - err, - ) + result := make([]net.IP, 0, len(parsedStructs)) + for _, parsedStruct := range parsedStructs { + // cast hex address to uint32 + d, err := strconv.ParseUint(parsedStruct.Gateway, 16, 32) + if err != nil { + return nil, fmt.Errorf( + "parsing default interface address field hex %q: %w", + parsedStruct.Gateway, + err, + ) + } + // make net.IP address from uint32 + ipd32 := make(net.IP, 4) + binary.LittleEndian.PutUint32(ipd32, uint32(d)) + result = append(result, ipd32) } - // make net.IP address from uint32 - ipd32 := make(net.IP, 4) - binary.LittleEndian.PutUint32(ipd32, uint32(d)) - return ipd32, nil + return result, nil } func parseLinuxInterfaceIP(output []byte) (net.IP, error) { @@ -298,12 +321,12 @@ func parseLinuxInterfaceIP(output []byte) (net.IP, error) { func parseLinuxInterfaceIPImpl(output []byte, ifaceGetter interfaceGetter) (net.IP, error) { // Mockable implemenation - parsedStruct, err := parseToLinuxRouteStruct(output) + parsedStructs, err := parseToLinuxRouteStruct(output) if err != nil { return nil, err } - return getInterfaceIP4(parsedStruct.Iface, ifaceGetter) + return getInterfaceIP4(parsedStructs[0].Iface, ifaceGetter) } func parseUnixInterfaceIP(output []byte) (net.IP, error) { @@ -313,12 +336,12 @@ func parseUnixInterfaceIP(output []byte) (net.IP, error) { func parseUnixInterfaceIPImpl(output []byte, ifaceGetter interfaceGetter) (net.IP, error) { // Mockable implemenation - parsedStruct, err := parseNetstatToRouteStruct(output) + parsedStructs, err := parseNetstatToRouteStruct(output) if err != nil { return nil, err } - return getInterfaceIP4(parsedStruct.Iface, ifaceGetter) + return getInterfaceIP4(parsedStructs[0].Iface, ifaceGetter) } func getInterfaceIP4(name string, ifaceGetter interfaceGetter) (net.IP, error) { @@ -350,32 +373,36 @@ func getInterfaceIP4(name string, ifaceGetter interfaceGetter) (net.IP, error) { name) } -func parseUnixGatewayIP(output []byte) (net.IP, error) { +func parseUnixGatewayIP(output []byte) ([]net.IP, error) { // Extract default gateway IP from netstat route table - parsedStruct, err := parseNetstatToRouteStruct(output) + parsedStructs, err := parseNetstatToRouteStruct(output) if err != nil { return nil, err } - ip := net.ParseIP(parsedStruct.Gateway) - if ip == nil { - return nil, &ErrCantParse{} + result := make([]net.IP, 0, len(parsedStructs)) + for _, parsedStruct := range parsedStructs { + ip := net.ParseIP(parsedStruct.Gateway) + if ip == nil { + return nil, &ErrCantParse{} + } + result = append(result, ip) } - - return ip, nil + return result, nil } // Parse any netstat -rn output -func parseNetstatToRouteStruct(output []byte) (unixRouteStruct, error) { +func parseNetstatToRouteStruct(output []byte) ([]unixRouteStruct, error) { startLine, nsFields := discoverFields(output) if startLine == -1 { // Unable to find required column headers in netstat output - return unixRouteStruct{}, &ErrCantParse{} + return nil, &ErrCantParse{} } outputLines := strings.Split(string(output), "\n") + var result []unixRouteStruct for lineNo, line := range outputLines { if lineNo <= startLine || strings.Contains(line, "-----") { // Skip until past column headers and heading underlines (solaris) @@ -394,12 +421,14 @@ func parseNetstatToRouteStruct(output []byte) (unixRouteStruct, error) { if ifaceIdx := nsFields[ns_netif]; ifaceIdx < len(fields) { iface = fields[nsFields[ns_netif]] } - return unixRouteStruct{ + result = append(result, unixRouteStruct{ Iface: iface, Gateway: fields[nsFields[ns_gateway]], - }, nil + }) } } - - return unixRouteStruct{}, &ErrNoGateway{} + if len(result) == 0 { + return nil, &ErrNoGateway{} + } + return result, nil } diff --git a/gateway_solaris.go b/gateway_solaris.go index 9a653dc..19981f9 100644 --- a/gateway_solaris.go +++ b/gateway_solaris.go @@ -13,7 +13,7 @@ func readNetstat() ([]byte, error) { return routeCmd.CombinedOutput() } -func discoverGatewayOSSpecific() (ip net.IP, err error) { +func discoverGatewayOSSpecific() (ips []net.IP, err error) { bytes, err := readNetstat() if err != nil { return nil, err diff --git a/gateway_test.go b/gateway_test.go index d2759de..b74fc20 100644 --- a/gateway_test.go +++ b/gateway_test.go @@ -161,14 +161,15 @@ func TestParseUnix(t *testing.T) { }) } -func testGatewayAddress(t *testing.T, testcases []ipTestCase, fn func([]byte) (net.IP, error)) { +func testGatewayAddress(t *testing.T, testcases []ipTestCase, fn func([]byte) ([]net.IP, error)) { for i, tc := range testcases { t.Run(tc.tableName, func(t *testing.T) { - net, err := fn(routeTables[tc.tableName]) + nets, err := fn(routeTables[tc.tableName]) if tc.ok { if err != nil { t.Errorf("Unexpected error in test #%d: %v", i, err) } + net := nets[0] if net.String() != tc.ifaceIP { t.Errorf("Unexpected gateway address %v != %s", net, tc.ifaceIP) } diff --git a/gateway_unimplemented.go b/gateway_unimplemented.go index a63d942..e1b9144 100644 --- a/gateway_unimplemented.go +++ b/gateway_unimplemented.go @@ -7,8 +7,8 @@ import ( "net" ) -func discoverGatewayOSSpecific() (ip net.IP, err error) { - return ip, &ErrNotImplemented{} +func discoverGatewayOSSpecific() (ips []net.IP, err error) { + return nil, &ErrNotImplemented{} } func discoverGatewayInterfaceOSSpecific() (ip net.IP, err error) { diff --git a/gateway_windows.go b/gateway_windows.go index ab236b7..2cce637 100644 --- a/gateway_windows.go +++ b/gateway_windows.go @@ -1,3 +1,4 @@ +//go:build windows // +build windows package gateway @@ -8,7 +9,7 @@ import ( "syscall" ) -func discoverGatewayOSSpecific() (ip net.IP, err error) { +func discoverGatewayOSSpecific() (ips []net.IP, err error) { routeCmd := exec.Command("route", "print", "0.0.0.0") routeCmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true} output, err := routeCmd.CombinedOutput() @@ -27,5 +28,9 @@ func discoverGatewayInterfaceOSSpecific() (ip net.IP, err error) { return nil, err } - return parseWindowsInterfaceIP(output) + ips, err := parseWindowsInterfaceIP(output) + if err != nil { + return nil, err + } + return ips[0], nil }