Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(DoH): implement DNS-over-HTTPS, fixes #2 #24

Merged
merged 15 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 76 additions & 7 deletions config.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
package main

import (
cTls "crypto/tls"
"errors"
"fmt"
"github.com/cottand/grimd/tls"
"github.com/jonboulle/clockwork"
"github.com/pelletier/go-toml/v2"
"log"
"os"
"strings"
)

// BuildVersion returns the build version of grimd, this should be incremented every new release
var BuildVersion = "2.2.1"
var BuildVersion = "1.3.0"

// ConfigVersion returns the version of grimd, this should be incremented every time the config changes so grimd presents a warning
var ConfigVersion = "2.2.1"
var ConfigVersion = "1.3.0"

// Config holds the configuration parameters
type Config struct {
Expand Down Expand Up @@ -40,13 +44,34 @@
APIDebug bool
DoH string
Metrics Metrics `toml:"metrics"`
DnsOverHttpServer DnsOverHttpServer
}

type Metrics struct {
Enabled bool
Path string
}

type DnsOverHttpServer struct {
Enabled bool
Bind string
TimeoutMs int64
TLS TlsConfig
parsedTls *cTls.Config
}

type TlsConfig struct {
certPath, keyPath, caPath string
enabled bool
}

func (c TlsConfig) parsedConfig() (*cTls.Config, error) {
if !c.enabled {
return nil, nil
}
return tls.NewTLSConfig(c.certPath, c.keyPath, c.caPath)
}

var defaultConfig = `
# version this config was generated from
version = "%s"
Expand Down Expand Up @@ -133,21 +158,34 @@
# having been turned off.
reactivationdelay = 300

#Dns over HTTPS provider to use.
# Dns over HTTPS upstream provider to use
DoH = "https://cloudflare-dns.com/dns-query"

# Prometheus metrics - enable
# Prometheus metrics - disabled by default
[Metrics]
enabled = false
path = "/metrics"

[DnsOverHttpServer]
enabled = false
bind = "0.0.0.0:80"
timeoutMs = 5000

# TLS config is not required for DoH if you have some proxy (ie, caddy, nginx, traefik...) manage HTTPS for you
[DnsOverHttpServer.TLS]
enabled = false
certPath = ""
keyPath = ""
# if empty, system CAs will be used
caPath = ""
`

func parseDefaultConfig() Config {
var config Config

err := toml.Unmarshal([]byte(defaultConfig), &config)
if err != nil {
logger.Fatalf("There was an error parsing the default config %v", err)
logger.Fatalf("There was an error parsing the default config: %v", err)
}
config.Version = ConfigVersion

Expand All @@ -157,6 +195,19 @@
// WallClock is the wall clock
var WallClock = clockwork.NewRealClock()

func contextualisedParsingErrorFrom(err error) error {
errString := strings.Builder{}
var derr *toml.DecodeError
_, _ = fmt.Fprint(&errString, "could not load config:", err)
if errors.As(err, &derr) {
errString.WriteByte('\n')
_, _ = fmt.Fprintln(&errString, derr.String())
row, col := derr.Position()
_, _ = fmt.Fprintln(&errString, "error occurred at row", row, "column", col)
}
return errors.New(errString.String())
}

// LoadConfig loads the given config file
func LoadConfig(path string) (*Config, error) {

Expand All @@ -167,9 +218,27 @@
return &config, nil
}

if err := toml.Unmarshal([]byte(path), &config); err != nil {
return nil, fmt.Errorf("could not load config: %s", err)
file, err := os.Open(path)
Fixed Show fixed Hide fixed
if _, err := os.Stat(path); os.IsNotExist(err) {
log.Printf("warning, failed to open config - using defaults")
return &config, nil
}

defer func(file *os.File) {
_ = file.Close()
}(file)

d := toml.NewDecoder(file)

if err := d.Decode(&config); err != nil {
return nil, contextualisedParsingErrorFrom(err)
}

dohTls, err := config.DnsOverHttpServer.TLS.parsedConfig()
if err != nil {
return nil, fmt.Errorf("could not load TLS config: %s", err)
}
config.DnsOverHttpServer.parsedTls = dohTls

if config.Version != ConfigVersion {
if config.Version == "" {
Expand Down
20 changes: 20 additions & 0 deletions config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"github.com/pelletier/go-toml/v2"
"github.com/stretchr/testify/assert"
"strings"
"testing"
)

Expand Down Expand Up @@ -35,3 +36,22 @@ path = "/voo"
}
assert.Equal(t, true, config.Metrics.Enabled, "expected overridden value for config.bind")
}

func TestFriendlyErrors(t *testing.T) {
config := parseDefaultConfig()

badConfig := `
[metrics]
enabled = 3
`

err := toml.Unmarshal([]byte(badConfig), &config)
if err == nil {
t.Fatalf("expected an error!")
}
err = contextualisedParsingErrorFrom(err)

if !(strings.Contains(err.Error(), "enabled = 3") && strings.Contains(err.Error(), "row 3 column 11")) {
t.Fatalf("expected error string to contain contextual info, but was %v", err.Error())
}
}
Loading
Loading