Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

util/addr: Fixes findIP to return the correct public IP #2673

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions util/addr/addr.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ func IsLocal(addr string) bool {
}

// Extract returns a valid IP address. If the address provided is a valid
// address, it will be returned directly. Otherwise the available interfaces
// be itterated over to find an IP address, prefferably private.
// address, it will be returned directly. Otherwise, the available interfaces
// will be iterated over to find an IP address, preferably private.
func Extract(addr string) (string, error) {
// if addr is already specified then it's directly returned
if len(addr) > 0 && (addr != "0.0.0.0" && addr != "[::]" && addr != "::") {
Expand Down Expand Up @@ -115,10 +115,12 @@ func IPs() []string {
return ipAddrs
}

// findIP will return the first private IP available in the list,
// if no private IP is available it will return a public IP if present.
// findIP will return the first private IP available in the list.
// If no private IP is available it will return the first public IP, if present.
// If no public IP is available, it will return the first loopback IP, if present.
func findIP(addresses []net.Addr) (net.IP, error) {
var publicIP net.IP
var localIP net.IP

for _, rawAddr := range addresses {
var ip net.IP
Expand All @@ -131,8 +133,17 @@ func findIP(addresses []net.Addr) (net.IP, error) {
continue
}

if ip.IsLoopback() {
if localIP == nil {
localIP = ip
}
continue
}

if !ip.IsPrivate() {
publicIP = ip
if publicIP == nil {
publicIP = ip
}
continue
}

Expand All @@ -145,5 +156,10 @@ func findIP(addresses []net.Addr) (net.IP, error) {
return publicIP, nil
}

// Return local IP
if len(localIP) > 0 {
return localIP, nil
}

return nil, ErrIPNotFound
}
76 changes: 76 additions & 0 deletions util/addr/addr_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package addr

import (
"github.com/stretchr/testify/assert"
"net"
"testing"
)
Expand Down Expand Up @@ -54,3 +55,78 @@ func TestExtractor(t *testing.T) {
}
}
}

func TestFindIP(t *testing.T) {
localhost, _ := net.ResolveIPAddr("ip", "127.0.0.1")
localhostIPv6, _ := net.ResolveIPAddr("ip", "::1")
privateIP, _ := net.ResolveIPAddr("ip", "10.0.0.1")
publicIP, _ := net.ResolveIPAddr("ip", "100.0.0.1")
publicIPv6, _ := net.ResolveIPAddr("ip", "2001:0db8:85a3:0000:0000:8a2e:0370:7334")

testCases := []struct {
addrs []net.Addr
ip net.IP
errMsg string
}{
{
addrs: []net.Addr{},
ip: nil,
errMsg: ErrIPNotFound.Error(),
},
{
addrs: []net.Addr{localhost},
ip: localhost.IP,
},
{
addrs: []net.Addr{localhost, localhostIPv6},
ip: localhost.IP,
},
{
addrs: []net.Addr{localhostIPv6},
ip: localhostIPv6.IP,
},
{
addrs: []net.Addr{privateIP, localhost},
ip: privateIP.IP,
},
{
addrs: []net.Addr{privateIP, publicIP, localhost},
ip: privateIP.IP,
},
{
addrs: []net.Addr{publicIP, privateIP, localhost},
ip: privateIP.IP,
},
{
addrs: []net.Addr{publicIP, localhost},
ip: publicIP.IP,
},
{
addrs: []net.Addr{publicIP, localhostIPv6},
ip: publicIP.IP,
},
{
addrs: []net.Addr{localhostIPv6, publicIP},
ip: publicIP.IP,
},
{
addrs: []net.Addr{localhostIPv6, publicIPv6, publicIP},
ip: publicIPv6.IP,
},
{
addrs: []net.Addr{publicIP, publicIPv6},
ip: publicIP.IP,
},
}

for _, tc := range testCases {
ip, err := findIP(tc.addrs)
if tc.errMsg == "" {
assert.Nil(t, err)
assert.Equal(t, tc.ip.String(), ip.String())
} else {
assert.NotNil(t, err)
assert.Equal(t, tc.errMsg, err.Error())
}
}
}
Loading