From b2fa525f49e18303086cc94ede0790917a388004 Mon Sep 17 00:00:00 2001 From: Vladislav Yarmak Date: Fri, 25 Oct 2024 16:06:54 +0300 Subject: [PATCH 1/4] cert auth: improve logging --- auth/cert.go | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/auth/cert.go b/auth/cert.go index 572e2d2..0a88282 100644 --- a/auth/cert.go +++ b/auth/cert.go @@ -1,6 +1,11 @@ package auth -import "net/http" +import ( + "encoding/hex" + "fmt" + "math/big" + "net/http" +) type CertAuth struct{} @@ -8,9 +13,24 @@ func (_ CertAuth) Validate(wr http.ResponseWriter, req *http.Request) (string, b if req.TLS == nil || len(req.TLS.VerifiedChains) < 1 || len(req.TLS.VerifiedChains[0]) < 1 { http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest) return "", false - } else { - return req.TLS.VerifiedChains[0][0].Subject.String(), true } + return fmt.Sprintf( + "Subject: %s, Serial Number: %s", + req.TLS.VerifiedChains[0][0].Subject.String(), + formatSerial(req.TLS.VerifiedChains[0][0].SerialNumber), + ), true } func (_ CertAuth) Stop() {} + +// formatSerial from https://codereview.stackexchange.com/a/165708 +func formatSerial(serial *big.Int) string { + b := serial.Bytes() + buf := make([]byte, 0, 3*len(b)) + x := buf[1*len(b) : 3*len(b)] + hex.Encode(x, b) + for i := 0; i < len(x); i += 2 { + buf = append(buf, x[i], x[i+1], ':') + } + return string(buf[:len(buf)-1]) +} From 222afda727caaac6e8a7b7a5188dab09b8504cde Mon Sep 17 00:00:00 2001 From: Vladislav Yarmak Date: Sat, 26 Oct 2024 01:35:33 +0300 Subject: [PATCH 2/4] added serial number set implementation --- auth/cert.go | 74 +++++++++++++++ auth/cert_test.go | 227 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 301 insertions(+) create mode 100644 auth/cert_test.go diff --git a/auth/cert.go b/auth/cert.go index 0a88282..4dd07e7 100644 --- a/auth/cert.go +++ b/auth/cert.go @@ -1,8 +1,12 @@ package auth import ( + "bufio" + "bytes" "encoding/hex" + "errors" "fmt" + "io" "math/big" "net/http" ) @@ -32,5 +36,75 @@ func formatSerial(serial *big.Int) string { for i := 0; i < len(x); i += 2 { buf = append(buf, x[i], x[i+1], ':') } + if serial.Sign() == -1 { + return "(Negative)" + string(buf[:len(buf)-1]) + } return string(buf[:len(buf)-1]) } + +type serialNumberKey = [20]byte +type serialNumberSet struct { + sns map[serialNumberKey]struct{} +} + +func normalizeSNBytes(b []byte) serialNumberKey { + var k serialNumberKey + copy( + k[max(len(k)-len(b), 0):], + b[max(len(b)-len(k), 0):], + ) + return k +} + +func (s *serialNumberSet) Has(serial *big.Int) bool { + key := normalizeSNBytes(serial.Bytes()) + if s == nil || s.sns == nil { + return false + } + _, found := s.sns[key] + return found +} + +func newSerialNumberSetFromReader(r io.Reader) (*serialNumberSet, error) { + set := make(map[serialNumberKey]struct{}) + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line, _, _ := bytes.Cut(scanner.Bytes(), []byte{'#'}) + line = bytes.TrimSpace(line) + if len(line) == 0 { + continue + } + serial, err := parseSerialBytes(line) + if err != nil { + continue + } + set[normalizeSNBytes(serial)] = struct{}{} + } + + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("unable to load serial number set: %w", err) + } + + return &serialNumberSet{ + sns: set, + }, nil +} + +func parseSerialBytes(serial []byte) ([]byte, error) { + res := make([]byte, (len(serial)+2)/3) + + var i int + for ; i < len(res) && i*3+1 < len(serial); i++ { + if _, err := hex.Decode(res[i:i+1], serial[i*3:i*3+2]); err != nil { + return nil, fmt.Errorf("parseSerialBytes() failed: %w", err) + } + if i*3+2 < len(serial) && serial[i*3+2] != ':' { + return nil, errors.New("missing colon delimiter") + } + } + if i < len(res) { + return nil, errors.New("incomplete serial number string") + } + + return res, nil +} diff --git a/auth/cert_test.go b/auth/cert_test.go new file mode 100644 index 0000000..a1c2d00 --- /dev/null +++ b/auth/cert_test.go @@ -0,0 +1,227 @@ +package auth + +import ( + "bytes" + "fmt" + "math/big" + "strings" + "testing" +) + +func mkbytes(l uint) []byte { + b := make([]byte, l) + for i := uint(0); i < l; i++ { + b[i] = byte(i) + } + return b +} + +var mask *big.Int = big.NewInt(0).Add(big.NewInt(0).Lsh(big.NewInt(1), uint(8*len(serialNumberKey{}))), big.NewInt(-1)) + +func TestNormalizeSNBytes(t *testing.T) { + for i := uint(0); i <= 32; i++ { + t.Run(fmt.Sprintf("%d-bytes", i), func(t *testing.T) { + s := mkbytes(i) + k := normalizeSNBytes(s) + var a, b big.Int + a.SetBytes(s).And(&a, mask) + b.SetBytes(k[:]) + if a.Cmp(&b) != 0 { + t.Fatalf("%d != %d", &a, &b) + } + }) + } +} + +type parseSerialBytesTestcase struct { + input []byte + output []byte + error bool +} + +func TestParseSerialBytes(t *testing.T) { + testcases := []parseSerialBytesTestcase{ + { + input: []byte(""), + output: []byte{}, + }, + { + input: []byte("01:02:03"), + output: []byte{1, 2, 3}, + }, + { + input: []byte("ff"), + output: []byte{255}, + }, + { + input: []byte("ff:f"), + error: true, + }, + { + input: []byte("f"), + error: true, + }, + { + input: []byte("fff"), + error: true, + }, + { + input: []byte("---"), + error: true, + }, + } + for i, testcase := range testcases { + t.Run(fmt.Sprintf("Testcase[%d]", i), func(t *testing.T) { + out, err := parseSerialBytes(testcase.input) + if (err != nil) != testcase.error { + t.Fatalf("unexpected error: %v", err) + } + if bytes.Compare(out, testcase.output) != 0 { + t.Fatalf("expected %v, got %v", testcase.output, out) + } + }) + } +} + +type serialNumberSetTestcase struct { + input *big.Int + output bool +} + +func TestSerialNumberSetSmoke(t *testing.T) { + const testFile = ` +01:00:00:00:00 # test +# test 2 +03 +03 + +00 + 01 +02` + testcases := []serialNumberSetTestcase{ + { + input: big.NewInt(1 << 32), + output: true, + }, + { + input: big.NewInt(0), + output: true, + }, + { + input: big.NewInt(1), + output: true, + }, + { + input: big.NewInt(2), + output: true, + }, + { + input: big.NewInt(3), + output: true, + }, + { + input: big.NewInt(4), + output: false, + }, + { + input: big.NewInt(-2), + output: true, + }, + } + s, err := newSerialNumberSetFromReader(strings.NewReader(testFile)) + if err != nil { + t.Fatalf("unable to load test set: %v", err) + } + for i, testcase := range testcases { + t.Run(fmt.Sprintf("Testcase[%d]", i), func(t *testing.T) { + out := s.Has(testcase.input) + if out != testcase.output { + t.Fatalf("expected %v, got %v", testcase.output, out) + } + }) + } +} + +func TestSerialNumberSetEmpty(t *testing.T) { + const testFile = "" + testcases := []serialNumberSetTestcase{ + { + input: big.NewInt(0), + output: false, + }, + { + input: big.NewInt(1), + output: false, + }, + { + input: big.NewInt(2), + output: false, + }, + } + s, err := newSerialNumberSetFromReader(strings.NewReader(testFile)) + if err != nil { + t.Fatalf("unable to load test set: %v", err) + } + for i, testcase := range testcases { + t.Run(fmt.Sprintf("Testcase[%d]", i), func(t *testing.T) { + out := s.Has(testcase.input) + if out != testcase.output { + t.Fatalf("expected %v, got %v", testcase.output, out) + } + }) + } +} + +func TestSerialNumberSetNullMap(t *testing.T) { + const testFile = "" + testcases := []serialNumberSetTestcase{ + { + input: big.NewInt(0), + output: false, + }, + { + input: big.NewInt(1), + output: false, + }, + { + input: big.NewInt(2), + output: false, + }, + } + s := new(serialNumberSet) + for i, testcase := range testcases { + t.Run(fmt.Sprintf("Testcase[%d]", i), func(t *testing.T) { + out := s.Has(testcase.input) + if out != testcase.output { + t.Fatalf("expected %v, got %v", testcase.output, out) + } + }) + } +} + +func TestSerialNumberSetNull(t *testing.T) { + const testFile = "" + testcases := []serialNumberSetTestcase{ + { + input: big.NewInt(0), + output: false, + }, + { + input: big.NewInt(1), + output: false, + }, + { + input: big.NewInt(2), + output: false, + }, + } + s := (*serialNumberSet)(nil) + for i, testcase := range testcases { + t.Run(fmt.Sprintf("Testcase[%d]", i), func(t *testing.T) { + out := s.Has(testcase.input) + if out != testcase.output { + t.Fatalf("expected %v, got %v", testcase.output, out) + } + }) + } +} From fc6f5e7b786a6f9ef10508dbe3182ac5b8616c02 Mon Sep 17 00:00:00 2001 From: Vladislav Yarmak Date: Sun, 27 Oct 2024 15:19:16 +0200 Subject: [PATCH 3/4] make use of cert blacklist --- auth/auth.go | 2 +- auth/basic.go | 18 ++++---- auth/cert.go | 116 +++++++++++++++++++++++++++++++++++++++++++++++--- 3 files changed, 122 insertions(+), 14 deletions(-) diff --git a/auth/auth.go b/auth/auth.go index 5e7be25..617b693 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -26,7 +26,7 @@ func NewAuth(paramstr string, logger *clog.CondLogger) (Auth, error) { case "basicfile": return NewBasicFileAuth(url, logger) case "cert": - return CertAuth{}, nil + return NewCertAuth(url, logger) case "none": return NoAuth{}, nil default: diff --git a/auth/basic.go b/auth/basic.go index 1592290..9644004 100644 --- a/auth/basic.go +++ b/auth/basic.go @@ -52,13 +52,13 @@ func NewBasicFileAuth(param_url *url.URL, logger *clog.CondLogger) (*BasicAuth, return nil, fmt.Errorf("unable to load initial password list: %w", err) } - reloadIntervalOption := values.Get("reload") - reloadInterval, err := time.ParseDuration(reloadIntervalOption) - if err != nil { - reloadInterval = 0 - } - if reloadInterval == 0 { - reloadInterval = 15 * time.Second + reloadInterval := 15 * time.Second + if reloadIntervalOption := values.Get("reload"); reloadIntervalOption != "" { + parsedInterval, err := time.ParseDuration(reloadIntervalOption) + if err != nil { + logger.Warning("unable to parse reload interval: %v. using default value.", err) + } + reloadInterval = parsedInterval } if reloadInterval > 0 { go auth.reloadLoop(reloadInterval) @@ -108,7 +108,9 @@ func (auth *BasicAuth) reloadLoop(interval time.Duration) { case <-auth.stopChan: return case <-ticker.C: - auth.reload() + if err := auth.reload(); err != nil { + auth.logger.Error("reload failed: %v", err) + } } } } diff --git a/auth/cert.go b/auth/cert.go index 4dd07e7..0928c43 100644 --- a/auth/cert.go +++ b/auth/cert.go @@ -9,23 +9,129 @@ import ( "io" "math/big" "net/http" + "net/url" + "sync" + "sync/atomic" + "time" + + clog "github.com/SenseUnit/dumbproxy/log" ) -type CertAuth struct{} +type serialNumberSetFile struct { + file *serialNumberSet + modTime time.Time +} + +type CertAuth struct { + blacklist atomic.Pointer[serialNumberSetFile] + blacklistFilename string + logger *clog.CondLogger + stopOnce sync.Once + stopChan chan struct{} +} + +func NewCertAuth(param_url *url.URL, logger *clog.CondLogger) (*CertAuth, error) { + values, err := url.ParseQuery(param_url.RawQuery) + if err != nil { + return nil, err + } + + auth := &CertAuth{ + blacklistFilename: values.Get("blacklist"), + logger: logger, + stopChan: make(chan struct{}), + } + auth.blacklist.Store(new(serialNumberSetFile)) -func (_ CertAuth) Validate(wr http.ResponseWriter, req *http.Request) (string, bool) { + if auth.blacklistFilename != "" { + if err := auth.reload(); err != nil { + return nil, fmt.Errorf("unable to load initial certificate blacklist: %w", err) + } + } + + reloadInterval := 15 * time.Second + if reloadIntervalOption := values.Get("reload"); reloadIntervalOption != "" { + parsedInterval, err := time.ParseDuration(reloadIntervalOption) + if err != nil { + logger.Warning("unable to parse reload interval: %v. using default value.", err) + } + reloadInterval = parsedInterval + } + if reloadInterval > 0 { + go auth.reloadLoop(reloadInterval) + } + + return auth, nil +} + +func (auth *CertAuth) Validate(wr http.ResponseWriter, req *http.Request) (string, bool) { if req.TLS == nil || len(req.TLS.VerifiedChains) < 1 || len(req.TLS.VerifiedChains[0]) < 1 { http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest) return "", false } + eeCert := req.TLS.VerifiedChains[0][0] + if auth.blacklist.Load().file.Has(eeCert.SerialNumber) { + http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest) + return "", false + } return fmt.Sprintf( "Subject: %s, Serial Number: %s", - req.TLS.VerifiedChains[0][0].Subject.String(), - formatSerial(req.TLS.VerifiedChains[0][0].SerialNumber), + eeCert.Subject.String(), + formatSerial(eeCert.SerialNumber), ), true } -func (_ CertAuth) Stop() {} +func (auth *CertAuth) Stop() { + auth.stopOnce.Do(func() { + close(auth.stopChan) + }) +} + +func (auth *CertAuth) reload() error { + var oldModTime time.Time + if oldBL := auth.blacklist.Load(); oldBL != nil { + oldModTime = oldBL.modTime + } + + f, modTime, err := openIfModified(auth.blacklistFilename, oldModTime) + if err != nil { + return err + } + if f == nil { + // no changes since last modTime + return nil + } + + auth.logger.Info("reloading certificate blacklist from %q...", auth.blacklistFilename) + newBlacklistSet, err := newSerialNumberSetFromReader(f) + if err != nil { + return err + } + + newBlacklist := &serialNumberSetFile{ + file: newBlacklistSet, + modTime: modTime, + } + auth.blacklist.Store(newBlacklist) + auth.logger.Info("blacklist file reloaded.") + + return nil +} + +func (auth *CertAuth) reloadLoop(interval time.Duration) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-auth.stopChan: + return + case <-ticker.C: + if err := auth.reload(); err != nil { + auth.logger.Error("reload failed: %v", err) + } + } + } +} // formatSerial from https://codereview.stackexchange.com/a/165708 func formatSerial(serial *big.Int) string { From 517abaa28435057ba070cc84b324b4a29cedc5a7 Mon Sep 17 00:00:00 2001 From: Vladislav Yarmak Date: Sun, 27 Oct 2024 15:31:39 +0200 Subject: [PATCH 4/4] upd doc --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index c21dbc7..39f72db 100644 --- a/README.md +++ b/README.md @@ -164,6 +164,8 @@ Authentication parameters are passed as URI via `-auth` parameter. Scheme of URI * `hidden_domain` - same as in `static` provider * `reload` - interval for conditional password file reload, if it was modified since last load. Use negative duration to disable autoreload. Default: `15s`. * `cert` - use mutual TLS authentication with client certificates. In order to use this auth provider server must listen sockert in TLS mode (`-cert` and `-key` options) and client CA file must be specified (`-cacert`). Example: `cert://`. + * `blacklist` - location of file with list of serial numbers of blocked certificates, one per each line in form of hex-encoded colon-separated bytes. Example: `ab:01:02:03`. Empty lines and comments starting with `#` are ignored. + * `reload` - interval for certificate blacklist file reload, if it was modified since last load. Use negative duration to disable autoreload. Default: `15s`. ## Synopsis