Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return all gateways #43

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@ func (e *ErrInvalidRouteFileFormat) Error() string {

// DiscoverGateway is the OS independent function to get the default gateway
func DiscoverGateway() (ip net.IP, err error) {
ips, err := DiscoverGateways()
if err != nil {
return nil, err
}
return ips[0], nil
}

// DiscoverGateways is the OS independent function to get all gateways
func DiscoverGateways() (ips []net.IP, err error) {
return discoverGatewayOSSpecific()
}

Expand Down
12 changes: 8 additions & 4 deletions gatewayForBSDs.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func readNetstat() ([]byte, error) {
return routeCmd.CombinedOutput()
}

func discoverGatewayOSSpecific() (ip net.IP, err error) {
func discoverGatewayOSSpecific() (ips []net.IP, err error) {
rib, err := route.FetchRIB(syscall.AF_INET, syscall.NET_RT_DUMP, 0)
if err != nil {
return nil, err
Expand All @@ -26,22 +26,26 @@ func discoverGatewayOSSpecific() (ip net.IP, err error) {
return nil, err
}

var result []net.IP
for _, m := range msgs {
switch m := m.(type) {
case *route.RouteMessage:
var ip net.IP
switch sa := m.Addrs[syscall.RTAX_GATEWAY].(type) {
case *route.Inet4Addr:
ip = net.IPv4(sa.IP[0], sa.IP[1], sa.IP[2], sa.IP[3])
return ip, nil
result = append(result, ip)
case *route.Inet6Addr:
ip = make(net.IP, net.IPv6len)
copy(ip, sa.IP[:])
return ip, nil
result = append(result, ip)
}
}
}
return nil, &ErrNoGateway{}
if len(result) == 0 {
return nil, &ErrNoGateway{}
}
return result, nil
}

func discoverGatewayInterfaceOSSpecific() (ip net.IP, err error) {
Expand Down
2 changes: 1 addition & 1 deletion gateway_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func readRoutes() ([]byte, error) {
return bytes, nil
}

func discoverGatewayOSSpecific() (ip net.IP, err error) {
func discoverGatewayOSSpecific() (ips []net.IP, err error) {
bytes, err := readRoutes()
if err != nil {
return nil, err
Expand Down
151 changes: 90 additions & 61 deletions gateway_parsers.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func flagsContain(flags string, flag ...string) bool {
return contain
}

func parseToWindowsRouteStruct(output []byte) (windowsRouteStruct, error) {
func parseToWindowsRouteStruct(output []byte) ([]windowsRouteStruct, error) {
// Windows route output format is always like this:
// ===========================================================================
// Interface List
Expand Down Expand Up @@ -137,7 +137,7 @@ func parseToWindowsRouteStruct(output []byte) (windowsRouteStruct, error) {
if sep == 3 {
// We just entered the 2nd section containing "Active Routes:"
if len(lines) <= idx+2 {
return windowsRouteStruct{}, &ErrNoGateway{}
return nil, &ErrNoGateway{}
}

inputLine := lines[idx+2]
Expand All @@ -147,7 +147,7 @@ func parseToWindowsRouteStruct(output []byte) (windowsRouteStruct, error) {
}
fields := strings.Fields(inputLine)
if len(fields) < 5 || !ipRegex.MatchString(fields[0]) {
return windowsRouteStruct{}, &ErrCantParse{}
return nil, &ErrCantParse{}
}

if fields[0] != "0.0.0.0" {
Expand All @@ -159,7 +159,7 @@ func parseToWindowsRouteStruct(output []byte) (windowsRouteStruct, error) {
metric, err := strconv.Atoi(fields[4])

if err != nil {
return windowsRouteStruct{}, err
return nil, err
}

defaultRoutes = append(defaultRoutes, gatewayEntry{
Expand All @@ -176,25 +176,32 @@ func parseToWindowsRouteStruct(output []byte) (windowsRouteStruct, error) {

if sep == 0 {
// We saw no separator lines, so input must have been garbage.
return windowsRouteStruct{}, &ErrCantParse{}
return nil, &ErrCantParse{}
}

if len(defaultRoutes) == 0 {
return windowsRouteStruct{}, &ErrNoGateway{}
return nil, &ErrNoGateway{}
}

minDefaultRoute := slices.MinFunc(defaultRoutes,
slices.SortFunc(defaultRoutes,
func(a, b gatewayEntry) int {
return a.metric - b.metric
})

return windowsRouteStruct{
Gateway: minDefaultRoute.gateway,
Interface: minDefaultRoute.iface,
}, nil
result := make([]windowsRouteStruct, 0, len(defaultRoutes))
for _, defaultRoute := range defaultRoutes {
result = append(result, windowsRouteStruct{
Gateway: defaultRoute.gateway,
Interface: defaultRoute.iface,
})
}
if len(result) == 0 {
return nil, &ErrNoGateway{}
}
return result, nil
}

func parseToLinuxRouteStruct(output []byte) (linuxRouteStruct, error) {
func parseToLinuxRouteStruct(output []byte) ([]linuxRouteStruct, error) {
// parseLinuxProcNetRoute parses the route file located at /proc/net/route
// and returns the IP address of the default gateway. The default gateway
// is the one with Destination value of 0.0.0.0.
Expand All @@ -218,77 +225,93 @@ func parseToLinuxRouteStruct(output []byte) (linuxRouteStruct, error) {
if !scanner.Scan() {
err := scanner.Err()
if err == nil {
return linuxRouteStruct{}, &ErrNoGateway{}
return nil, &ErrNoGateway{}
}

return linuxRouteStruct{}, err
return nil, err
}

var result []linuxRouteStruct
for scanner.Scan() {
row := scanner.Text()
tokens := strings.Split(row, sep)
if len(tokens) < 11 {
return linuxRouteStruct{}, &ErrInvalidRouteFileFormat{row: row}
return nil, &ErrInvalidRouteFileFormat{row: row}
}

// The default interface is the one that's 0 for both destination and mask.
if !(tokens[destinationField] == "00000000" && tokens[maskField] == "00000000") {
continue
}

return linuxRouteStruct{
result = append(result, linuxRouteStruct{
Iface: tokens[0],
Gateway: tokens[2],
}, nil
})
}
if len(result) == 0 {
return nil, &ErrNoGateway{}
}
return linuxRouteStruct{}, &ErrNoGateway{}
return result, nil
}

func parseWindowsGatewayIP(output []byte) (net.IP, error) {
parsedOutput, err := parseToWindowsRouteStruct(output)
func parseWindowsGatewayIP(output []byte) ([]net.IP, error) {
parsedOutputs, err := parseToWindowsRouteStruct(output)
if err != nil {
return nil, err
}

ip := net.ParseIP(parsedOutput.Gateway)
if ip == nil {
return nil, &ErrCantParse{}
result := make([]net.IP, 0, len(parsedOutputs))
for _, parsedOutput := range parsedOutputs {
ip := net.ParseIP(parsedOutput.Gateway)
if ip == nil {
return nil, &ErrCantParse{}
}
result = append(result, ip)
}
return ip, nil
return result, nil
}

func parseWindowsInterfaceIP(output []byte) (net.IP, error) {
parsedOutput, err := parseToWindowsRouteStruct(output)
func parseWindowsInterfaceIP(output []byte) ([]net.IP, error) {
parsedOutputs, err := parseToWindowsRouteStruct(output)
if err != nil {
return nil, err
}

ip := net.ParseIP(parsedOutput.Interface)
if ip == nil {
return nil, &ErrCantParse{}
result := make([]net.IP, 0, len(parsedOutputs))
for _, parsedOutput := range parsedOutputs {
ip := net.ParseIP(parsedOutput.Interface)
if ip == nil {
return nil, &ErrCantParse{}
}
result = append(result, ip)
}
return ip, nil
return result, nil
}

func parseLinuxGatewayIP(output []byte) (net.IP, error) {
parsedStruct, err := parseToLinuxRouteStruct(output)
func parseLinuxGatewayIP(output []byte) ([]net.IP, error) {
parsedStructs, err := parseToLinuxRouteStruct(output)
if err != nil {
return nil, err
}

// cast hex address to uint32
d, err := strconv.ParseUint(parsedStruct.Gateway, 16, 32)
if err != nil {
return nil, fmt.Errorf(
"parsing default interface address field hex %q: %w",
parsedStruct.Gateway,
err,
)
result := make([]net.IP, 0, len(parsedStructs))
for _, parsedStruct := range parsedStructs {
// cast hex address to uint32
d, err := strconv.ParseUint(parsedStruct.Gateway, 16, 32)
if err != nil {
return nil, fmt.Errorf(
"parsing default interface address field hex %q: %w",
parsedStruct.Gateway,
err,
)
}
// make net.IP address from uint32
ipd32 := make(net.IP, 4)
binary.LittleEndian.PutUint32(ipd32, uint32(d))
result = append(result, ipd32)
}
// make net.IP address from uint32
ipd32 := make(net.IP, 4)
binary.LittleEndian.PutUint32(ipd32, uint32(d))
return ipd32, nil
return result, nil
}

func parseLinuxInterfaceIP(output []byte) (net.IP, error) {
Expand All @@ -298,12 +321,12 @@ func parseLinuxInterfaceIP(output []byte) (net.IP, error) {

func parseLinuxInterfaceIPImpl(output []byte, ifaceGetter interfaceGetter) (net.IP, error) {
// Mockable implemenation
parsedStruct, err := parseToLinuxRouteStruct(output)
parsedStructs, err := parseToLinuxRouteStruct(output)
if err != nil {
return nil, err
}

return getInterfaceIP4(parsedStruct.Iface, ifaceGetter)
return getInterfaceIP4(parsedStructs[0].Iface, ifaceGetter)
}

func parseUnixInterfaceIP(output []byte) (net.IP, error) {
Expand All @@ -313,12 +336,12 @@ func parseUnixInterfaceIP(output []byte) (net.IP, error) {

func parseUnixInterfaceIPImpl(output []byte, ifaceGetter interfaceGetter) (net.IP, error) {
// Mockable implemenation
parsedStruct, err := parseNetstatToRouteStruct(output)
parsedStructs, err := parseNetstatToRouteStruct(output)
if err != nil {
return nil, err
}

return getInterfaceIP4(parsedStruct.Iface, ifaceGetter)
return getInterfaceIP4(parsedStructs[0].Iface, ifaceGetter)
}

func getInterfaceIP4(name string, ifaceGetter interfaceGetter) (net.IP, error) {
Expand Down Expand Up @@ -350,32 +373,36 @@ func getInterfaceIP4(name string, ifaceGetter interfaceGetter) (net.IP, error) {
name)
}

func parseUnixGatewayIP(output []byte) (net.IP, error) {
func parseUnixGatewayIP(output []byte) ([]net.IP, error) {
// Extract default gateway IP from netstat route table
parsedStruct, err := parseNetstatToRouteStruct(output)
parsedStructs, err := parseNetstatToRouteStruct(output)
if err != nil {
return nil, err
}

ip := net.ParseIP(parsedStruct.Gateway)
if ip == nil {
return nil, &ErrCantParse{}
result := make([]net.IP, 0, len(parsedStructs))
for _, parsedStruct := range parsedStructs {
ip := net.ParseIP(parsedStruct.Gateway)
if ip == nil {
return nil, &ErrCantParse{}
}
result = append(result, ip)
}

return ip, nil
return result, nil
}

// Parse any netstat -rn output
func parseNetstatToRouteStruct(output []byte) (unixRouteStruct, error) {
func parseNetstatToRouteStruct(output []byte) ([]unixRouteStruct, error) {
startLine, nsFields := discoverFields(output)

if startLine == -1 {
// Unable to find required column headers in netstat output
return unixRouteStruct{}, &ErrCantParse{}
return nil, &ErrCantParse{}
}

outputLines := strings.Split(string(output), "\n")

var result []unixRouteStruct
for lineNo, line := range outputLines {
if lineNo <= startLine || strings.Contains(line, "-----") {
// Skip until past column headers and heading underlines (solaris)
Expand All @@ -394,12 +421,14 @@ func parseNetstatToRouteStruct(output []byte) (unixRouteStruct, error) {
if ifaceIdx := nsFields[ns_netif]; ifaceIdx < len(fields) {
iface = fields[nsFields[ns_netif]]
}
return unixRouteStruct{
result = append(result, unixRouteStruct{
Iface: iface,
Gateway: fields[nsFields[ns_gateway]],
}, nil
})
}
}

return unixRouteStruct{}, &ErrNoGateway{}
if len(result) == 0 {
return nil, &ErrNoGateway{}
}
return result, nil
}
2 changes: 1 addition & 1 deletion gateway_solaris.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func readNetstat() ([]byte, error) {
return routeCmd.CombinedOutput()
}

func discoverGatewayOSSpecific() (ip net.IP, err error) {
func discoverGatewayOSSpecific() (ips []net.IP, err error) {
bytes, err := readNetstat()
if err != nil {
return nil, err
Expand Down
5 changes: 3 additions & 2 deletions gateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,15 @@ func TestParseUnix(t *testing.T) {
})
}

func testGatewayAddress(t *testing.T, testcases []ipTestCase, fn func([]byte) (net.IP, error)) {
func testGatewayAddress(t *testing.T, testcases []ipTestCase, fn func([]byte) ([]net.IP, error)) {
for i, tc := range testcases {
t.Run(tc.tableName, func(t *testing.T) {
net, err := fn(routeTables[tc.tableName])
nets, err := fn(routeTables[tc.tableName])
if tc.ok {
if err != nil {
t.Errorf("Unexpected error in test #%d: %v", i, err)
}
net := nets[0]
if net.String() != tc.ifaceIP {
t.Errorf("Unexpected gateway address %v != %s", net, tc.ifaceIP)
}
Expand Down
4 changes: 2 additions & 2 deletions gateway_unimplemented.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ import (
"net"
)

func discoverGatewayOSSpecific() (ip net.IP, err error) {
return ip, &ErrNotImplemented{}
func discoverGatewayOSSpecific() (ips []net.IP, err error) {
return nil, &ErrNotImplemented{}
}

func discoverGatewayInterfaceOSSpecific() (ip net.IP, err error) {
Expand Down
Loading