diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..3f80115 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,3 @@ +.github +.git +.idea \ No newline at end of file diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..3938344 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,6 @@ +version: 2 +updates: + - package-ecosystem: "gomod" + directory: "/" + schedule: + interval: "daily" diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..49033d7 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,59 @@ +--- +name: CI + +# Controls when the workflow will run +on: + push: + branches: + - main + pull_request: + branches: + - main +permissions: + contents: read +jobs: + golangci-lint: + permissions: + contents: read + pull-requests: read + runs-on: ubuntu-latest + steps: + # Get the repositery's code + - name: Checkout + uses: actions/checkout@v3 + + - name: Set up Golang + uses: actions/setup-go@v3 + with: + go-version: '1.20.x' + check-latest: true + cache: true + + - name: Run golangci-lint + uses: golangci/golangci-lint-action@v3.4.0 + with: + version: latest + args: --verbose + test: + strategy: + fail-fast: false + matrix: + platform: + - ubuntu + go: + - 20 + name: 'tests on ${{ matrix.platform }} | 1.${{ matrix.go }}.x' + runs-on: ${{ matrix.platform }}-latest + steps: + # Get the repositery's code + - name: Checkout + uses: actions/checkout@v3 + + - name: Set up Golang + uses: actions/setup-go@v3 + with: + go-version: '1.${{ matrix.go }}.x' + cache: true + + - name: Run tests + run: go clean -testcache && go test -race -cover -covermode=atomic ./... diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..07e2b73 --- /dev/null +++ b/.gitignore @@ -0,0 +1,17 @@ +.idea + +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, built with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Dependency directories (remove the comment below to include it) +# vendor/ diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..3c86dda --- /dev/null +++ b/Dockerfile @@ -0,0 +1,13 @@ +FROM golang:1.20.7 as build + +WORKDIR /go/src/app +COPY . . + +RUN go mod download +RUN CGO_ENABLED=0 go build -o /go/bin/dns-gateway ./cmd/dns-gateway + +FROM debian:bookworm-slim + +COPY --from=build /go/bin/dns-gateway /usr/sbin/dns-gateway + +ENTRYPOINT ["/usr/sbin/dns-gateway"] diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..cbd0664 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 Andrew Krasichkov + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/cmd/dns-gateway/.gitignore b/cmd/dns-gateway/.gitignore new file mode 100644 index 0000000..dfd4f9a --- /dev/null +++ b/cmd/dns-gateway/.gitignore @@ -0,0 +1 @@ +gateway \ No newline at end of file diff --git a/cmd/dns-gateway/main.go b/cmd/dns-gateway/main.go new file mode 100644 index 0000000..c5b2157 --- /dev/null +++ b/cmd/dns-gateway/main.go @@ -0,0 +1,17 @@ +package main + +import ( + "fmt" + "os" + + _ "go.uber.org/automaxprocs" + + "github.com/buglloc/DNSGateway/internal/commands" +) + +func main() { + if err := commands.Execute(); err != nil { + _, _ = fmt.Fprintf(os.Stderr, "%v\n", err) + os.Exit(1) + } +} diff --git a/example/config.yaml b/example/config.yaml new file mode 100644 index 0000000..652f279 --- /dev/null +++ b/example/config.yaml @@ -0,0 +1,16 @@ +listener: + kind: rfc2136 + rfc2136: + addr: :5454 + clients: + - name: tst. + secret: NzBjOTU4OTVlOTZlOTg5OGQwYTUxYTdjNWYzNTI3NzA5YjIyZTIxNWVjOTc3NWMxNzIxZjdjN2ExNjliNDc1ZCAgLQo= + zones: + test.lala. +upstream: + kind: adguard + adguard: + api_server_url: https://g.buglloc.cc + login: buglloc + password: kek-cheburek + auto_ptr: true \ No newline at end of file diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..57bb5df --- /dev/null +++ b/go.mod @@ -0,0 +1,46 @@ +module github.com/buglloc/DNSGateway + +go 1.20 + +require ( + github.com/buglloc/certifi v0.9.1 + github.com/go-resty/resty/v2 v2.7.0 + github.com/knadh/koanf/parsers/yaml v0.1.0 + github.com/knadh/koanf/providers/env v0.1.0 + github.com/knadh/koanf/providers/file v0.1.0 + github.com/knadh/koanf/v2 v2.0.1 + github.com/labstack/echo/v4 v4.11.1 + github.com/miekg/dns v1.1.55 + github.com/rs/zerolog v1.30.0 + github.com/spf13/cobra v1.7.0 + github.com/stretchr/testify v1.8.4 + go.uber.org/automaxprocs v1.5.3 + golang.org/x/sync v0.3.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/fsnotify/fsnotify v1.6.0 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/knadh/koanf/maps v0.1.1 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/labstack/gommon v0.4.0 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.19 // indirect + github.com/mitchellh/copystructure v1.2.0 // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect + github.com/mitchellh/reflectwalk v1.0.2 // indirect + github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasttemplate v1.2.2 // indirect + golang.org/x/crypto v0.11.0 // indirect + golang.org/x/mod v0.12.0 // indirect + golang.org/x/net v0.12.0 // indirect + golang.org/x/sys v0.10.0 // indirect + golang.org/x/text v0.11.0 // indirect + golang.org/x/tools v0.11.0 // indirect + gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..44d7b18 --- /dev/null +++ b/go.sum @@ -0,0 +1,107 @@ +github.com/buglloc/certifi v0.9.1 h1:Xgeg9xlCWCSLNYSx0moU7mB++RRrRFFdtmvqlRHRqms= +github.com/buglloc/certifi v0.9.1/go.mod h1:gfBmljk3gl0dTYtpx18dLMjx8fXutictjX9RWyBcQT0= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= +github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= +github.com/go-resty/resty/v2 v2.7.0 h1:me+K9p3uhSmXtrBZ4k9jcEAfJmuC8IivWHwaLZwPrFY= +github.com/go-resty/resty/v2 v2.7.0/go.mod h1:9PWDzw47qPphMRFfhsyk0NnSgvluHcljSMVIq3w7q0I= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/knadh/koanf/maps v0.1.1 h1:G5TjmUh2D7G2YWf5SQQqSiHRJEjaicvU0KpypqB3NIs= +github.com/knadh/koanf/maps v0.1.1/go.mod h1:npD/QZY3V6ghQDdcQzl1W4ICNVTkohC8E73eI2xW4yI= +github.com/knadh/koanf/parsers/yaml v0.1.0 h1:ZZ8/iGfRLvKSaMEECEBPM1HQslrZADk8fP1XFUxVI5w= +github.com/knadh/koanf/parsers/yaml v0.1.0/go.mod h1:cvbUDC7AL23pImuQP0oRw/hPuccrNBS2bps8asS0CwY= +github.com/knadh/koanf/providers/env v0.1.0 h1:LqKteXqfOWyx5Ab9VfGHmjY9BvRXi+clwyZozgVRiKg= +github.com/knadh/koanf/providers/env v0.1.0/go.mod h1:RE8K9GbACJkeEnkl8L/Qcj8p4ZyPXZIQ191HJi44ZaQ= +github.com/knadh/koanf/providers/file v0.1.0 h1:fs6U7nrV58d3CFAFh8VTde8TM262ObYf3ODrc//Lp+c= +github.com/knadh/koanf/providers/file v0.1.0/go.mod h1:rjJ/nHQl64iYCtAW2QQnF0eSmDEX/YZ/eNFj5yR6BvA= +github.com/knadh/koanf/v2 v2.0.1 h1:1dYGITt1I23x8cfx8ZnldtezdyaZtfAuRtIFOiRzK7g= +github.com/knadh/koanf/v2 v2.0.1/go.mod h1:ZeiIlIDXTE7w1lMT6UVcNiRAS2/rCeLn/GdLNvY1Dus= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/labstack/echo/v4 v4.11.1 h1:dEpLU2FLg4UVmvCGPuk/APjlH6GDpbEPti61srUUUs4= +github.com/labstack/echo/v4 v4.11.1/go.mod h1:YuYRTSM3CHs2ybfrL8Px48bO6BAnYIN4l8wSTMP6BDQ= +github.com/labstack/gommon v0.4.0 h1:y7cvthEAEbU0yHOf4axH8ZG2NH8knB9iNSoTO8dyIk8= +github.com/labstack/gommon v0.4.0/go.mod h1:uW6kP17uPlLJsD3ijUYn3/M5bAxtlZhMI6m3MFxTMTM= +github.com/mattn/go-colorable v0.1.11/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= +github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/miekg/dns v1.1.55 h1:GoQ4hpsj0nFLYe+bWiCToyrBEJXkQfOOIvFGFy0lEgo= +github.com/miekg/dns v1.1.55/go.mod h1:uInx36IzPl7FYnDcMeVWxj9byh7DutNykX4G9Sj60FY= +github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw= +github.com/mitchellh/copystructure v1.2.0/go.mod h1:qLl+cE2AmVv+CoeAwDPye/v+N2HKCj9FbZEVFJRxO9s= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/mitchellh/reflectwalk v1.0.2 h1:G2LzWKi524PWgd3mLHV8Y5k7s6XUvT0Gef6zxSIeXaQ= +github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= +github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g= +github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/rs/zerolog v1.30.0 h1:SymVODrcRsaRaSInD9yQtKbtWqwsfoPcRff/oRXLj4c= +github.com/rs/zerolog v1.30.0/go.mod h1:/tk+P47gFdPXq4QYjvCmT5/Gsug2nagsFWBWhAiSi1w= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= +github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= +github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo= +github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= +go.uber.org/automaxprocs v1.5.3 h1:kWazyxZUrS3Gs4qUpbwo5kEIMGe/DAvi5Z4tl2NW4j8= +go.uber.org/automaxprocs v1.5.3/go.mod h1:eRbA25aqJrxAbsLO0xy5jVwPt7FQnRgjW+efnwa1WM0= +golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA= +golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= +golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc= +golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/net v0.0.0-20211029224645-99673261e6eb/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.12.0 h1:cfawfvKITfUsFCeJIHJrbSxpeu/E81khclypR0GVT50= +golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= +golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= +golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA= +golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4= +golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.11.0 h1:EMCa6U9S2LtZXLAMoWiR/R8dAQFRqbAitmbJ2UKhoi8= +golang.org/x/tools v0.11.0/go.mod h1:anzJrxPjNtfgiYQYirP2CPGzGLxrH2u2QBhn6Bf3qY8= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/commands/root.go b/internal/commands/root.go new file mode 100644 index 0000000..2e7ffce --- /dev/null +++ b/internal/commands/root.go @@ -0,0 +1,57 @@ +package commands + +import ( + "fmt" + "log" + "os" + + "github.com/rs/zerolog" + "github.com/spf13/cobra" + + "github.com/buglloc/DNSGateway/internal/config" +) + +var rootArgs struct { + Configs []string +} + +var cfg *config.Config + +var rootCmd = &cobra.Command{ + Use: "gateway", + SilenceUsage: true, + SilenceErrors: true, + Short: `Dynamic DNS gateway for AdGuardHome`, +} + +func init() { + cobra.OnInitialize( + initConfig, + initLogger, + ) + + flags := rootCmd.PersistentFlags() + flags.StringSliceVar(&rootArgs.Configs, "config", nil, "config file") + + rootCmd.AddCommand( + startCmd, + ) +} + +func Execute() error { + return rootCmd.Execute() +} + +func initConfig() { + var err error + cfg, err = config.LoadConfig(rootArgs.Configs...) + if err != nil { + _, _ = fmt.Fprintf(os.Stderr, "unable to load config: %v\n", err) + os.Exit(1) + } +} + +func initLogger() { + log.SetOutput(os.Stderr) + zerolog.SetGlobalLevel(zerolog.InfoLevel) +} diff --git a/internal/commands/start.go b/internal/commands/start.go new file mode 100644 index 0000000..6b566d3 --- /dev/null +++ b/internal/commands/start.go @@ -0,0 +1,60 @@ +package commands + +import ( + "context" + "fmt" + "os" + "os/signal" + "syscall" + "time" + + "github.com/rs/zerolog/log" + "github.com/spf13/cobra" +) + +var startCmd = &cobra.Command{ + Use: "start", + SilenceUsage: true, + SilenceErrors: true, + Short: "Starts server", + RunE: func(_ *cobra.Command, _ []string) error { + runtime, err := cfg.NewRuntime() + if err != nil { + return fmt.Errorf("create runtime: %w", err) + } + + instance, err := runtime.NewListener() + if err != nil { + return fmt.Errorf("create gateway: %w", err) + } + + errChan := make(chan error, 1) + okChan := make(chan struct{}) + go func() { + err := instance.ListenAndServe() + if err != nil { + errChan <- err + return + } + + close(okChan) + }() + + stopChan := make(chan os.Signal, 1) + signal.Notify(stopChan, syscall.SIGINT, syscall.SIGTERM) + select { + case <-stopChan: + log.Info().Msg("shutting down gracefully by signal") + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) + defer cancel() + + return instance.Shutdown(ctx) + case err := <-errChan: + log.Error().Err(err).Msg("start failed") + return err + case <-okChan: + } + return nil + }, +} diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..34b82cb --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,62 @@ +package config + +import ( + "fmt" + + "github.com/knadh/koanf/parsers/yaml" + "github.com/knadh/koanf/providers/env" + "github.com/knadh/koanf/providers/file" + "github.com/knadh/koanf/v2" +) + +type Config struct { + Listener Listener `koanf:"listener"` + Upstream Upstream `koanf:"upstream"` +} + +func (c *Config) Validate() error { + return nil +} + +type Runtime struct { + cfg *Config +} + +func LoadConfig(files ...string) (*Config, error) { + out := Config{ + Listener: Listener{ + Kind: ListenerKindRFC2136, + RFC2136: RFC2136Listener{ + Addr: ":53", + Nets: []string{ + "udp", + "tcp", + }, + }, + }, + } + + k := koanf.New(".") + if err := k.Load(env.Provider("DG", "_", nil), nil); err != nil { + return nil, fmt.Errorf("load env config: %w", err) + } + + yamlParser := yaml.Parser() + for _, fpath := range files { + if err := k.Load(file.Provider(fpath), yamlParser); err != nil { + return nil, fmt.Errorf("load %q config: %w", fpath, err) + } + } + + return &out, k.Unmarshal("", &out) +} + +func (c *Config) NewRuntime() (*Runtime, error) { + if err := c.Validate(); err != nil { + return nil, fmt.Errorf("invalid config: %w", err) + } + + return &Runtime{ + cfg: c, + }, nil +} diff --git a/internal/config/listener.go b/internal/config/listener.go new file mode 100644 index 0000000..056edc5 --- /dev/null +++ b/internal/config/listener.go @@ -0,0 +1,112 @@ +package config + +import ( + "errors" + "fmt" + "strings" + + _ "github.com/knadh/koanf/v2" + + "github.com/buglloc/DNSGateway/internal/listener" + "github.com/buglloc/DNSGateway/internal/listener/lrfc2136" + "github.com/buglloc/DNSGateway/internal/upstream" +) + +type ListenerKind string + +const ( + ListenerKindNone ListenerKind = "" + ListenerKindRFC2136 ListenerKind = "rfc2136" +) + +func (k *ListenerKind) UnmarshalText(data []byte) error { + switch strings.ToLower(string(data)) { + case "", "none": + *k = ListenerKindNone + case "rfc2136": + *k = ListenerKindRFC2136 + default: + return fmt.Errorf("invalid listener kind: %s", string(data)) + } + return nil +} + +func (k ListenerKind) MarshalText() ([]byte, error) { + return []byte(k), nil +} + +type Client struct { + Name string `koanf:"name"` + Secret string `koanf:"secret"` + Zones []string `koanf:"zones"` +} + +type RFC2136Listener struct { + Addr string `koanf:"addr"` + Nets []string `koanf:"nets"` + Clients []Client `koanf:"clients"` +} + +type Listener struct { + Kind ListenerKind `koanf:"kind"` + RFC2136 RFC2136Listener `koanf:"rfc2136"` +} + +func (l *RFC2136Listener) Validate() error { + if l.Addr == "" { + return errors.New("addr is empty") + } + + names := make(map[string]struct{}) + for _, cl := range l.Clients { + _, exists := names[cl.Name] + if exists { + return fmt.Errorf("duplicate client name: %s", cl.Name) + } + names[cl.Name] = struct{}{} + + if len(cl.Secret) < 32 { + return fmt.Errorf("invalid client %q secret: too short: 32 chars min", cl.Name) + } + } + return nil +} + +func (r *Runtime) NewListener() (listener.Listener, error) { + u, err := r.NewUpstream() + if err != nil { + return nil, fmt.Errorf("unable to create upstream for listener: %w", err) + } + + switch r.cfg.Listener.Kind { + case ListenerKindRFC2136: + return r.newRFC2136Listener(u, r.cfg.Listener.RFC2136) + default: + return nil, fmt.Errorf("unsupported listener kind: %s", r.cfg.Listener.Kind) + } +} + +func (r *Runtime) newRFC2136Listener(u upstream.Upstream, cfg RFC2136Listener) (*lrfc2136.Listener, error) { + if err := cfg.Validate(); err != nil { + return nil, fmt.Errorf("invalid rfc2136 config: %w", err) + } + + lCfg := lrfc2136.NewConfig(). + Addr(cfg.Addr). + Upstream(u) + + for _, cl := range cfg.Clients { + lCfg.AppendClient(lrfc2136.Client{ + Name: cl.Name, + Secret: cl.Secret, + Zones: cl.Zones, + }) + } + + gw, err := lrfc2136.NewListener(lCfg) + if err != nil { + return nil, fmt.Errorf("create rfc2136 listener: %w", err) + } + + return gw, nil +} diff --git a/internal/config/upstream.go b/internal/config/upstream.go new file mode 100644 index 0000000..b4e2295 --- /dev/null +++ b/internal/config/upstream.go @@ -0,0 +1,79 @@ +package config + +import ( + "errors" + "fmt" + "strings" + + "github.com/buglloc/DNSGateway/internal/upstream" + "github.com/buglloc/DNSGateway/internal/upstream/uadguard" +) + +type UpstreamKind string + +const ( + UpstreamKindNone UpstreamKind = "" + UpstreamKindAdGuard UpstreamKind = "adguard" +) + +func (k *UpstreamKind) UnmarshalText(data []byte) error { + switch strings.ToLower(string(data)) { + case "", "none": + *k = UpstreamKindNone + case "adguard": + *k = UpstreamKindAdGuard + default: + return fmt.Errorf("invalid upstream kind: %s", string(data)) + } + return nil +} + +func (k UpstreamKind) MarshalText() ([]byte, error) { + return []byte(k), nil +} + +type AdguardUpstream struct { + APIServerURL string `koanf:"api_server_url"` + Login string `koanf:"login"` + Password string `koanf:"password"` + AutoPTR bool `koanf:"auto_ptr"` +} + +type Upstream struct { + Kind UpstreamKind `koanf:"kind"` + Adguard AdguardUpstream `koanf:"adguard"` +} + +func (l *AdguardUpstream) Validate() error { + if l.APIServerURL == "" { + return errors.New("addr is empty") + } + + return nil +} + +func (r *Runtime) NewUpstream() (upstream.Upstream, error) { + switch r.cfg.Upstream.Kind { + case UpstreamKindAdGuard: + return r.newAdguardUpstream(r.cfg.Upstream.Adguard) + default: + return nil, fmt.Errorf("unsupported upstream kind: %s", r.cfg.Listener.Kind) + } +} + +func (r *Runtime) newAdguardUpstream(cfg AdguardUpstream) (*uadguard.Upstream, error) { + if err := cfg.Validate(); err != nil { + return nil, fmt.Errorf("invalid adguard config: %w", err) + } + + gw, err := uadguard.NewUpstream( + uadguard.WithUpstream(cfg.APIServerURL), + uadguard.WithBasicAuth(cfg.Login, cfg.Password), + uadguard.WithAutoPTR(cfg.AutoPTR), + ) + if err != nil { + return nil, fmt.Errorf("create adguard upstream: %w", err) + } + + return gw, nil +} diff --git a/internal/listener/listener.go b/internal/listener/listener.go new file mode 100644 index 0000000..4d1d00c --- /dev/null +++ b/internal/listener/listener.go @@ -0,0 +1,8 @@ +package listener + +import "context" + +type Listener interface { + ListenAndServe() error + Shutdown(ctx context.Context) error +} diff --git a/internal/listener/lrfc2136/acl.go b/internal/listener/lrfc2136/acl.go new file mode 100644 index 0000000..2f089ca --- /dev/null +++ b/internal/listener/lrfc2136/acl.go @@ -0,0 +1,55 @@ +package lrfc2136 + +import ( + "fmt" + "strings" +) + +type Client struct { + Name string + Secret string + Zones []string +} + +type ACL struct { + clients map[string][]string +} + +func (a *ACL) IsAllow(tsigName, fqdn string) bool { + fqdn = "." + fqdn + for _, zone := range a.clients[tsigName] { + if strings.HasSuffix(fqdn, zone) { + return true + } + } + + return false +} + +func TsigSecrets(clients ...Client) (map[string]string, error) { + out := make(map[string]string, len(clients)) + for _, c := range clients { + if _, exists := out[c.Name]; exists { + return nil, fmt.Errorf("duplicate client name: %s", c.Name) + } + + out[c.Name] = c.Secret + } + + return out, nil +} + +func TsigACL(clients ...Client) (*ACL, error) { + out := ACL{ + clients: make(map[string][]string, len(clients)), + } + for _, c := range clients { + if _, exists := out.clients[c.Name]; exists { + return nil, fmt.Errorf("duplicate client name: %s", c.Name) + } + + out.clients[c.Name] = c.Zones + } + + return &out, nil +} diff --git a/internal/listener/lrfc2136/config.go b/internal/listener/lrfc2136/config.go new file mode 100644 index 0000000..50be4db --- /dev/null +++ b/internal/listener/lrfc2136/config.go @@ -0,0 +1,70 @@ +package lrfc2136 + +import ( + "errors" + + "github.com/buglloc/DNSGateway/internal/upstream" +) + +type Config struct { + addr string + nets []string + upstream upstream.Upstream + clients []Client +} + +func NewConfig() *Config { + return &Config{ + addr: ":53", + nets: []string{ + "tcp", + "udp", + }, + } +} + +func (c *Config) Addr(addr string) *Config { + c.addr = addr + return c +} + +func (c *Config) Nets(nets ...string) *Config { + c.nets = nets + return c +} + +func (c *Config) Upstream(upstream upstream.Upstream) *Config { + c.upstream = upstream + return c +} + +func (c *Config) Clients(clients ...Client) *Config { + c.clients = clients + return c +} + +func (c *Config) AppendClient(client Client) *Config { + c.clients = append(c.clients, client) + return c +} + +func (c *Config) Validate() error { + var errs []error + if c.addr == "" { + errs = append(errs, errors.New(".Addr is required")) + } + + if len(c.nets) == 0 { + errs = append(errs, errors.New(".Nets is required")) + } + + if c.upstream == nil { + errs = append(errs, errors.New(".Upstream is required")) + } + + if len(c.clients) == 0 { + errs = append(errs, errors.New(".Clients is required")) + } + + return errors.Join(errs...) +} diff --git a/internal/listener/lrfc2136/listener.go b/internal/listener/lrfc2136/listener.go new file mode 100644 index 0000000..28b2354 --- /dev/null +++ b/internal/listener/lrfc2136/listener.go @@ -0,0 +1,262 @@ +package lrfc2136 + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "github.com/miekg/dns" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + "golang.org/x/sync/errgroup" + + "github.com/buglloc/DNSGateway/internal/upstream" +) + +type Listener struct { + listeners []*dns.Server + upsc upstream.Upstream + acl *ACL + mu sync.Mutex + log zerolog.Logger +} + +func NewListener(cfg *Config) (*Listener, error) { + if err := cfg.Validate(); err != nil { + return nil, fmt.Errorf("invalid config: %w", err) + } + + tsigSecrets, err := TsigSecrets(cfg.clients...) + if err != nil { + return nil, fmt.Errorf("parse TSIG secrets: %w", err) + } + + tsigACL, err := TsigACL(cfg.clients...) + if err != nil { + return nil, fmt.Errorf("parse TSIG ACLs: %w", err) + } + + logger := log.With(). + Str("source", "rfc2136-listener"). + Logger() + + app := &Listener{ + listeners: make([]*dns.Server, len(cfg.nets)), + upsc: cfg.upstream, + acl: tsigACL, + log: logger, + } + + for i, net := range cfg.nets { + net := net + app.listeners[i] = &dns.Server{ + Addr: cfg.addr, + Net: net, + TsigSecret: tsigSecrets, + NotifyStartedFunc: func() { + logger.Info(). + Str("net", net). + Str("addr", cfg.addr). + Msg("started") + }, + MsgAcceptFunc: dnsMsgAcceptFunc, + Handler: app, + } + } + + return app, nil +} + +func (a *Listener) ListenAndServe() error { + var g errgroup.Group + for _, l := range a.listeners { + l := l + g.Go(func() error { + if err := l.ListenAndServe(); err != nil { + return fmt.Errorf("listener for net %q failed: %w", l.Net, err) + } + + return nil + }) + } + + return g.Wait() +} + +func (a *Listener) Shutdown(ctx context.Context) error { + var errs []error + for _, l := range a.listeners { + if err := l.ShutdownContext(ctx); err != nil { + errs = append(errs, err) + } + } + return errors.Join(errs...) +} + +func (a *Listener) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + l := log.Logger.With(). + Stringer("client", w.RemoteAddr()). + Logger() + + ctx := l.WithContext(context.Background()) + + a.mu.Lock() + defer a.mu.Unlock() + if err := a.lockedServeDNS(ctx, w, r); err != nil { + l.Error().Err(err).Msg("request failed") + + m := new(dns.Msg) + m.SetReply(r) + _ = w.WriteMsg(m) + } +} + +func (a *Listener) lockedServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) error { + now := time.Now() + tsig := r.IsTsig() + if tsig == nil { + return errors.New("missing TSIG") + } + + if err := w.TsigStatus(); err != nil { + return fmt.Errorf("invalid TSIG: %w", err) + } + + m := new(dns.Msg) + m.SetReply(r) + m.SetTsig(tsig.Hdr.Name, dns.HmacSHA256, 300, time.Now().Unix()) + m.Compress = false + + switch r.Opcode { + case dns.OpcodeQuery: + a.handleQuery(ctx, m, r) + + case dns.OpcodeUpdate: + a.handleUpdates(ctx, m, r) + } + + if err := w.WriteMsg(m); err != nil { + log.Ctx(ctx).Error().Err(err).Msg("write failed") + return nil + } + + log.Ctx(ctx).Info().Dur("elapsed", time.Since(now)).Msg("finished") + return nil +} + +func (a *Listener) handleQuery(ctx context.Context, m *dns.Msg, r *dns.Msg) { + log.Ctx(ctx).Info().Msg("handle query") + for _, q := range r.Question { + rules, err := a.upsc.Query(ctx, upstream.Rule{ + Name: q.Name, + Type: q.Qtype, + }) + if err != nil { + continue + } + + for _, rule := range rules { + rr, err := rule.RR() + if err != nil { + log.Ctx(ctx).Error().Err(err).Str("name", rule.Name).Msg("unable to generate rr") + continue + } + + m.Answer = append(m.Answer, rr) + } + } +} + +func (a *Listener) handleUpdates(ctx context.Context, m *dns.Msg, r *dns.Msg) { + log.Ctx(ctx).Info().Msg("handle updates") + + tx, err := a.upsc.Tx(ctx) + if err != nil { + log.Ctx(ctx).Error().Err(err).Msg("create tx") + return + } + + handleUpdate := func(rr dns.RR) error { + header := rr.Header() + name := header.Name + l := log.Ctx(ctx).With().Str("name", name).Logger() + + if _, ok := dns.IsDomainName(name); !ok { + return errors.New("invalid domain name") + } + + if !a.acl.IsAllow(m.IsTsig().Hdr.Name, name) { + return fmt.Errorf("%q is not allowed for client %q", name, m.IsTsig().Hdr.Name) + } + + if header.Class == dns.ClassANY && header.Rdlength == 0 { + err := tx.Delete(upstream.Rule{ + Name: name, + Type: header.Rrtype, + }) + if err != nil { + return fmt.Errorf("delete: %w", err) + } + l.Info().Str("name", name).Msg("deleted") + return nil + } + + rule, err := upstream.RuleFromRR(rr) + if err != nil { + return fmt.Errorf("parse RR: %w", err) + } + + if err := tx.Append(rule); err != nil { + return fmt.Errorf("update: %w", err) + } + + l.Info().Any("rr", rr).Msg("updated") + return nil + } + + if len(r.Question) > 0 { + a.handleQuery(ctx, m, r) + } + + for _, rr := range r.Ns { + if err := handleUpdate(rr); err != nil { + log.Error().Err(err).Any("rr", rr).Msg("update failed") + } + } + + if err := tx.Commit(context.Background()); err != nil { + log.Error().Err(err).Msg("commit failed") + } +} + +func dnsMsgAcceptFunc(dh dns.Header) dns.MsgAcceptAction { + const qrBits = 1 << 15 // query/response (response=1) + if isResponse := dh.Bits&qrBits != 0; isResponse { + return dns.MsgIgnore + } + + // Don't allow dynamic updates, because then the sections can contain a whole bunch of RRs. + opcode := int(dh.Bits>>11) & 0xF + switch opcode { + case dns.OpcodeQuery: + case dns.OpcodeNotify: + case dns.OpcodeUpdate: + default: + return dns.MsgRejectNotImplemented + } + + if dh.Qdcount != 1 { + return dns.MsgReject + } + // NOTIFY requests can have a SOA in the ANSWER section. See RFC 1996 Section 3.7 and 3.11. + if dh.Ancount > 1 { + return dns.MsgReject + } + + if dh.Arcount > 2 { + return dns.MsgReject + } + return dns.MsgAccept +} diff --git a/internal/upstream/uadguard/adgh_types.go b/internal/upstream/uadguard/adgh_types.go new file mode 100644 index 0000000..bcf4902 --- /dev/null +++ b/internal/upstream/uadguard/adgh_types.go @@ -0,0 +1,9 @@ +package uadguard + +type ErrorRsp struct { + Message string `json:"message"` +} + +type FilteringStatusRsp struct { + Rules []string `json:"user_rules"` +} diff --git a/internal/upstream/uadguard/client.go b/internal/upstream/uadguard/client.go new file mode 100644 index 0000000..25783fb --- /dev/null +++ b/internal/upstream/uadguard/client.go @@ -0,0 +1,101 @@ +package uadguard + +import ( + "context" + "errors" + "fmt" + "net/http" + "time" + + "github.com/go-resty/resty/v2" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + + "github.com/buglloc/DNSGateway/internal/upstream" + "github.com/buglloc/DNSGateway/internal/upstream/uadguard/rules" + "github.com/buglloc/DNSGateway/internal/xhttp" +) + +const ( + DefaultRetries = 3 + DefaultTimeout = 1 * time.Minute + rulesMarkerBegin = "# ---- DNSGateway rules begin ----" + rulesMarkerEnd = "# ---- DNSGateway rules end ----" +) + +var _ upstream.Upstream = (*Upstream)(nil) + +type Upstream struct { + httpc *resty.Client + parser *rules.Parser + log zerolog.Logger + autoPTR bool +} + +func NewUpstream(opts ...Option) (*Upstream, error) { + return NewUpstreamWithHTTP(xhttp.NewHTTPClient(), opts...) +} + +func NewUpstreamWithHTTP(httpc *http.Client, opts ...Option) (*Upstream, error) { + client := &Upstream{ + httpc: resty.NewWithClient(httpc). + SetHeader("User-Agent", "DNSGateway"). + SetHeader("Content-Type", "application/json"). + SetRetryCount(DefaultRetries). + SetTimeout(DefaultTimeout), + log: log.With(). + Str("source", "agh-upstream"). + Logger(), + parser: rules.NewParser(rulesMarkerBegin, rulesMarkerEnd), + } + + for _, opt := range opts { + opt(client) + } + + if client.httpc.BaseURL == "" { + return nil, errors.New("no upstream configured, use WithUpstream()") + } + return client, nil +} + +func (c *Upstream) Query(ctx context.Context, r upstream.Rule) ([]upstream.Rule, error) { + rh, err := c.fetchRules(ctx) + if err != nil { + return nil, err + } + + return rh.Query(r), nil +} + +func (c *Upstream) Tx(ctx context.Context) (upstream.Tx, error) { + rh, err := c.fetchRules(ctx) + if err != nil { + return nil, err + } + + return &Tx{ + httpc: c.httpc, + rules: rh, + autoPTR: c.autoPTR, + }, nil +} + +func (c *Upstream) fetchRules(ctx context.Context) (*rules.Storage, error) { + var rsp FilteringStatusRsp + var errRsp ErrorRsp + httpRsp, err := c.httpc.R(). + SetContext(ctx). + SetResult(&rsp). + SetError(&errRsp). + Get("/control/filtering/status") + if err != nil { + return nil, fmt.Errorf("make http request: %w", err) + } + + if httpRsp.IsError() { + return nil, errors.New(errRsp.Message) + } + + return c.parser.Parse(rsp.Rules) +} diff --git a/internal/upstream/uadguard/client_test.go b/internal/upstream/uadguard/client_test.go new file mode 100644 index 0000000..205b447 --- /dev/null +++ b/internal/upstream/uadguard/client_test.go @@ -0,0 +1,140 @@ +package uadguard_test + +import ( + "context" + "net" + "net/http/httptest" + "sync" + "testing" + + "github.com/labstack/echo/v4" + "github.com/miekg/dns" + "github.com/stretchr/testify/require" + + "github.com/buglloc/DNSGateway/internal/upstream" + "github.com/buglloc/DNSGateway/internal/upstream/uadguard" +) + +func TestSrvMock(t *testing.T) { + adghApp := echo.New() + var rulesMu sync.Mutex + rules := []string{ + "lol", + "kek", + "|ya.ru^$dnsrewrite=NOERROR;A;1.2.3.3", + "# ---- DNSGateway rules begin ----", + "|ya.ru^$dnsrewrite=NOERROR;A;1.2.3.3", + "|3.3.2.1.in-addr.arpa^$dnsrewrite=NOERROR;PTR;ya.ru", + "# ---- DNSGateway rules end ----", + "|ya.ru^$dnsrewrite=NOERROR;A;1.2.3.3", + "kek", + "lol", + } + + adghApp.GET("/control/filtering/status", func(c echo.Context) error { + rulesMu.Lock() + defer rulesMu.Unlock() + + return c.JSON(200, struct { + Rules []string `json:"user_rules"` + Internal int `json:"interval"` + Enabled bool `json:"enabled"` + }{ + Rules: rules, + Internal: 24, + Enabled: true, + }) + }) + adghApp.POST("/control/filtering/set_rules", func(c echo.Context) error { + rulesMu.Lock() + defer rulesMu.Unlock() + + var req struct { + Rules []string `json:"rules"` + } + if err := c.Bind(&req); err != nil { + return c.String(500, err.Error()) + } + + rules = req.Rules + return c.String(200, "") + }) + + srv := httptest.NewServer(adghApp) + defer srv.Close() + + c, err := uadguard.NewUpstream( + uadguard.WithUpstream(srv.URL), + uadguard.WithAutoPTR(true), + ) + require.NoError(t, err) + + rr, err := c.Query(context.Background(), upstream.Rule{ + Name: "lol", + Type: dns.TypePTR, + }) + require.NoError(t, err) + require.Len(t, rr, 0) + + rr, err = c.Query(context.Background(), upstream.Rule{ + Name: "ya.ru.", + Type: dns.TypeA, + }) + require.NoError(t, err) + require.Len(t, rr, 1) + re := upstream.Rule{ + Name: "ya.ru.", + Type: dns.TypeA, + Value: net.ParseIP("1.2.3.3"), + ValueStr: "1.2.3.3", + } + require.EqualValues(t, re, rr[0]) + + rr, err = c.Query(context.Background(), upstream.Rule{ + Type: dns.TypePTR, + ValueStr: "ya.ru.", + }) + require.NoError(t, err) + require.Len(t, rr, 1) + re = upstream.Rule{ + Name: "3.3.2.1.in-addr.arpa.", + Type: dns.TypePTR, + Value: "ya.ru.", + ValueStr: "ya.ru.", + } + require.EqualValues(t, re, rr[0]) + + tx, err := c.Tx(context.Background()) + require.NoError(t, err) + + err = tx.Delete(upstream.Rule{ + Name: "ya.ru.", + Type: dns.TypeA, + }) + require.NoError(t, err) + + err = tx.Append(upstream.Rule{ + Name: "ya.ru.", + Type: dns.TypeA, + Value: net.ParseIP("1.2.4.5"), + }) + require.NoError(t, err) + + err = tx.Commit(context.Background()) + require.NoError(t, err) + + expected := []string{ + "lol", + "kek", + "|ya.ru^$dnsrewrite=NOERROR;A;1.2.3.3", + "# ---- DNSGateway rules begin ----", + "|ya.ru^$dnsrewrite=NOERROR;A;1.2.4.5", + "|5.4.2.1.in-addr.arpa^$dnsrewrite=NOERROR;PTR;ya.ru", + "# ---- DNSGateway rules end ----", + "|ya.ru^$dnsrewrite=NOERROR;A;1.2.3.3", + "kek", + "lol", + } + + require.EqualValues(t, expected, rules) +} diff --git a/internal/upstream/uadguard/opts.go b/internal/upstream/uadguard/opts.go new file mode 100644 index 0000000..0263936 --- /dev/null +++ b/internal/upstream/uadguard/opts.go @@ -0,0 +1,31 @@ +package uadguard + +type Option func(*Upstream) + +func WithUpstream(upstream string) Option { + return func(client *Upstream) { + if upstream == "" { + return + } + + client.httpc.SetBaseURL(upstream) + } +} + +func WithBasicAuth(login, password string) Option { + return func(client *Upstream) { + client.httpc.SetBasicAuth(login, password) + } +} + +func WithAutoPTR(enabled bool) Option { + return func(client *Upstream) { + client.autoPTR = enabled + } +} + +func WithDebug(verbose bool) Option { + return func(client *Upstream) { + client.httpc.SetDebug(verbose) + } +} diff --git a/internal/upstream/uadguard/rules/parser.go b/internal/upstream/uadguard/rules/parser.go new file mode 100644 index 0000000..64ba6c0 --- /dev/null +++ b/internal/upstream/uadguard/rules/parser.go @@ -0,0 +1,222 @@ +package rules + +import ( + "bytes" + "errors" + "fmt" + "io" + "strings" + + "github.com/miekg/dns" +) + +type Parser struct { + beginMarker string + endMarker string +} + +func NewParser(beginMarker, endMarker string) *Parser { + return &Parser{ + beginMarker: beginMarker, + endMarker: endMarker, + } +} + +func (p *Parser) Parse(in []string) (*Storage, error) { + beforeLen := -1 + afterRulesIdx := -1 + ourRules := make([]Rule, 0) + inOurRules := false + for i, r := range in { + if !inOurRules { + if r == p.beginMarker { + inOurRules = true + beforeLen = i + 1 + } + continue + } + + if r == p.endMarker { + afterRulesIdx = i + break + } + + rule, err := p.ParseRule([]byte(r)) + if err != nil { + return nil, fmt.Errorf("invalid rule %q: %w", r, err) + } + + ourRules = append(ourRules, rule) + } + + if beforeLen == -1 { + beforeLen = len(in) + } + + if afterRulesIdx == -1 { + afterRulesIdx = len(in) + } + + before := make([]string, beforeLen) + copy(before, in[:beforeLen]) + if len(before) == 0 || before[len(before)-1] != p.beginMarker { + before = append(before, p.beginMarker) + } + + after := make([]string, len(in)-afterRulesIdx) + copy(after, in[afterRulesIdx:]) + if len(after) == 0 || after[0] != p.endMarker { + after = append([]string{p.endMarker}, after...) + } + + return &Storage{ + before: before, + after: after, + rules: ourRules, + }, nil +} + +// ParseRule parses AdBlock rule with dnsrewrite option +// AdBlock syntax: https://github.com/AdguardTeam/AdGuardHome/wiki/Hosts-Blocklists#adblock-style +// examples: +// +// |4.3.2.1.in-addr.arpa^$dnsrewrite=NOERROR;PTR;example.net. +// |2.0.0.0.0.0.0.0.4.f.7.0.0.0.0.0.0.0.4.3.0.0.0.0.8.b.6.0.2.0.a.2.ip6.arpa^$dnsrewrite=NOERROR;PTR;example.net. +// |ya.ru^$dnsrewrite=NOERROR;A;1.2.3.3 +// |ya.ru^$dnsrewrite=NOERROR;AAAA;::1 +// +// Limitation: +// - only dnsrewrite +// - only strict match +func (p *Parser) ParseRule(in []byte) (Rule, error) { + in = bytes.TrimSpace(in) + if len(in) == 0 { + return Rule{}, io.EOF + } + + if in[0] != '|' { + return Rule{}, errors.New("must starts with '|'") + } + in = in[1:] + + idx := bytes.IndexByte(in, '^') + if idx == -1 { + return Rule{}, errors.New("expected '^' as end of name") + } + + nameBytes := in[:idx] + if nameBytes[0] == '|' { + nameBytes = append([]byte{'*', '.'}, nameBytes[1:]...) + } + name := fqdn(string(nameBytes)) + + in = in[idx+1:] + idx = bytes.IndexByte(in, ';') + if idx == -1 { + return Rule{}, errors.New("expected ';' as end of RRCode") + } + if !bytes.Equal(in[:idx], []byte("$dnsrewrite=NOERROR")) { + return Rule{}, fmt.Errorf("invalid dnsrewrite option: %s", string(in[:idx])) + } + in = in[idx+1:] + + idx = bytes.IndexByte(in, ';') + if idx == -1 { + return Rule{}, errors.New("expected ';' as end of RRType") + } + rrType, err := strToRRType(string(in[:idx])) + if err != nil { + return Rule{}, fmt.Errorf("invalid RRType: %w", err) + } + in = in[idx+1:] + + if idx = indexComment(in); idx != -1 { + in = in[:idx] + } + + return newRule(name, rrType, string(in)) +} + +func strToRRType(s string) (rr uint16, err error) { + // TypeNone and TypeReserved are special cases in package dns. + if strings.EqualFold(s, "none") || strings.EqualFold(s, "reserved") { + return 0, errors.New("dns rr type is none or reserved") + } + + typ, ok := dns.StringToType[strings.ToUpper(s)] + if !ok { + return 0, fmt.Errorf("dns rr type %q is invalid", s) + } + + return typ, nil +} + +func indexComment(in []byte) int { + for i, b := range in { + switch b { + case '!', '#': + return i + } + } + + return -1 +} + +func unFqdn(s string) string { + return strings.TrimSuffix(s, ".") +} + +func fqdn(s string) string { + if isFqdn(s) { + return s + } + return s + "." +} + +func isFqdn(s string) bool { + return len(s) > 1 && s[len(s)-1] == '.' +} + +func validateHostname(fqdn string) (err error) { + l := len(fqdn) + if l == 0 { + return fmt.Errorf("invalid hostname length: %d", l) + } + + parts := strings.Split(fqdn, ".") + lastPart := len(parts) - 1 + for i, p := range parts { + if len(p) == 0 { + if i == lastPart { + break + } + + return fmt.Errorf("empty hostname part at index %d", i) + } + + if r := p[0]; !isValidHostFirstRune(rune(r)) { + return fmt.Errorf("invalid hostname part at index %d: invalid char %q at index %d", i, r, 0) + } + + for j, r := range p[1:] { + if !isValidHostRune(r) { + return fmt.Errorf("invalid hostname part at index %d: invalid char %q at index %d", i, r, j+1) + } + } + } + + return nil +} + +// isValidHostRune returns true if r is a valid rune for a hostname part. +func isValidHostRune(r rune) (ok bool) { + return r == '-' || isValidHostFirstRune(r) +} + +// isValidHostFirstRune returns true if r is a valid first rune for a hostname +// part. +func isValidHostFirstRune(r rune) (ok bool) { + return (r >= 'a' && r <= 'z') || + (r >= 'A' && r <= 'Z') || + (r >= '0' && r <= '9') +} diff --git a/internal/upstream/uadguard/rules/parser_test.go b/internal/upstream/uadguard/rules/parser_test.go new file mode 100644 index 0000000..1c0ece2 --- /dev/null +++ b/internal/upstream/uadguard/rules/parser_test.go @@ -0,0 +1,221 @@ +package rules + +import ( + "net" + "strings" + "testing" + + "github.com/miekg/dns" + "github.com/stretchr/testify/require" + + "github.com/buglloc/DNSGateway/internal/upstream" +) + +func TestParse(t *testing.T) { + cases := []struct { + in []string + out *Storage + }{ + { + in: []string{ + "lol", + "kek", + }, + out: &Storage{ + before: []string{ + "lol", + "kek", + "--b--", + }, + rules: []Rule{}, + after: []string{ + "--e--", + }, + }, + }, + { + in: []string{ + "lol", + "kek", + "--b--", + }, + out: &Storage{ + before: []string{ + "lol", + "kek", + "--b--", + }, + rules: []Rule{}, + after: []string{ + "--e--", + }, + }, + }, + { + in: []string{ + "lol", + "kek", + "--b--", + "--e--", + }, + out: &Storage{ + before: []string{ + "lol", + "kek", + "--b--", + }, + rules: []Rule{}, + after: []string{ + "--e--", + }, + }, + }, + { + in: []string{ + "lol", + "kek", + "|ya.ru^$dnsrewrite=NOERROR;A;1.2.3.3", + "--b--", + "|ya.ru^$dnsrewrite=NOERROR;A;1.2.3.3", + "--e--", + "|ya.ru^$dnsrewrite=NOERROR;A;1.2.3.3", + "kek", + "lol", + }, + out: &Storage{ + before: []string{ + "lol", + "kek", + "|ya.ru^$dnsrewrite=NOERROR;A;1.2.3.3", + "--b--", + }, + rules: []Rule{ + { + Rule: &upstream.Rule{ + Name: "ya.ru.", + Type: dns.TypeA, + Value: net.ParseIP("1.2.3.3"), + ValueStr: "1.2.3.3", + }, + }, + }, + after: []string{ + "--e--", + "|ya.ru^$dnsrewrite=NOERROR;A;1.2.3.3", + "kek", + "lol", + }, + }, + }, + } + + p := NewParser("--b--", "--e--") + for _, tc := range cases { + t.Run(strings.Join(tc.in, "%%"), func(t *testing.T) { + s, err := p.Parse(tc.in) + require.NoError(t, err) + + require.EqualValues(t, tc.out, s) + }) + } +} + +func TestParseRule(t *testing.T) { + cases := []struct { + in string + out Rule + err bool + }{ + { + in: "|4.3.2.1.in-addr.arpa^$dnsrewrite=NOERROR;PTR;example.net.", + out: Rule{ + Rule: &upstream.Rule{ + Name: "4.3.2.1.in-addr.arpa.", + Type: dns.TypePTR, + Value: "example.net.", + ValueStr: "example.net.", + }, + }, + }, + { + in: "|2.0.0.0.0.0.0.0.4.f.7.0.0.0.0.0.0.0.4.3.0.0.0.0.8.b.6.0.2.0.a.2.ip6.arpa^$dnsrewrite=NOERROR;PTR;example.net", + out: Rule{ + Rule: &upstream.Rule{ + Name: "2.0.0.0.0.0.0.0.4.f.7.0.0.0.0.0.0.0.4.3.0.0.0.0.8.b.6.0.2.0.a.2.ip6.arpa.", + Type: dns.TypePTR, + Value: "example.net.", + ValueStr: "example.net.", + }, + }, + }, + { + in: "|ya.ru^$dnsrewrite=NOERROR;A;1.2.3.3", + out: Rule{ + Rule: &upstream.Rule{ + Name: "ya.ru.", + Type: dns.TypeA, + Value: net.ParseIP("1.2.3.3"), + ValueStr: "1.2.3.3", + }, + }, + }, + { + in: "|ya.ru^$dnsrewrite=NOERROR;AAAA;::1", + out: Rule{ + Rule: &upstream.Rule{ + Name: "ya.ru.", + Type: dns.TypeAAAA, + Value: net.ParseIP("::1"), + ValueStr: "::1", + }, + }, + }, + { + in: "|ya.ru^$dnsrewrite=NOERROR;CNAME;google.com", + out: Rule{ + Rule: &upstream.Rule{ + Name: "ya.ru.", + Type: dns.TypeCNAME, + Value: "google.com.", + ValueStr: "google.com.", + }, + }, + }, + { + in: "||ya.ru^$dnsrewrite=NOERROR;CNAME;google.com", + out: Rule{ + Rule: &upstream.Rule{ + Name: "*.ya.ru.", + Type: dns.TypeCNAME, + Value: "google.com.", + ValueStr: "google.com.", + }, + }, + }, + { + in: "|ya.ru^$dnsrewrite=REFUSED;;", + err: true, + }, + { + in: "|canon.example.com^$dnstype=~CNAME", + err: true, + }, + } + + p := NewParser("b", "e") + for _, tc := range cases { + t.Run(tc.in, func(t *testing.T) { + actual, err := p.ParseRule([]byte(tc.in)) + if tc.err { + require.Error(t, err) + return + } + + require.NoError(t, err) + require.EqualExportedValues(t, tc.out, actual) + + ruleStr := strings.TrimSuffix(tc.in, ".") + require.Equal(t, ruleStr, actual.Format()) + }) + } +} diff --git a/internal/upstream/uadguard/rules/rule.go b/internal/upstream/uadguard/rules/rule.go new file mode 100644 index 0000000..6a3df66 --- /dev/null +++ b/internal/upstream/uadguard/rules/rule.go @@ -0,0 +1,246 @@ +package rules + +import ( + "fmt" + "net" + "strconv" + "strings" + + "github.com/miekg/dns" + + "github.com/buglloc/DNSGateway/internal/upstream" +) + +type Rule struct { + *upstream.Rule +} + +func (r *Rule) Same(other Rule) bool { + return r.SameUpstreamRule(other.Rule) +} + +func (r *Rule) SameUpstreamRule(other *upstream.Rule) bool { + if other.Type != 0 && other.Type != r.Type { + return false + } + + if other.Name != "" && other.Name != r.Name { + return false + } + + if other.ValueStr != "" && other.ValueStr != r.ValueStr { + return false + } + + return true +} + +func (r *Rule) Format() string { + value := fmt.Sprint(r.Value) + name := unFqdn(r.Name) + if strictName := strings.TrimPrefix(name, "*."); strictName != name { + name = "|" + strictName + } + + return fmt.Sprintf( + "|%s^$dnsrewrite=NOERROR;%s;%s", + name, dns.TypeToString[r.Type], unFqdn(value), + ) +} + +func newRule(name string, rrType uint16, valStr string) (Rule, error) { + var uRule *upstream.Rule + var err error + switch rrType { + case dns.TypeA: + uRule, err = newRuleA(name, valStr) + + case dns.TypeAAAA: + uRule, err = newRuleAAAA(name, valStr) + + case dns.TypeCNAME: + uRule, err = newRuleCNAME(name, valStr) + + case dns.TypeMX: + uRule, err = newRuleMX(name, valStr) + + case dns.TypePTR: + uRule, err = newRulePTR(name, valStr) + + case dns.TypeTXT: + uRule, err = newRuleTXT(name, valStr) + + case dns.TypeSRV: + uRule, err = newRuleSRV(name, valStr) + + default: + return Rule{}, fmt.Errorf("unsupported rrType %d: %s", rrType, dns.TypeToString[rrType]) + } + + if err != nil { + return Rule{}, fmt.Errorf("invalid rule %d[%s]: %w", rrType, dns.TypeToString[rrType], err) + } + + return Rule{ + Rule: uRule, + }, nil +} + +func newRuleA(name string, valStr string) (*upstream.Rule, error) { + ip := parseIP(valStr) + if ip == nil { + return nil, fmt.Errorf("invalid ipv4: %q", valStr) + } + + if ip4 := ip.To4(); ip4 == nil { + return nil, fmt.Errorf("invalid ipv4: %q", valStr) + } + + return &upstream.Rule{ + Name: name, + Type: dns.TypeA, + Value: ip, + ValueStr: valStr, + }, nil +} + +func newRuleAAAA(name string, valStr string) (*upstream.Rule, error) { + ip := parseIP(valStr) + if ip == nil { + return nil, fmt.Errorf("invalid ipv6: %q", valStr) + } else if ip4 := ip.To4(); ip4 != nil { + return nil, fmt.Errorf("want ipv6, got ipv4: %q", valStr) + } + + return &upstream.Rule{ + Name: name, + Type: dns.TypeAAAA, + Value: ip, + ValueStr: valStr, + }, nil +} + +func newRuleCNAME(name string, valStr string) (*upstream.Rule, error) { + domain := fqdn(valStr) + if err := validateHostname(domain); err != nil { + return nil, fmt.Errorf("invalid new domain %q: %w", valStr, err) + } + + return &upstream.Rule{ + Name: name, + Type: dns.TypeCNAME, + Value: domain, + ValueStr: domain, + }, nil +} + +func newRuleMX(name string, valStr string) (*upstream.Rule, error) { + parts := strings.SplitN(valStr, " ", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("invalid mx: %q", valStr) + } + + pref64, err := strconv.ParseUint(parts[0], 10, 16) + if err != nil { + return nil, fmt.Errorf("invalid mx preference: %w", err) + } + + exch := parts[1] + if err := validateHostname(exch); err != nil { + return nil, fmt.Errorf("invalid mx exchange %q: %w", exch, err) + } + + return &upstream.Rule{ + Name: name, + Type: dns.TypeMX, + Value: &dns.MX{ + Preference: uint16(pref64), + Mx: exch, + }, + ValueStr: valStr, + }, nil +} + +func newRulePTR(name string, valStr string) (*upstream.Rule, error) { + domain := fqdn(valStr) + if err := validateHostname(domain); err != nil { + return nil, fmt.Errorf("invalid ptr host %q: %w", valStr, err) + } + + return &upstream.Rule{ + Name: name, + Type: dns.TypePTR, + Value: domain, + ValueStr: domain, + }, nil +} + +func newRuleTXT(name string, valStr string) (*upstream.Rule, error) { + return &upstream.Rule{ + Name: name, + Type: dns.TypeTXT, + Value: valStr, + ValueStr: valStr, + }, nil +} + +func newRuleSRV(name string, valStr string) (*upstream.Rule, error) { + fields := strings.Split(valStr, " ") + if len(fields) < 4 { + return nil, fmt.Errorf("invalid srv %q: need four fields", valStr) + } + + prio64, err := strconv.ParseUint(fields[0], 10, 16) + if err != nil { + return nil, fmt.Errorf("invalid srv priority: %w", err) + } + + weight64, err := strconv.ParseUint(fields[1], 10, 16) + if err != nil { + return nil, fmt.Errorf("invalid srv weight: %w", err) + } + + port64, err := strconv.ParseUint(fields[2], 10, 16) + if err != nil { + return nil, fmt.Errorf("invalid srv port: %w", err) + } + + target := fields[3] + + // From RFC 2782: + // + // A Target of "." means that the service is decidedly not available + // at this domain. + // + if target != "." { + if err := validateHostname(target); err != nil { + return nil, fmt.Errorf("invalid srv target %q: %w", target, err) + } + } + + return &upstream.Rule{ + Name: name, + Type: dns.TypeSRV, + Value: &dns.SRV{ + Priority: uint16(prio64), + Weight: uint16(weight64), + Port: uint16(port64), + Target: target, + }, + ValueStr: valStr, + }, nil +} + +func parseIP(in string) net.IP { + for _, c := range in { + if c != '.' && c != ':' && + (c < '0' || c > '9') && + (c < 'A' || c > 'F') && + (c < 'a' || c > 'f') && + c != '[' && c != ']' { + return nil + } + } + + return net.ParseIP(in) +} diff --git a/internal/upstream/uadguard/rules/storage.go b/internal/upstream/uadguard/rules/storage.go new file mode 100644 index 0000000..d0afbd6 --- /dev/null +++ b/internal/upstream/uadguard/rules/storage.go @@ -0,0 +1,62 @@ +package rules + +import ( + "github.com/buglloc/DNSGateway/internal/upstream" +) + +type Storage struct { + before []string + after []string + rules []Rule +} + +func (s *Storage) Query(q upstream.Rule) []upstream.Rule { + var out []upstream.Rule + for _, rule := range s.rules { + if !rule.SameUpstreamRule(&q) { + continue + } + + out = append(out, *rule.Rule) + } + + return out +} + +func (s *Storage) Delete(q upstream.Rule) []upstream.Rule { + n := 0 + var deleted []upstream.Rule + for _, rule := range s.rules { + if rule.SameUpstreamRule(&q) { + deleted = append(deleted, *rule.Rule) + continue + } + + s.rules[n] = rule + n++ + } + + s.rules = s.rules[:n] + return deleted +} + +func (s *Storage) Append(r upstream.Rule) { + s.rules = append(s.rules, Rule{ + Rule: &r, + }) +} + +func (s *Storage) Dump() []string { + out := make([]string, len(s.before)) + copy(out, s.before) + + for _, rule := range s.rules { + out = append(out, rule.Format()) + } + + for _, rule := range s.after { + out = append(out, rule) + } + + return out +} diff --git a/internal/upstream/uadguard/tx.go b/internal/upstream/uadguard/tx.go new file mode 100644 index 0000000..b74a8ae --- /dev/null +++ b/internal/upstream/uadguard/tx.go @@ -0,0 +1,106 @@ +package uadguard + +import ( + "context" + "fmt" + + "github.com/go-resty/resty/v2" + "github.com/miekg/dns" + + "github.com/buglloc/DNSGateway/internal/upstream" + "github.com/buglloc/DNSGateway/internal/upstream/uadguard/rules" +) + +var _ upstream.Tx = (*Tx)(nil) + +type Tx struct { + httpc *resty.Client + rules *rules.Storage + autoPTR bool + changed bool +} + +func (t *Tx) Delete(r upstream.Rule) error { + deleted := t.rules.Delete(r) + if len(deleted) == 0 { + return nil + } + + t.changed = true + if !t.needPTR(r.Type) { + return nil + } + + _ = t.rules.Delete(upstream.Rule{ + Name: r.ValueStr, + Type: dns.TypePTR, + }) + + return nil +} + +func (t *Tx) Append(r upstream.Rule) error { + if r.ValueStr == "" { + // TODO(buglloc): fix me + r.ValueStr = fmt.Sprint(r.Value) + } + + t.changed = true + t.rules.Append(r) + if !t.needPTR(r.Type) { + return nil + } + + arpa, err := dns.ReverseAddr(r.ValueStr) + if err != nil { + return fmt.Errorf("unable to generate arpa address: %w", err) + } + + _ = t.rules.Delete(upstream.Rule{ + Name: arpa, + Type: dns.TypePTR, + }) + + t.rules.Append(upstream.Rule{ + Name: arpa, + Type: dns.TypePTR, + Value: r.Name, + ValueStr: r.Name, + }) + return nil +} + +func (t *Tx) Commit(ctx context.Context) error { + if !t.changed { + return nil + } + + httpRsp, err := t.httpc.R(). + SetContext(ctx). + SetHeader("Content-Type", "application/json"). + SetBody(struct { + Rules []string `json:"rules"` + }{ + Rules: t.rules.Dump(), + }). + Post("/control/filtering/set_rules") + if err != nil { + return fmt.Errorf("make http request: %w", err) + } + + if httpRsp.IsError() { + return fmt.Errorf("non-200 response: %s", string(httpRsp.Body())) + } + + return nil +} + +func (t *Tx) Close() {} + +func (t *Tx) needPTR(rrType uint16) bool { + if !t.autoPTR { + return false + } + + return rrType == dns.TypeA || rrType == dns.TypeAAAA +} diff --git a/internal/upstream/upstream.go b/internal/upstream/upstream.go new file mode 100644 index 0000000..a63a3a8 --- /dev/null +++ b/internal/upstream/upstream.go @@ -0,0 +1,128 @@ +package upstream + +import ( + "context" + "fmt" + "net" + + "github.com/miekg/dns" +) + +type Upstream interface { + Tx(ctx context.Context) (Tx, error) + Query(ctx context.Context, q Rule) ([]Rule, error) +} + +type Tx interface { + Delete(r Rule) error + Append(r Rule) error + Commit(ctx context.Context) error + Close() +} + +type RValue any + +type RType = uint16 + +type Rule struct { + Name string + Type RType + Value RValue + ValueStr string +} + +func RuleFromRR(rr dns.RR) (Rule, error) { + var value any + switch v := rr.(type) { + case *dns.A: + value = v.A + + case *dns.AAAA: + value = v.AAAA + + case *dns.CNAME: + value = v.Target + + case *dns.MX: + value = v + + case *dns.PTR: + value = v.Ptr + + case *dns.TXT: + value = v.Txt + + case *dns.SRV: + value = v + + default: + return Rule{}, fmt.Errorf("unsupported rr: %s", rr) + } + + return Rule{ + Name: rr.Header().Name, + Type: rr.Header().Rrtype, + Value: value, + ValueStr: fmt.Sprint(value), + }, nil +} + +func (r *Rule) RR() (dns.RR, error) { + hdr := dns.RR_Header{ + Name: r.Name, + Rrtype: r.Type, + } + + switch r.Type { + case dns.TypeA: + return &dns.A{ + Hdr: hdr, + A: r.Value.(net.IP), + }, nil + + case dns.TypeAAAA: + return &dns.AAAA{ + Hdr: hdr, + AAAA: r.Value.(net.IP), + }, nil + + case dns.TypeCNAME: + return &dns.CNAME{ + Hdr: hdr, + Target: r.Value.(string), + }, nil + + case dns.TypeMX: + mx := r.Value.(*dns.MX) + return &dns.MX{ + Hdr: hdr, + Preference: mx.Preference, + Mx: mx.Mx, + }, nil + + case dns.TypePTR: + return &dns.PTR{ + Hdr: hdr, + Ptr: r.Value.(string), + }, nil + + case dns.TypeTXT: + return &dns.TXT{ + Hdr: hdr, + Txt: []string{r.Value.(string)}, + }, nil + + case dns.TypeSRV: + srv := r.Value.(*dns.SRV) + return &dns.SRV{ + Hdr: hdr, + Priority: srv.Priority, + Weight: srv.Weight, + Port: srv.Port, + Target: srv.Target, + }, nil + + default: + return nil, fmt.Errorf("unsupported rrType %d: %s", r.Type, dns.TypeToString[r.Type]) + } +} diff --git a/internal/xhttp/xhttp.go b/internal/xhttp/xhttp.go new file mode 100644 index 0000000..c23e5b9 --- /dev/null +++ b/internal/xhttp/xhttp.go @@ -0,0 +1,42 @@ +package xhttp + +import ( + "crypto/tls" + "net" + "net/http" + "time" + + "github.com/buglloc/certifi" +) + +const ( + dialTimeout = 1 * time.Second + requestTimeout = 5 * time.Second + keepAlive = 60 * time.Second +) + +func NewHTTPClient() *http.Client { + return &http.Client{ + Transport: NewTransport(), + Timeout: requestTimeout, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } +} + +func NewTransport() http.RoundTripper { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.TLSClientConfig = NewTLSClientConfig() + transport.DialContext = (&net.Dialer{ + Timeout: dialTimeout, + KeepAlive: keepAlive, + }).DialContext + return transport +} + +func NewTLSClientConfig() *tls.Config { + return &tls.Config{ + RootCAs: certifi.NewCertPool(), + } +}