Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: connection leaks #624

Merged
merged 6 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
*.tmp
bpf_bpfeb*.go
bpf_bpfel*.go
bpf_*_bpfeb*.go
bpf_*_bpfel*.go
dae
outline.json
go-mod/
Expand Down
17 changes: 9 additions & 8 deletions cmd/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,9 @@ var (
Compress: true,
}
}
log := logger.NewLogger(conf.Global.LogLevel, disableTimestamp, logOpts)
logrus.SetLevel(log.Level)
log := logrus.New()
logger.SetLogger(log, conf.Global.LogLevel, disableTimestamp, logOpts)
logger.SetLogger(logrus.StandardLogger(), conf.Global.LogLevel, disableTimestamp, logOpts)

log.Infof("Include config files: [%v]", strings.Join(includes, ", "))
if err := Run(log, conf, []string{filepath.Dir(cfgFile)}); err != nil {
Expand Down Expand Up @@ -238,9 +239,11 @@ loop:
}
// New logger.
oldLogOutput := log.Out
log = logger.NewLogger(newConf.Global.LogLevel, disableTimestamp, nil)
log = logrus.New()
logger.SetLogger(log, newConf.Global.LogLevel, disableTimestamp, nil)
logger.SetLogger(logrus.StandardLogger(), newConf.Global.LogLevel, disableTimestamp, nil)
log.SetOutput(oldLogOutput) // FIXME: THIS IS A HACK.
logrus.SetLevel(log.Level)
logrus.SetOutput(oldLogOutput)

// New control plane.
obj := c.EjectBpf()
Expand Down Expand Up @@ -330,8 +333,7 @@ func newControlPlane(log *logrus.Logger, bpf interface{}, dnsCache map[string]*c
client := http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (c net.Conn, err error) {
cd := netproxy.ContextDialerConverter{Dialer: direct.SymmetricDirect}
conn, err := cd.DialContext(ctx, common.MagicNetwork("tcp", conf.Global.SoMarkFromDae, conf.Global.Mptcp), addr)
conn, err := direct.SymmetricDirect.DialContext(ctx, common.MagicNetwork("tcp", conf.Global.SoMarkFromDae, conf.Global.Mptcp), addr)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -372,8 +374,7 @@ func newControlPlane(log *logrus.Logger, bpf interface{}, dnsCache map[string]*c
client := http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (c net.Conn, err error) {
cd := netproxy.ContextDialerConverter{Dialer: direct.SymmetricDirect}
conn, err := cd.DialContext(ctx, common.MagicNetwork("tcp", conf.Global.SoMarkFromDae, conf.Global.Mptcp), addr)
conn, err := direct.SymmetricDirect.DialContext(ctx, common.MagicNetwork("tcp", conf.Global.SoMarkFromDae, conf.Global.Mptcp), addr)
if err != nil {
return nil, err
}
Expand Down
3 changes: 1 addition & 2 deletions common/netutils/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,7 @@ func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host st
}

// Dial and write.
cd := &netproxy.ContextDialerConverter{Dialer: d}
c, err := cd.DialContext(ctx, network, dns.String())
c, err := d.DialContext(ctx, network, dns.String())
if err != nil {
return nil, err
}
Expand Down
3 changes: 1 addition & 2 deletions component/outbound/dialer/connectivity_check.go
Original file line number Diff line number Diff line change
Expand Up @@ -600,12 +600,11 @@ func (d *Dialer) HttpCheck(ctx context.Context, u *netutils.URL, ip netip.Addr,
if method == "" {
method = http.MethodGet
}
cd := &netproxy.ContextDialerConverter{Dialer: d.Dialer}
cli := http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (c net.Conn, err error) {
// Force to dial "ip".
conn, err := cd.DialContext(ctx, common.MagicNetwork("tcp", soMark, mptcp), net.JoinHostPort(ip.String(), u.Port()))
conn, err := d.Dialer.DialContext(ctx, common.MagicNetwork("tcp", soMark, mptcp), net.JoinHostPort(ip.String(), u.Port()))
if err != nil {
return nil, err
}
Expand Down
7 changes: 2 additions & 5 deletions control/dns_control.go
Original file line number Diff line number Diff line change
Expand Up @@ -562,16 +562,13 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte

ctxDial, cancel := context.WithTimeout(context.TODO(), consts.DefaultDialTimeout)
defer cancel()
bestContextDialer := netproxy.ContextDialerConverter{
Dialer: dialArgument.bestDialer,
}

switch dialArgument.l4proto {
case consts.L4ProtoStr_UDP:
// Get udp endpoint.

// TODO: connection pool.
conn, err = bestContextDialer.DialContext(
conn, err = dialArgument.bestDialer.DialContext(
ctxDial,
common.MagicNetwork("udp", dialArgument.mark, dialArgument.mptcp),
dialArgument.bestTarget.String(),
Expand Down Expand Up @@ -636,7 +633,7 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte
case consts.L4ProtoStr_TCP:
// We can block here because we are in a coroutine.

conn, err = bestContextDialer.DialContext(ctxDial, common.MagicNetwork("tcp", dialArgument.mark, dialArgument.mptcp), dialArgument.bestTarget.String())
conn, err = dialArgument.bestDialer.DialContext(ctxDial, common.MagicNetwork("tcp", dialArgument.mark, dialArgument.mptcp), dialArgument.bestTarget.String())
if err != nil {
return fmt.Errorf("failed to dial proxy to tcp: %w", err)
}
Expand Down
5 changes: 1 addition & 4 deletions control/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,7 @@ func (c *ControlPlane) RouteDialTcp(p *RouteDialParam) (conn netproxy.Conn, err
}
ctx, cancel := context.WithTimeout(context.TODO(), consts.DefaultDialTimeout)
defer cancel()
cd := netproxy.ContextDialerConverter{
Dialer: d,
}
return cd.DialContext(ctx, common.MagicNetwork("tcp", routingResult.Mark, c.mptcp), dialTarget)
return d.DialContext(ctx, common.MagicNetwork("tcp", routingResult.Mark, c.mptcp), dialTarget)
}

type WriteCloser interface {
Expand Down
5 changes: 1 addition & 4 deletions control/udp_endpoint_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,9 @@ begin:
if err != nil {
return nil, false, err
}
cd := netproxy.ContextDialerConverter{
Dialer: dialOption.Dialer,
}
ctx, cancel := context.WithTimeout(context.TODO(), consts.DefaultDialTimeout)
defer cancel()
udpConn, err := cd.DialContext(ctx, dialOption.Network, dialOption.Target)
udpConn, err := dialOption.Dialer.DialContext(ctx, dialOption.Network, dialOption.Target)
if err != nil {
return nil, true, err
}
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ require (
github.com/bits-and-blooms/bloom/v3 v3.5.0
github.com/cilium/ebpf v0.12.3
github.com/daeuniverse/dae-config-dist/go/dae_config v0.0.0-20230604120805-1c27619b592d
github.com/daeuniverse/outbound v0.0.0-20240911144232-d470a59233a5
github.com/daeuniverse/outbound v0.0.0-20240926143218-3cf58cdd942f
github.com/fsnotify/fsnotify v1.7.0
github.com/json-iterator/go v1.1.12
github.com/mholt/archiver/v3 v3.5.1
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ github.com/cloudflare/circl v1.3.7/go.mod h1:sRTcRWXGLrKw6yIGJ+l7amYJFfAXbZG0kBS
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/daeuniverse/dae-config-dist/go/dae_config v0.0.0-20230604120805-1c27619b592d h1:hnC39MjR7xt5kZjrKlef7DXKFDkiX8MIcDXYC/6Jf9Q=
github.com/daeuniverse/dae-config-dist/go/dae_config v0.0.0-20230604120805-1c27619b592d/go.mod h1:VGWGgv7pCP5WGyHGUyb9+nq/gW0yBm+i/GfCNATOJ1M=
github.com/daeuniverse/outbound v0.0.0-20240911144232-d470a59233a5 h1:L450vqT1TO+Ygzd8buBMna8d4/0asT0q74qitGTWSl4=
github.com/daeuniverse/outbound v0.0.0-20240911144232-d470a59233a5/go.mod h1:0dkFMC58MVUWMB19jwQuXEg1G16uAIAtdAU7v+yWXYs=
github.com/daeuniverse/outbound v0.0.0-20240926143218-3cf58cdd942f h1:HB2IMJcU6FqLFqgDHbhhK9F0At6AFfpDRKk/oZz3T2A=
github.com/daeuniverse/outbound v0.0.0-20240926143218-3cf58cdd942f/go.mod h1:0dkFMC58MVUWMB19jwQuXEg1G16uAIAtdAU7v+yWXYs=
github.com/daeuniverse/quic-go v0.0.0-20240413031024-943f218e0810 h1:YtEYouFaNrg9sV9vf3UabvKShKn6sD0QaCdOxCwaF3g=
github.com/daeuniverse/quic-go v0.0.0-20240413031024-943f218e0810/go.mod h1:61o2uZUGLrlv1i+oO2rx9sVX0vbf8cHzdSHt7h6lMnM=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
Expand Down
6 changes: 1 addition & 5 deletions pkg/logger/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ import (
"gopkg.in/natefinch/lumberjack.v2"
)

func NewLogger(logLevel string, disableTimestamp bool, logFileOpt *lumberjack.Logger) *logrus.Logger {
log := logrus.New()

func SetLogger(log *logrus.Logger, logLevel string, disableTimestamp bool, logFileOpt *lumberjack.Logger) {
level, err := logrus.ParseLevel(logLevel)
if err != nil {
level = logrus.InfoLevel
Expand All @@ -28,6 +26,4 @@ func NewLogger(logLevel string, disableTimestamp bool, logFileOpt *lumberjack.Lo
if logFileOpt != nil {
log.SetOutput(logFileOpt)
}

return log
}
Loading