Skip to content

Commit

Permalink
proxy: imp code
Browse files Browse the repository at this point in the history
  • Loading branch information
Mizzick committed Dec 12, 2023
1 parent d5b40a4 commit 9b7e21e
Show file tree
Hide file tree
Showing 9 changed files with 45 additions and 46 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ Application Options:
--ipv6-disabled If specified, all AAAA requests will be replied with NoError RCode and empty answer
--bogus-nxdomain= Transform the responses containing at least a single IP that matches specified addresses and CIDRs into NXDOMAIN. Can be specified multiple times.
--udp-buf-size= Set the size of the UDP buffer in bytes. A value <= 0 will use the system default.
--max-go-routines= Set the maximum number of go routines. A value <= 0 will not not set a maximum.
--max-go-routines= Set the maximum number of go routines. A zero value will not not set a maximum.
--pprof If present, exposes pprof information on localhost:6060.
--version Prints the program version
Expand Down
4 changes: 2 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ type Options struct {
// UDP buffer size value
UDPBufferSize int `yaml:"udp-buf-size" long:"udp-buf-size" description:"Set the size of the UDP buffer in bytes. A value <= 0 will use the system default."`

// The maximum number of go routines
MaxGoRoutines int `yaml:"max-go-routines" long:"max-go-routines" description:"Set the maximum number of go routines. A value <= 0 will not not set a maximum."`
// MaxGoRoutines is the maximum number of goroutines.
MaxGoRoutines uint `yaml:"max-go-routines" long:"max-go-routines" description:"Set the maximum number of go routines. A zero value will not not set a maximum."`

// Pprof defines whether the pprof information needs to be exposed via
// localhost:6060 or not.
Expand Down
2 changes: 1 addition & 1 deletion proxy/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ type Config struct {
// TODO(a.garipov): Rename this to something like
// “MaxDNSRequestGoroutines” in a later major version, as it doesn't
// actually limit all goroutines.
MaxGoroutines int
MaxGoroutines uint

// The size of the read buffer on the underlying socket. Larger read buffers can handle
// larger bursts of requests before packets get dropped.
Expand Down
8 changes: 4 additions & 4 deletions proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,15 +167,15 @@ type Proxy struct {
// RWMutex protects the whole proxy.
sync.RWMutex

// requestGoroutinesSema limits the number of simultaneous requests.
// requestsSema limits the number of simultaneous requests.
//
// TODO(a.garipov): Currently we have to pass this exact semaphore to
// the workers, to prevent races on restart. In the future we will need
// a better restarting mechanism that completely prevents such invalid
// states.
//
// See also: https://github.com/AdguardTeam/AdGuardHome/issues/2242.
requestGoroutinesSema syncutil.Semaphore
requestsSema syncutil.Semaphore

// Config is the proxy configuration.
//
Expand All @@ -196,9 +196,9 @@ func (p *Proxy) Init() (err error) {
if p.MaxGoroutines > 0 {
log.Info("dnsproxy: max goroutines is set to %d", p.MaxGoroutines)

p.requestGoroutinesSema = syncutil.NewChanSemaphore(uint(p.MaxGoroutines))
p.requestsSema = syncutil.NewChanSemaphore(p.MaxGoroutines)
} else {
p.requestGoroutinesSema = syncutil.EmptySemaphore{}
p.requestsSema = syncutil.EmptySemaphore{}
}

p.udpOOBSize = proxynetutil.UDPGetOOBSize()
Expand Down
8 changes: 4 additions & 4 deletions proxy/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,15 @@ func (p *Proxy) startListeners(ctx context.Context) error {
}

for _, l := range p.udpListen {
go p.udpPacketLoop(l, p.requestGoroutinesSema)
go p.udpPacketLoop(l, p.requestsSema)
}

for _, l := range p.tcpListen {
go p.tcpPacketLoop(l, ProtoTCP, p.requestGoroutinesSema)
go p.tcpPacketLoop(l, ProtoTCP, p.requestsSema)
}

for _, l := range p.tlsListen {
go p.tcpPacketLoop(l, ProtoTLS, p.requestGoroutinesSema)
go p.tcpPacketLoop(l, ProtoTLS, p.requestsSema)
}

for _, l := range p.httpsListen {
Expand All @@ -64,7 +64,7 @@ func (p *Proxy) startListeners(ctx context.Context) error {
}

for _, l := range p.quicListen {
go p.quicPacketLoop(l, p.requestGoroutinesSema)
go p.quicPacketLoop(l, p.requestsSema)
}

for _, l := range p.dnsCryptUDPListen {
Expand Down
11 changes: 6 additions & 5 deletions proxy/server_dnscrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func (p *Proxy) createDNSCryptListeners() (err error) {
Handler: &dnsCryptHandler{
proxy: p,

requestGoroutinesSema: p.requestGoroutinesSema,
reqSema: p.requestsSema,
},
}

Expand Down Expand Up @@ -63,7 +63,7 @@ func (p *Proxy) createDNSCryptListeners() (err error) {
type dnsCryptHandler struct {
proxy *Proxy

requestGoroutinesSema syncutil.Semaphore
reqSema syncutil.Semaphore
}

// compile-time type check
Expand All @@ -75,11 +75,12 @@ func (h *dnsCryptHandler) ServeDNS(rw dnscrypt.ResponseWriter, req *dns.Msg) (er
d.Addr = netutil.NetAddrToAddrPort(rw.RemoteAddr())
d.DNSCryptResponseWriter = rw

err = h.requestGoroutinesSema.Acquire(context.Background())
// TODO(d.kolyshev): Pass and use context from above.
err = h.reqSema.Acquire(context.Background())
if err != nil {
return fmt.Errorf("acquiring semaphore: %w", err)
return fmt.Errorf("dnsproxy: dnscrypt: acquiring semaphore: %w", err)
}
defer h.requestGoroutinesSema.Release()
defer h.reqSema.Release()

return h.proxy.handleDNSRequest(d)
}
Expand Down
28 changes: 13 additions & 15 deletions proxy/server_quic.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ func (p *Proxy) createQUICListeners() error {

// quicPacketLoop listens for incoming QUIC packets.
//
// See also the comment on Proxy.requestGoroutinesSema.
func (p *Proxy) quicPacketLoop(l *quic.EarlyListener, requestGoroutinesSema syncutil.Semaphore) {
// See also the comment on Proxy.requestsSema.
func (p *Proxy) quicPacketLoop(l *quic.EarlyListener, reqSema syncutil.Semaphore) {
log.Info("Entering the DNS-over-QUIC listener loop on %s", l.Addr())
for {
ctx := context.Background()
Expand All @@ -99,27 +99,25 @@ func (p *Proxy) quicPacketLoop(l *quic.EarlyListener, requestGoroutinesSema sync
break
}

err = requestGoroutinesSema.Acquire(ctx)
err = reqSema.Acquire(ctx)
if err != nil {
log.Error("acquiring semaphore: %s", err)
log.Error("dnsproxy: quic: acquiring semaphore: %s", err)

break
}
go func() {
p.handleQUICConnection(conn, requestGoroutinesSema)
requestGoroutinesSema.Release()
defer reqSema.Release()

p.handleQUICConnection(conn, reqSema)
}()
}
}

// handleQUICConnection handles a new QUIC connection. It waits for new streams
// and passes them to handleQUICStream.
//
// See also the comment on Proxy.requestGoroutinesSema.
func (p *Proxy) handleQUICConnection(
conn quic.Connection,
requestGoroutinesSema syncutil.Semaphore,
) {
// See also the comment on Proxy.requestsSema.
func (p *Proxy) handleQUICConnection(conn quic.Connection, reqSema syncutil.Semaphore) {
for {
ctx := context.Background()

Expand All @@ -142,24 +140,24 @@ func (p *Proxy) handleQUICConnection(
return
}

err = requestGoroutinesSema.Acquire(ctx)
err = reqSema.Acquire(ctx)
if err != nil {
log.Error("acquiring semaphore: %s", err)
log.Error("dnsproxy: quic: acquiring semaphore: %s", err)

// Close the connection to make sure resources are freed.
closeQUICConn(conn, DoQCodeNoError)

return
}
go func() {
defer reqSema.Release()

p.handleQUICStream(stream, conn)

// The server MUST send the response(s) on the same stream and MUST
// indicate, after the last response, through the STREAM FIN
// mechanism that no further data will be sent on that stream.
_ = stream.Close()

requestGoroutinesSema.Release()
}()
}
}
Expand Down
16 changes: 7 additions & 9 deletions proxy/server_tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,8 @@ func (p *Proxy) createTLSListeners() (err error) {
// tcpPacketLoop listens for incoming TCP packets. proto must be either "tcp"
// or "tls".
//
// See also the comment on Proxy.requestGoroutinesSema.
func (p *Proxy) tcpPacketLoop(
l net.Listener,
proto Proto,
requestGoroutinesSema syncutil.Semaphore,
) {
// See also the comment on Proxy.requestsSema.
func (p *Proxy) tcpPacketLoop(l net.Listener, proto Proto, reqSema syncutil.Semaphore) {
log.Info("dnsproxy: entering %s listener loop on %s", proto, l.Addr())

for {
Expand All @@ -81,15 +77,17 @@ func (p *Proxy) tcpPacketLoop(
break
}

err = requestGoroutinesSema.Acquire(context.Background())
// TODO(d.kolyshev): Pass and use context from above.
err = reqSema.Acquire(context.Background())
if err != nil {
log.Error("acquiring semaphore: %s", err)
log.Error("dnsproxy: tcp: acquiring semaphore: %s", err)

break
}
go func() {
defer reqSema.Release()

p.handleTCPConnection(clientConn, proto)
requestGoroutinesSema.Release()
}()
}
}
Expand Down
12 changes: 7 additions & 5 deletions proxy/server_udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ func (p *Proxy) udpCreate(ctx context.Context, udpAddr *net.UDPAddr) (*net.UDPCo

// udpPacketLoop listens for incoming UDP packets.
//
// See also the comment on Proxy.requestGoroutinesSema.
func (p *Proxy) udpPacketLoop(conn *net.UDPConn, requestGoroutinesSema syncutil.Semaphore) {
// See also the comment on Proxy.requestsSema.
func (p *Proxy) udpPacketLoop(conn *net.UDPConn, reqSema syncutil.Semaphore) {
log.Info("dnsproxy: entering udp listener loop on %s", conn.LocalAddr())

b := make([]byte, dns.MaxMsgSize)
Expand All @@ -81,15 +81,17 @@ func (p *Proxy) udpPacketLoop(conn *net.UDPConn, requestGoroutinesSema syncutil.
packet := make([]byte, n)
copy(packet, b)

sErr := requestGoroutinesSema.Acquire(context.Background())
// TODO(d.kolyshev): Pass and use context from above.
sErr := reqSema.Acquire(context.Background())
if sErr != nil {
log.Error("acquiring semaphore: %s", sErr)
log.Error("dnsproxy: udp: acquiring semaphore: %s", sErr)

break
}
go func() {
defer reqSema.Release()

p.udpHandlePacket(packet, localIP, remoteAddr, conn)
requestGoroutinesSema.Release()
}()
}
if err != nil {
Expand Down

0 comments on commit 9b7e21e

Please sign in to comment.