diff --git a/ipvs_test.go b/ipvs_test.go index 4054f6d..b0804cf 100644 --- a/ipvs_test.go +++ b/ipvs_test.go @@ -66,7 +66,9 @@ func checkDestination(t *testing.T, i *Handle, s *Service, d *Destination, check assert.NilError(t, err) for _, dst := range dstArray { - if dst.Address.Equal(d.Address) && dst.Port == d.Port && lookupFwMethod(dst.ConnectionFlags) == lookupFwMethod(d.ConnectionFlags) { + if dst.Address.Equal(d.Address) && dst.Port == d.Port && + lookupFwMethod(dst.ConnectionFlags) == lookupFwMethod(d.ConnectionFlags) && + dst.AddressFamily == d.AddressFamily { dstFound = true break } diff --git a/netlink.go b/netlink.go index 8a98ebb..ac48d05 100644 --- a/netlink.go +++ b/netlink.go @@ -5,6 +5,7 @@ package ipvs import ( "bytes" "encoding/binary" + "errors" "fmt" "net" "os/exec" @@ -351,17 +352,6 @@ func assembleService(attrs []syscall.NetlinkRouteAttr) (*Service, error) { } - // in older kernels (< 3.18), the svc address family attribute may not exist so we must - // assume it based on the svc address provided. - if s.AddressFamily == 0 { - addr := (net.IP)(addressBytes) - if addr.To4() != nil { - s.AddressFamily = syscall.AF_INET - } else { - s.AddressFamily = syscall.AF_INET6 - } - } - // parse Address after parse AddressFamily incase of parseIP error if addressBytes != nil { ip, err := parseIP(addressBytes, s.AddressFamily) @@ -472,12 +462,14 @@ func assembleDestination(attrs []syscall.NetlinkRouteAttr) (*Destination, error) // in older kernels (< 3.18), the destination address family attribute doesn't exist so we must // assume it based on the destination address provided. if d.AddressFamily == 0 { - addr := (net.IP)(addressBytes) - if addr.To4() != nil { - d.AddressFamily = syscall.AF_INET - } else { - d.AddressFamily = syscall.AF_INET6 + // we can't check the address family using net stdlib because netlink returns + // IPv4 addresses as the first 4 bytes in a []byte of length 16 where as + // stdlib expects it as the last 4 bytes. + addressFamily, err := getIPFamily(addressBytes) + if err != nil { + return nil, err } + d.AddressFamily = addressFamily } // parse Address after parse AddressFamily incase of parseIP error @@ -492,6 +484,37 @@ func assembleDestination(attrs []syscall.NetlinkRouteAttr) (*Destination, error) return &d, nil } +// getIPFamily parses the IP family based on raw data from netlink. +// For AF_INET, netlink will set the first 4 bytes with trailing zeros +// 10.0.0.1 -> [10 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0] +// For AF_INET6, the full 16 byte array is used: +// 2001:db8:3c4d:15::1a00 -> [32 1 13 184 60 77 0 21 0 0 0 0 0 0 26 0] +func getIPFamily(address []byte) (uint16, error) { + if len(address) == 4 { + return syscall.AF_INET, nil + } + + if isZeros(address) { + return 0, errors.New("could not parse IP family from address data") + } + + // assume IPv4 if first 4 bytes are non-zero but rest of the data is trailing zeros + if !isZeros(address[:4]) && isZeros(address[4:]) { + return syscall.AF_INET, nil + } + + return syscall.AF_INET6, nil +} + +func isZeros(b []byte) bool { + for i := 0; i < len(b); i++ { + if b[i] != 0 { + return false + } + } + return true +} + // parseDestination given a ipvs netlink response this function will respond with a valid destination entry, an error otherwise func (i *Handle) parseDestination(msg []byte) (*Destination, error) { var dst *Destination diff --git a/netlink_test.go b/netlink_test.go new file mode 100644 index 0000000..443f274 --- /dev/null +++ b/netlink_test.go @@ -0,0 +1,55 @@ +// +build linux + +package ipvs + +import ( + "errors" + "reflect" + "syscall" + "testing" +) + +func Test_getIPFamily(t *testing.T) { + testcases := []struct { + name string + address []byte + expectedFamily uint16 + expectedErr error + }{ + { + name: "16 byte IPv4 10.0.0.1", + address: []byte{10, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + expectedFamily: syscall.AF_INET, + expectedErr: nil, + }, + { + name: "16 byte IPv6 2001:db8:3c4d:15::1a00", + address: []byte{32, 1, 13, 184, 60, 77, 0, 21, 0, 0, 0, 0, 0, 0, 26, 0}, + expectedFamily: syscall.AF_INET6, + expectedErr: nil, + }, + { + name: "zero address", + address: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + expectedFamily: 0, + expectedErr: errors.New("could not parse IP family from address data"), + }, + } + + for _, testcase := range testcases { + t.Run(testcase.name, func(t *testing.T) { + family, err := getIPFamily(testcase.address) + if !reflect.DeepEqual(err, testcase.expectedErr) { + t.Logf("got err: %v", err) + t.Logf("expected err: %v", testcase.expectedErr) + t.Errorf("unexpected error") + } + + if family != testcase.expectedFamily { + t.Logf("got IP family: %v", family) + t.Logf("expected IP family: %v", testcase.expectedFamily) + t.Errorf("unexpected IP family") + } + }) + } +}