Skip to content

Commit

Permalink
all: imp code
Browse files Browse the repository at this point in the history
  • Loading branch information
EugeneOne1 committed Aug 16, 2024
1 parent dd45ba1 commit 7b95e9d
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 32 deletions.
9 changes: 0 additions & 9 deletions internal/cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ import (
"syscall"
"time"

"github.com/AdguardTeam/dnsproxy/internal/dnsmsg"
"github.com/AdguardTeam/dnsproxy/internal/handler"
"github.com/AdguardTeam/dnsproxy/internal/version"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/errors"
Expand Down Expand Up @@ -107,13 +105,6 @@ func runProxy(ctx context.Context, l *slog.Logger, options *Options) (err error)
return fmt.Errorf("creating proxy: %w", err)
}

reqHdlr := handler.NewDefault(&handler.DefaultConfig{
Logger: l.With(slogutil.KeyPrefix, "default_handler"),
MessageConstructor: dnsmsg.DefaultMessageConstructor{},
DisableIPv6: options.IPv6Disabled,
})
dnsProxy.RequestHandler = reqHdlr.HandleRequest

// Start the proxy server.
err = dnsProxy.Start(ctx)
if err != nil {
Expand Down
10 changes: 10 additions & 0 deletions internal/cmd/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
"os"
"strings"

"github.com/AdguardTeam/dnsproxy/internal/dnsmsg"
"github.com/AdguardTeam/dnsproxy/internal/handler"
proxynetutil "github.com/AdguardTeam/dnsproxy/internal/netutil"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
Expand Down Expand Up @@ -47,6 +49,13 @@ func createProxyConfig(
l *slog.Logger,
options *Options,
) (conf *proxy.Config, err error) {
reqHdlr := handler.NewDefault(&handler.DefaultConfig{
Logger: l.With(slogutil.KeyPrefix, "default_handler"),
// TODO(e.burkov): Use the configured message constructor.
MessageConstructor: dnsmsg.DefaultMessageConstructor{},
HaltIPv6: options.IPv6Disabled,
})

conf = &proxy.Config{
Logger: l.With(slogutil.KeyPrefix, proxy.LogPrefix),

Expand Down Expand Up @@ -74,6 +83,7 @@ func createProxyConfig(
MaxGoroutines: options.MaxGoRoutines,
UsePrivateRDNS: options.UsePrivateRDNS,
PrivateSubnets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
RequestHandler: reqHdlr.HandleRequest,
}

if uiStr := options.HTTPSUserinfo; uiStr != "" {
Expand Down
2 changes: 1 addition & 1 deletion internal/dnsmsg/constructor.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func (DefaultMessageConstructor) NewMsgNODATA(req *dns.Msg) (resp *dns.Msg) {
},
}

if strings.HasPrefix(zone, ".") {
if !strings.HasPrefix(zone, ".") {
soa.Mbox += zone
}

Expand Down
4 changes: 2 additions & 2 deletions internal/dnsproxytest/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ func (u *FakeUpstream) Close() (err error) {
return u.OnClose()
}

// TestMessageConstructor is a mock message constructor implementation to
// simplify testing.
// TestMessageConstructor is a fake [dnsmsg.MessageConstructor] implementation
// for tests.
type TestMessageConstructor struct {
OnNewMsgNXDOMAIN func(req *dns.Msg) (resp *dns.Msg)
OnNewMsgSERVFAIL func(req *dns.Msg) (resp *dns.Msg)
Expand Down
6 changes: 3 additions & 3 deletions internal/dnsproxytest/interface_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package dnsproxytest_test

import (
"github.com/AdguardTeam/dnsproxy/internal/dnsmsg"
"github.com/AdguardTeam/dnsproxy/internal/dnsproxytest"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
)

// type checks
var (
_ upstream.Upstream = (*dnsproxytest.FakeUpstream)(nil)
_ proxy.MessageConstructor = (*dnsproxytest.TestMessageConstructor)(nil)
_ upstream.Upstream = (*dnsproxytest.FakeUpstream)(nil)
_ dnsmsg.MessageConstructor = (*dnsproxytest.TestMessageConstructor)(nil)
)
40 changes: 23 additions & 17 deletions internal/handler/default.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,62 +3,68 @@
package handler

import (
"context"
"log/slog"

"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/miekg/dns"
)

// DefaultConfig is the configuration for the default handler.
// DefaultConfig is the configuration for [Default].
type DefaultConfig struct {
// Logger is the logger. It must not be nil.
Logger *slog.Logger

// MessageConstructor constructs DNS messages. It must not be nil.
MessageConstructor proxy.MessageConstructor

// DisableIPv6 halts the processing of AAAA requests.
DisableIPv6 bool
// HaltIPv6 halts the processing of AAAA requests and makes the handler
// reply with NODATA to them.
HaltIPv6 bool
}

// Default implements the default configurable [proxy.RequestHandler].
type Default struct {
logger *slog.Logger
messageConstructor proxy.MessageConstructor
disableIPv6 bool
isIPv6Halted bool
}

// NewDefault creates a new [Default] handler.
func NewDefault(conf *DefaultConfig) (d *Default) {
return &Default{
logger: conf.Logger,
disableIPv6: conf.DisableIPv6,
isIPv6Halted: conf.HaltIPv6,
messageConstructor: conf.MessageConstructor,
}
}

// HandleRequest checks the IPv6 configuration for current session before
// resolving.
func (h Default) HandleRequest(p *proxy.Proxy, ctx *proxy.DNSContext) (err error) {
if !h.haltAAAA(ctx) {
func (h Default) HandleRequest(p *proxy.Proxy, proxyCtx *proxy.DNSContext) (err error) {
// TODO(e.burkov): Use the [*context.Context] instead of
// [*proxy.DNSContext] when the interface-based handler is implemented.
ctx := context.TODO()

if proxyCtx.Res = h.haltAAAA(ctx, proxyCtx.Req); proxyCtx.Res != nil {
return nil
}

return p.Resolve(ctx)
return p.Resolve(proxyCtx)
}

// haltAAAA halts the processing of AAAA requests if IPv6 is disabled.
func (h *Default) haltAAAA(ctx *proxy.DNSContext) (cont bool) {
if h.disableIPv6 && ctx.Req.Question[0].Qtype == dns.TypeAAAA {
h.logger.Debug(
// haltAAAA halts the processing of AAAA requests if IPv6 is disabled. req must
// not be nil.
func (h *Default) haltAAAA(ctx context.Context, req *dns.Msg) (resp *dns.Msg) {
if h.isIPv6Halted && req.Question[0].Qtype == dns.TypeAAAA {
h.logger.DebugContext(
ctx,
"ipv6 is disabled; replying with empty response",
"req", ctx.Req.Question[0].Name,
"req", req.Question[0].Name,
)

ctx.Res = h.messageConstructor.NewMsgNODATA(ctx.Req)

return false
return h.messageConstructor.NewMsgNODATA(req)
}

return true
return nil
}

0 comments on commit 7b95e9d

Please sign in to comment.