From 20205f14590519a10fca9625727ccb8a44b2c0cb Mon Sep 17 00:00:00 2001 From: Shawn Carey Date: Thu, 1 Feb 2024 14:54:03 -0500 Subject: [PATCH] assign dns ip range to loopback instead of one-ip-at-a-time --- router/xgress_edge_tunnel/tunneler.go | 1 + tunnel/intercept/interceptor.go | 28 ++++++++++++++--------- tunnel/intercept/iputils.go | 30 ++++++++++++++++++++----- tunnel/intercept/tproxy/tproxy_linux.go | 9 ++++---- 4 files changed, 47 insertions(+), 21 deletions(-) diff --git a/router/xgress_edge_tunnel/tunneler.go b/router/xgress_edge_tunnel/tunneler.go index cf1161141..daf1c9f6a 100644 --- a/router/xgress_edge_tunnel/tunneler.go +++ b/router/xgress_edge_tunnel/tunneler.go @@ -150,6 +150,7 @@ func (self *tunneler) Listen(_ string, bindHandler xgress.BindHandler) error { func (self *tunneler) Close() error { self.interceptor.Stop() + intercept.ClearDnsInterceptIpRange() return nil } diff --git a/tunnel/intercept/interceptor.go b/tunnel/intercept/interceptor.go index 547077d84..b7b5b8fde 100644 --- a/tunnel/intercept/interceptor.go +++ b/tunnel/intercept/interceptor.go @@ -43,12 +43,13 @@ type Interceptor interface { // - service name - when a service is removed (e.g. from an appwan) type InterceptAddress struct { - cidr *net.IPNet - lowPort uint16 - highPort uint16 - protocol string - TproxySpec []string - AcceptSpec []string + cidr *net.IPNet + routeRequired bool + lowPort uint16 + highPort uint16 + protocol string + TproxySpec []string + AcceptSpec []string } func (addr *InterceptAddress) Proto() string { @@ -59,6 +60,10 @@ func (addr *InterceptAddress) IpNet() *net.IPNet { return addr.cidr } +func (addr *InterceptAddress) RouteRequired() bool { + return addr.routeRequired +} + func (addr *InterceptAddress) LowPort() uint16 { return addr.lowPort } @@ -82,14 +87,15 @@ type InterceptAddrCB interface { func GetInterceptAddresses(service *entities.Service, protocols []string, resolver dns.Resolver, addressCB InterceptAddrCB) error { for _, addr := range service.InterceptV1Config.Addresses { - err := getInterceptIP(service, addr, resolver, func(ipNet *net.IPNet) { + err := getInterceptIP(service, addr, resolver, func(ipNet *net.IPNet, routeRequired bool) { for _, protocol := range protocols { for _, portRange := range service.InterceptV1Config.PortRanges { addr := &InterceptAddress{ - cidr: ipNet, - lowPort: portRange.Low, - highPort: portRange.High, - protocol: protocol} + cidr: ipNet, + routeRequired: routeRequired, + lowPort: portRange.Low, + highPort: portRange.High, + protocol: protocol} addressCB.Apply(addr) } } diff --git a/tunnel/intercept/iputils.go b/tunnel/intercept/iputils.go index f5be68504..a7f98832b 100644 --- a/tunnel/intercept/iputils.go +++ b/tunnel/intercept/iputils.go @@ -22,6 +22,7 @@ import ( "github.com/michaelquigley/pfxlog" "github.com/openziti/ziti/tunnel/dns" "github.com/openziti/ziti/tunnel/entities" + "github.com/openziti/ziti/tunnel/router" "github.com/openziti/ziti/tunnel/utils" "net" "net/netip" @@ -46,7 +47,23 @@ func SetDnsInterceptIpRange(cidr string) error { dnsCurrentIp = dnsPrefix.Addr() dnsCurrentIpMtx.Unlock() pfxlog.Logger().Infof("dns intercept IP range: %v - %v", dnsCurrentIp, dnsIpHigh) - return nil + dnsNet := &net.IPNet{ + IP: dnsPrefix.Addr().AsSlice(), + Mask: net.CIDRMask(dnsPrefix.Bits(), dnsPrefix.Addr().BitLen()), + } + err = router.AddLocalAddress(dnsNet, "lo") + if err != nil { + pfxlog.Logger().WithError(err).Errorf("failed assigning dns cidr to loopback interface") + } + return err +} + +func ClearDnsInterceptIpRange() { + dnsNet := &net.IPNet{ + IP: dnsPrefix.Addr().AsSlice(), + Mask: net.CIDRMask(dnsPrefix.Bits(), dnsPrefix.Addr().BitLen()), + } + router.RemoveLocalAddress(dnsNet, "lo") } func cleanUpFunc(hostname string, resolver dns.Resolver) func() { @@ -70,19 +87,20 @@ func incDnsIp() (err error) { return } -func getDnsIp(host string, addrCB func(*net.IPNet), svc *entities.Service, resolver dns.Resolver) (net.IP, error) { +func getDnsIp(host string, addrCB func(*net.IPNet, bool), svc *entities.Service, resolver dns.Resolver) (net.IP, error) { err := incDnsIp() if err == nil { - addrCB(&net.IPNet{ + addr := &net.IPNet{ IP: dnsCurrentIp.AsSlice(), Mask: net.CIDRMask(dnsCurrentIp.BitLen(), dnsCurrentIp.BitLen()), - }) + } + addrCB(addr, false) svc.AddCleanupAction(cleanUpFunc(host, resolver)) } return dnsCurrentIp.AsSlice(), err } -func getInterceptIP(svc *entities.Service, hostname string, resolver dns.Resolver, addrCB func(ipNet *net.IPNet)) error { +func getInterceptIP(svc *entities.Service, hostname string, resolver dns.Resolver, addrCB func(*net.IPNet, bool)) error { logger := pfxlog.Logger() // handle wildcard domain - IPs will be allocated when matching hostnames are queried @@ -96,7 +114,7 @@ func getInterceptIP(svc *entities.Service, hostname string, resolver dns.Resolve // handle IP or CIDR ipNet, err := utils.GetCidr(hostname) if err == nil { - addrCB(ipNet) + addrCB(ipNet, true) return err } diff --git a/tunnel/intercept/tproxy/tproxy_linux.go b/tunnel/intercept/tproxy/tproxy_linux.go index 61c3f5540..10b2719b6 100644 --- a/tunnel/intercept/tproxy/tproxy_linux.go +++ b/tunnel/intercept/tproxy/tproxy_linux.go @@ -505,10 +505,11 @@ func (self *tProxy) intercept(service *entities.Service, resolver dns.Resolver, func (self *tProxy) addInterceptAddr(interceptAddr *intercept.InterceptAddress, service *entities.Service, port IPPortAddr, tracker intercept.AddressTracker) error { ipNet := interceptAddr.IpNet() - if err := router.AddLocalAddress(ipNet, "lo"); err != nil { - return errors.Wrapf(err, "failed to add local route %v", ipNet) + if interceptAddr.RouteRequired() { + if err := router.AddLocalAddress(ipNet, "lo"); err != nil { + return errors.Wrapf(err, "failed to add local route %v", ipNet) + } } - tracker.AddAddress(ipNet.String()) self.addresses = append(self.addresses, interceptAddr) if self.interceptor.diverter != "" { @@ -608,7 +609,7 @@ func (self *tProxy) StopIntercepting(tracker intercept.AddressTracker) error { } ipNet := addr.IpNet() - if tracker.RemoveAddress(ipNet.String()) { + if tracker.RemoveAddress(ipNet.String()) && addr.RouteRequired() { err := router.RemoveLocalAddress(ipNet, "lo") if err != nil { errorList = append(errorList, err)