Skip to content

Commit

Permalink
Merge pull request #3523 from telepresenceio/thallgren/routing-prio
Browse files Browse the repository at this point in the history
Ensure that a smaller allow-proxy isn't dropped by larger subnet.
  • Loading branch information
thallgren authored Feb 19, 2024
2 parents 0347431 + ae1767d commit a0b95a8
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 127 deletions.
58 changes: 53 additions & 5 deletions integration_test/kubeconfig_extension_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/telepresenceio/telepresence/v2/pkg/filelocation"
"github.com/telepresenceio/telepresence/v2/pkg/iputil"
"github.com/telepresenceio/telepresence/v2/pkg/routing"
"github.com/telepresenceio/telepresence/v2/pkg/slice"
)

func getClusterIPs(cluster *api.Cluster) ([]net.IP, error) {
Expand Down Expand Up @@ -161,8 +162,29 @@ func (s *notConnectedSuite) Test_NeverProxy() {
func (s *notConnectedSuite) Test_ConflictingProxies() {
ctx := s.Context()

s.TelepresenceConnect(ctx)
st := itest.TelepresenceStatusOk(ctx)
itest.TelepresenceQuitOk(ctx)
rq := s.Require()
rq.True(len(st.RootDaemon.Subnets) > 0)
svcCIDR := st.RootDaemon.Subnets[0]
ones, bits := svcCIDR.Mask.Size()
if ones != 16 || bits != 32 {
s.T().Skip("test requires an IPv4 service subnet with a 16 bit mask")
}

base := svcCIDR.IP.Mask(svcCIDR.Mask)
largeCIDR := &net.IPNet{
IP: base,
Mask: net.CIDRMask(24, 32),
}
smallCIDR := &net.IPNet{
IP: base,
Mask: net.CIDRMask(28, 32),
}
// testIP is an IP that is covered by smallCIDR
testIP := &net.IPNet{
IP: net.ParseIP("10.128.0.32"),
IP: net.IP{base[0], base[1], 0, 4},
Mask: net.CIDRMask(32, 32),
}
// We don't really care if we can't route this with TP disconnected provided the result is the same once we connect
Expand All @@ -173,13 +195,13 @@ func (s *notConnectedSuite) Test_ConflictingProxies() {
expectEq bool
}{
"Never Proxy wins": {
alsoProxy: []string{"10.128.0.0/16"},
neverProxy: []string{"10.128.0.0/24"},
alsoProxy: []string{largeCIDR.String()},
neverProxy: []string{smallCIDR.String()},
expectEq: true,
},
"Also Proxy wins": {
alsoProxy: []string{"10.128.0.0/24"},
neverProxy: []string{"10.128.0.0/16"},
alsoProxy: []string{smallCIDR.String()},
neverProxy: []string{largeCIDR.String()},
expectEq: false,
},
} {
Expand Down Expand Up @@ -215,6 +237,32 @@ func (s *notConnectedSuite) Test_ConflictingProxies() {
}
}

func (s *notConnectedSuite) Test_AlsoNeverProxyDocker() {
if s.IsCI() && !(runtime.GOOS == "linux" && runtime.GOARCH == "amd64") {
s.T().Skip("CI can't run linux docker containers inside non-linux runners")
}
alsoProxy := []string{"10.128.0.0/16"}
neverProxy := []string{"10.128.0.0/24"}
ctx := itest.WithKubeConfigExtension(s.Context(), func(cluster *api.Cluster) map[string]any {
return map[string]any{
"never-proxy": neverProxy,
"also-proxy": alsoProxy,
}
})
cidrsToStrings := func(cidrs []*iputil.Subnet) []string {
ss := make([]string, len(cidrs))
for i, cidr := range cidrs {
ss[i] = cidr.String()
}
return ss
}
s.TelepresenceConnect(ctx, "--context", "extra", "--docker")
defer itest.TelepresenceQuitOk(ctx)
st := itest.TelepresenceStatusOk(ctx)
s.True(slice.ContainsAll(cidrsToStrings(st.ContainerizedDaemon.AlsoProxy), alsoProxy))
s.True(slice.ContainsAll(cidrsToStrings(st.ContainerizedDaemon.NeverProxy), neverProxy))
}

func (s *notConnectedSuite) Test_DNSSuffixRules() {
if s.IsCI() && runtime.GOOS == "linux" && runtime.GOARCH == "arm64" {
s.T().Skip("The DNS on the linux-arm64 GitHub runner is not configured correctly")
Expand Down
58 changes: 39 additions & 19 deletions pkg/client/rootd/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -767,11 +767,28 @@ func (s *Session) onClusterInfo(ctx context.Context, mgrInfo *manager.ClusterInf
)
}

subnets = subnet.Unique(subnets)
dontProxy := slices.Clone(s.neverProxySubnets)
last := len(dontProxy) - 1
proxy, neverProxy, neverProxyOverrides := computeNeverProxyOverrides(ctx, subnets, s.neverProxySubnets)

// Fire and forget to send metrics out.
go func() {
scout.Report(ctx, "update_routes",
scout.Entry{Key: "subnets", Value: len(proxy)},
scout.Entry{Key: "allow_conflicting_subnets", Value: len(s.allowConflictingSubnets)},
)
}()
if s.tunVif == nil {
return nil
}
rt := s.tunVif.Router
rt.UpdateWhitelist(s.allowConflictingSubnets)
return rt.UpdateRoutes(ctx, proxy, neverProxy, neverProxyOverrides)
}

func computeNeverProxyOverrides(ctx context.Context, subnets, nvp []*net.IPNet) (proxy, neverProxy, neverProxyOverrides []*net.IPNet) {
neverProxy = slices.Clone(nvp)
last := len(neverProxy) - 1
for i := 0; i <= last; {
nps := dontProxy[i]
nps := neverProxy[i]
found := false
for _, ds := range subnets {
if subnet.Overlaps(ds, nps) {
Expand All @@ -783,28 +800,31 @@ func (s *Session) onClusterInfo(ctx context.Context, mgrInfo *manager.ClusterInf
// This never-proxy is pointless because it's not a subnet that we are routing
dlog.Infof(ctx, "Dropping never-proxy %q because it is not routed", nps)
if last > i {
dontProxy[i] = dontProxy[last]
neverProxy[i] = neverProxy[last]
}
last--
} else {
i++
}
}
dontProxy = dontProxy[:last+1]
neverProxy = neverProxy[:last+1]

// Fire and forget to send metrics out.
go func() {
scout.Report(ctx, "update_routes",
scout.Entry{Key: "subnets", Value: len(subnets)},
scout.Entry{Key: "allow_conflicting_subnets", Value: len(s.allowConflictingSubnets)},
)
}()
if s.tunVif == nil {
return nil
}
rt := s.tunVif.Router
rt.UpdateWhitelist(s.allowConflictingSubnets)
return rt.UpdateRoutes(ctx, subnets, dontProxy)
proxy, neverProxyOverrides = subnet.Partition(subnets, func(i int, isn *net.IPNet) bool {
for r, rsn := range subnets {
if i == r {
continue
}
if subnet.Covers(rsn, isn) && !subnet.Equal(rsn, isn) {
for _, dsn := range neverProxy {
if subnet.Covers(dsn, isn) {
return false
}
}
}
}
return true
})
return subnet.Unique(proxy), neverProxy, neverProxyOverrides
}

func validateSubnets(name string, sns []*manager.IPNet, allowLoopback func() bool) ([]*net.IPNet, error) {
Expand Down
16 changes: 9 additions & 7 deletions pkg/routing/routing_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"net"
"regexp"
"strconv"
"strings"
"time"

Expand All @@ -20,7 +21,7 @@ import (

type table struct{}

func rowAsRoute(ctx context.Context, row *winipcfg.MibIPforwardRow2, localIP net.IP) (*Route, error) {
func rowAsRoute(row *winipcfg.MibIPforwardRow2, localIP net.IP) (*Route, error) {
dst := row.DestinationPrefix.Prefix()
if !dst.IsValid() {
return nil, nil
Expand Down Expand Up @@ -74,7 +75,7 @@ func getConsistentRoutingTable(ctx context.Context) ([]*Route, error) {
}
routes := []*Route{}
for _, row := range table {
r, err := rowAsRoute(ctx, &row, nil)
r, err := rowAsRoute(&row, nil)
if err != nil {
return nil, err
}
Expand All @@ -85,7 +86,7 @@ func getConsistentRoutingTable(ctx context.Context) ([]*Route, error) {
return routes, nil
}

func getRouteForIP(ctx context.Context, localIP net.IP) (*Route, error) {
func getRouteForIP(localIP net.IP) (*Route, error) {
retryInconsistent:
for i := 0; i < maxInconsistentRetries; i++ {
table, err := winipcfg.GetIPForwardTable2(windows.AF_UNSPEC)
Expand All @@ -98,7 +99,7 @@ retryInconsistent:
if addrs, err := iface.Addrs(); err == nil {
for _, addr := range addrs {
if ip, _, err := net.ParseCIDR(addr.String()); err == nil && ip.Equal(localIP) {
r, err := rowAsRoute(ctx, &row, ip)
r, err := rowAsRoute(&row, ip)
if err != nil {
if err == errInconsistentRT {
time.Sleep(inconsistentRetryDelay)
Expand Down Expand Up @@ -144,7 +145,7 @@ func GetRoute(ctx context.Context, routedNet *net.IPNet) (*Route, error) {
if localIP == nil {
return nil, fmt.Errorf("unable to parse local IP from %q", string(out))
}
return getRouteForIP(ctx, localIP)
return getRouteForIP(localIP)
}

func maskToIP(mask net.IPMask) (ip net.IP) {
Expand All @@ -154,14 +155,15 @@ func maskToIP(mask net.IPMask) (ip net.IP) {
}

func (r *Route) addStatic(ctx context.Context) error {
mask := maskToIP(r.RoutedNet.Mask)
cmd := proc.CommandContext(ctx,
"route",
"ADD",
r.RoutedNet.IP.String(),
"MASK",
mask.String(),
maskToIP(r.RoutedNet.Mask).String(),
r.Gateway.String(),
"IF",
strconv.Itoa(r.Interface.Index),
)
cmd.DisableLogging = true
out, err := cmd.Output()
Expand Down
43 changes: 5 additions & 38 deletions pkg/vif/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,15 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/stack"

"github.com/datawire/dlib/dlog"
"github.com/telepresenceio/telepresence/v2/pkg/routing"
"github.com/telepresenceio/telepresence/v2/pkg/tracing"
vifBuffer "github.com/telepresenceio/telepresence/v2/pkg/vif/buffer"
)

type device struct {
*channel.Endpoint
ctx context.Context
wg sync.WaitGroup
dev *nativeDevice
table routing.Table
ctx context.Context
wg sync.WaitGroup
dev *nativeDevice
}

type Device interface {
Expand All @@ -49,7 +47,7 @@ const defaultDevOutQueueLen = 1024
var _ Device = (*device)(nil)

// OpenTun creates a new TUN device and ensures that it is up and running.
func OpenTun(ctx context.Context, routingTable routing.Table) (Device, error) {
func OpenTun(ctx context.Context) (Device, error) {
dev, err := openTun(ctx)
if err != nil {
return nil, err
Expand All @@ -59,7 +57,6 @@ func OpenTun(ctx context.Context, routingTable routing.Table) (Device, error) {
Endpoint: channel.New(defaultDevOutQueueLen, defaultDevMtu, ""),
ctx: ctx,
dev: dev,
table: routingTable,
}, nil
}

Expand All @@ -78,35 +75,12 @@ func (d *device) Attach(dp stack.NetworkDispatcher) {
}()
}

func (d *device) subnetToRoute(subnet *net.IPNet) (*routing.Route, error) {
gw := make(net.IP, len(subnet.IP))
copy(gw, subnet.IP)
gw[len(gw)-1] += 1
iface, err := net.InterfaceByName(d.Name())
if err != nil {
return nil, err
}
return &routing.Route{
LocalIP: subnet.IP,
RoutedNet: subnet,
Interface: iface,
Gateway: gw,
}, nil
}

// AddSubnet adds a subnet to this TUN device and creates a route for that subnet which
// is associated with the device (removing the device will automatically remove the route).
func (d *device) AddSubnet(ctx context.Context, subnet *net.IPNet) (err error) {
ctx, span := otel.GetTracerProvider().Tracer("").Start(ctx, "AddSubnet", trace.WithAttributes(attribute.Stringer("tel2.subnet", subnet)))
defer tracing.EndAndRecord(span, err)
if err := d.dev.addSubnet(ctx, subnet); err != nil {
return err
}
route, err := d.subnetToRoute(subnet)
if err != nil {
return err
}
return d.table.Add(ctx, route)
return d.dev.addSubnet(ctx, subnet)
}

func (d *device) Close() error {
Expand Down Expand Up @@ -135,13 +109,6 @@ func (d *device) SetMTU(mtu int) error {
// RemoveSubnet removes a subnet from this TUN device and also removes the route for that subnet which
// is associated with the device.
func (d *device) RemoveSubnet(ctx context.Context, subnet *net.IPNet) (err error) {
route, err := d.subnetToRoute(subnet)
if err != nil {
return err
}
if err := d.table.Remove(ctx, route); err != nil {
return err
}
// Staticcheck screams if this is ctx, span := because it thinks the context argument is being overwritten before being used.
sCtx, span := otel.GetTracerProvider().Tracer("").Start(ctx, "RemoveSubnet", trace.WithAttributes(attribute.Stringer("tel2.subnet", subnet)))
defer tracing.EndAndRecord(span, err)
Expand Down
Loading

0 comments on commit a0b95a8

Please sign in to comment.