Skip to content

Commit

Permalink
Merge pull request #75 from SenseUnit/imp/cert_auth
Browse files Browse the repository at this point in the history
Improved certificate authentication
  • Loading branch information
Snawoot authored Oct 27, 2024
2 parents 3f04887 + 517abaa commit cbf4bdc
Show file tree
Hide file tree
Showing 5 changed files with 446 additions and 15 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ Authentication parameters are passed as URI via `-auth` parameter. Scheme of URI
* `hidden_domain` - same as in `static` provider
* `reload` - interval for conditional password file reload, if it was modified since last load. Use negative duration to disable autoreload. Default: `15s`.
* `cert` - use mutual TLS authentication with client certificates. In order to use this auth provider server must listen sockert in TLS mode (`-cert` and `-key` options) and client CA file must be specified (`-cacert`). Example: `cert://`.
* `blacklist` - location of file with list of serial numbers of blocked certificates, one per each line in form of hex-encoded colon-separated bytes. Example: `ab:01:02:03`. Empty lines and comments starting with `#` are ignored.
* `reload` - interval for certificate blacklist file reload, if it was modified since last load. Use negative duration to disable autoreload. Default: `15s`.

## Synopsis

Expand Down
2 changes: 1 addition & 1 deletion auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func NewAuth(paramstr string, logger *clog.CondLogger) (Auth, error) {
case "basicfile":
return NewBasicFileAuth(url, logger)
case "cert":
return CertAuth{}, nil
return NewCertAuth(url, logger)
case "none":
return NoAuth{}, nil
default:
Expand Down
18 changes: 10 additions & 8 deletions auth/basic.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ func NewBasicFileAuth(param_url *url.URL, logger *clog.CondLogger) (*BasicAuth,
return nil, fmt.Errorf("unable to load initial password list: %w", err)
}

reloadIntervalOption := values.Get("reload")
reloadInterval, err := time.ParseDuration(reloadIntervalOption)
if err != nil {
reloadInterval = 0
}
if reloadInterval == 0 {
reloadInterval = 15 * time.Second
reloadInterval := 15 * time.Second
if reloadIntervalOption := values.Get("reload"); reloadIntervalOption != "" {
parsedInterval, err := time.ParseDuration(reloadIntervalOption)
if err != nil {
logger.Warning("unable to parse reload interval: %v. using default value.", err)
}
reloadInterval = parsedInterval
}
if reloadInterval > 0 {
go auth.reloadLoop(reloadInterval)
Expand Down Expand Up @@ -108,7 +108,9 @@ func (auth *BasicAuth) reloadLoop(interval time.Duration) {
case <-auth.stopChan:
return
case <-ticker.C:
auth.reload()
if err := auth.reload(); err != nil {
auth.logger.Error("reload failed: %v", err)
}
}
}
}
Expand Down
212 changes: 206 additions & 6 deletions auth/cert.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,216 @@
package auth

import "net/http"
import (
"bufio"
"bytes"
"encoding/hex"
"errors"
"fmt"
"io"
"math/big"
"net/http"
"net/url"
"sync"
"sync/atomic"
"time"

type CertAuth struct{}
clog "github.com/SenseUnit/dumbproxy/log"
)

func (_ CertAuth) Validate(wr http.ResponseWriter, req *http.Request) (string, bool) {
type serialNumberSetFile struct {
file *serialNumberSet
modTime time.Time
}

type CertAuth struct {
blacklist atomic.Pointer[serialNumberSetFile]
blacklistFilename string
logger *clog.CondLogger
stopOnce sync.Once
stopChan chan struct{}
}

func NewCertAuth(param_url *url.URL, logger *clog.CondLogger) (*CertAuth, error) {
values, err := url.ParseQuery(param_url.RawQuery)
if err != nil {
return nil, err
}

auth := &CertAuth{
blacklistFilename: values.Get("blacklist"),
logger: logger,
stopChan: make(chan struct{}),
}
auth.blacklist.Store(new(serialNumberSetFile))

if auth.blacklistFilename != "" {
if err := auth.reload(); err != nil {
return nil, fmt.Errorf("unable to load initial certificate blacklist: %w", err)
}
}

reloadInterval := 15 * time.Second
if reloadIntervalOption := values.Get("reload"); reloadIntervalOption != "" {
parsedInterval, err := time.ParseDuration(reloadIntervalOption)
if err != nil {
logger.Warning("unable to parse reload interval: %v. using default value.", err)
}
reloadInterval = parsedInterval
}
if reloadInterval > 0 {
go auth.reloadLoop(reloadInterval)
}

return auth, nil
}

func (auth *CertAuth) Validate(wr http.ResponseWriter, req *http.Request) (string, bool) {
if req.TLS == nil || len(req.TLS.VerifiedChains) < 1 || len(req.TLS.VerifiedChains[0]) < 1 {
http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest)
return "", false
} else {
return req.TLS.VerifiedChains[0][0].Subject.String(), true
}
eeCert := req.TLS.VerifiedChains[0][0]
if auth.blacklist.Load().file.Has(eeCert.SerialNumber) {
http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest)
return "", false
}
return fmt.Sprintf(
"Subject: %s, Serial Number: %s",
eeCert.Subject.String(),
formatSerial(eeCert.SerialNumber),
), true
}

func (auth *CertAuth) Stop() {
auth.stopOnce.Do(func() {
close(auth.stopChan)
})
}

func (auth *CertAuth) reload() error {
var oldModTime time.Time
if oldBL := auth.blacklist.Load(); oldBL != nil {
oldModTime = oldBL.modTime
}

f, modTime, err := openIfModified(auth.blacklistFilename, oldModTime)
if err != nil {
return err
}
if f == nil {
// no changes since last modTime
return nil
}

auth.logger.Info("reloading certificate blacklist from %q...", auth.blacklistFilename)
newBlacklistSet, err := newSerialNumberSetFromReader(f)
if err != nil {
return err
}

newBlacklist := &serialNumberSetFile{
file: newBlacklistSet,
modTime: modTime,
}
auth.blacklist.Store(newBlacklist)
auth.logger.Info("blacklist file reloaded.")

return nil
}

func (auth *CertAuth) reloadLoop(interval time.Duration) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-auth.stopChan:
return
case <-ticker.C:
if err := auth.reload(); err != nil {
auth.logger.Error("reload failed: %v", err)
}
}
}
}

// formatSerial from https://codereview.stackexchange.com/a/165708
func formatSerial(serial *big.Int) string {
b := serial.Bytes()
buf := make([]byte, 0, 3*len(b))
x := buf[1*len(b) : 3*len(b)]
hex.Encode(x, b)
for i := 0; i < len(x); i += 2 {
buf = append(buf, x[i], x[i+1], ':')
}
if serial.Sign() == -1 {
return "(Negative)" + string(buf[:len(buf)-1])
}
return string(buf[:len(buf)-1])
}

type serialNumberKey = [20]byte
type serialNumberSet struct {
sns map[serialNumberKey]struct{}
}

func normalizeSNBytes(b []byte) serialNumberKey {
var k serialNumberKey
copy(
k[max(len(k)-len(b), 0):],
b[max(len(b)-len(k), 0):],
)
return k
}

func (_ CertAuth) Stop() {}
func (s *serialNumberSet) Has(serial *big.Int) bool {
key := normalizeSNBytes(serial.Bytes())
if s == nil || s.sns == nil {
return false
}
_, found := s.sns[key]
return found
}

func newSerialNumberSetFromReader(r io.Reader) (*serialNumberSet, error) {
set := make(map[serialNumberKey]struct{})
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line, _, _ := bytes.Cut(scanner.Bytes(), []byte{'#'})
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
serial, err := parseSerialBytes(line)
if err != nil {
continue
}
set[normalizeSNBytes(serial)] = struct{}{}
}

if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("unable to load serial number set: %w", err)
}

return &serialNumberSet{
sns: set,
}, nil
}

func parseSerialBytes(serial []byte) ([]byte, error) {
res := make([]byte, (len(serial)+2)/3)

var i int
for ; i < len(res) && i*3+1 < len(serial); i++ {
if _, err := hex.Decode(res[i:i+1], serial[i*3:i*3+2]); err != nil {
return nil, fmt.Errorf("parseSerialBytes() failed: %w", err)
}
if i*3+2 < len(serial) && serial[i*3+2] != ':' {
return nil, errors.New("missing colon delimiter")
}
}
if i < len(res) {
return nil, errors.New("incomplete serial number string")
}

return res, nil
}
Loading

0 comments on commit cbf4bdc

Please sign in to comment.