diff --git a/AUTHORS b/AUTHORS index bf565389..cdfef861 100644 --- a/AUTHORS +++ b/AUTHORS @@ -14,3 +14,4 @@ Wayne Scott Zlatko Čalušić cgonzalez n0npax +textaligncenter <67056612+textaligncenter@users.noreply.github.com> diff --git a/cmd/rest-server/main.go b/cmd/rest-server/main.go index 4287b97a..384f9816 100644 --- a/cmd/rest-server/main.go +++ b/cmd/rest-server/main.go @@ -1,8 +1,11 @@ package main import ( + "crypto/tls" + "crypto/x509" "errors" "fmt" + "io/ioutil" "log" "net/http" "os" @@ -46,6 +49,8 @@ func init() { flags.BoolVar(&server.TLS, "tls", server.TLS, "turn on TLS support") flags.StringVar(&server.TLSCert, "tls-cert", server.TLSCert, "TLS certificate path") flags.StringVar(&server.TLSKey, "tls-key", server.TLSKey, "TLS key path") + flags.BoolVar(&server.MTLS, "mtls", server.MTLS, "turn on client certificate support") + flags.StringVar(&server.CACert, "cacert", server.CACert, "mTLS CA certificate path") flags.BoolVar(&server.NoAuth, "no-auth", server.NoAuth, "disable .htpasswd authentication") flags.BoolVar(&server.NoVerifyUpload, "no-verify-upload", server.NoVerifyUpload, "do not verify the integrity of uploaded data. DO NOT enable unless the rest-server runs on a very low-power device") @@ -57,12 +62,12 @@ func init() { var version = "0.11.0" -func tlsSettings() (bool, string, string, error) { +func tlsSettings() (bool, bool, string, string, string, error) { var key, cert string - if !server.TLS && (server.TLSKey != "" || server.TLSCert != "") { - return false, "", "", errors.New("requires enabled TLS") + if (!server.TLS && !server.MTLS) && (server.TLSKey != "" || server.TLSCert != "") { + return false, false, "", "", "", errors.New("requires enabled TLS or mTLS") } else if !server.TLS { - return false, "", "", nil + return false, false, "", "", "", nil } if server.TLSKey != "" { key = server.TLSKey @@ -74,7 +79,11 @@ func tlsSettings() (bool, string, string, error) { } else { cert = filepath.Join(server.Path, "public_key") } - return server.TLS, key, cert, nil + + if server.MTLS && server.CACert == "" { + return false, false, "", "", "", errors.New("missing cacert") + } + return server.TLS, server.MTLS, server.CACert, key, cert, nil } func runRoot(cmd *cobra.Command, args []string) error { @@ -125,7 +134,7 @@ func runRoot(cmd *cobra.Command, args []string) error { log.Println("Private repositories disabled") } - enabledTLS, privateKey, publicKey, err := tlsSettings() + enabledTLS, enabledMTLS, caCert, privateKey, publicKey, err := tlsSettings() if err != nil { return err } @@ -138,8 +147,31 @@ func runRoot(cmd *cobra.Command, args []string) error { if !enabledTLS { err = http.Serve(listener, handler) } else { - log.Printf("TLS enabled, private key %s, pubkey %v", privateKey, publicKey) - err = http.ServeTLS(listener, handler, publicKey, privateKey) + if enabledMTLS { + log.Printf("mTLS enabled, private key %s, pubkey %s, cacert %s", privateKey, publicKey, caCert) + caCertPool := x509.NewCertPool() + caCertPem, err := ioutil.ReadFile(caCert) + if err != nil { + return errors.New("unable to read cacert") + } + caCertPool.AppendCertsFromPEM(caCertPem) + + tlsConfig := &tls.Config{ + ClientCAs: caCertPool, + ClientAuth: tls.VerifyClientCertIfGiven, + } + tlsConfig.BuildNameToCertificate() + + server := &http.Server{ + Addr: server.Listen, + TLSConfig: tlsConfig, + } + server.Handler = handler + err = server.ServeTLS(listener, publicKey, privateKey) + } else { + log.Printf("TLS enabled, private key %s, pubkey %v", privateKey, publicKey) + err = http.ServeTLS(listener, handler, publicKey, privateKey) + } } return err diff --git a/cmd/rest-server/main_test.go b/cmd/rest-server/main_test.go index 6f6bc86d..752dc779 100644 --- a/cmd/rest-server/main_test.go +++ b/cmd/rest-server/main_test.go @@ -13,6 +13,7 @@ func TestTLSSettings(t *testing.T) { type expected struct { TLSKey string TLSCert string + CAcert string Error bool } type passed struct { @@ -20,30 +21,29 @@ func TestTLSSettings(t *testing.T) { TLS bool TLSKey string TLSCert string + MTLS bool + CACert string } var tests = []struct { passed passed expected expected }{ - {passed{TLS: false}, expected{"", "", false}}, - {passed{TLS: true}, expected{ - filepath.Join(os.TempDir(), "restic/private_key"), - filepath.Join(os.TempDir(), "restic/public_key"), - false, - }}, - {passed{ - Path: os.TempDir(), - TLS: true, - }, expected{ - filepath.Join(os.TempDir(), "private_key"), - filepath.Join(os.TempDir(), "public_key"), - false, - }}, - {passed{Path: os.TempDir(), TLS: true, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"/etc/restic/key", "/etc/restic/cert", false}}, - {passed{Path: os.TempDir(), TLS: false, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"", "", true}}, - {passed{Path: os.TempDir(), TLS: false, TLSKey: "/etc/restic/key"}, expected{"", "", true}}, - {passed{Path: os.TempDir(), TLS: false, TLSCert: "/etc/restic/cert"}, expected{"", "", true}}, + {passed{TLS: false}, expected{"", "", "", false}}, + {passed{TLS: true}, expected{"/tmp/restic/private_key", "/tmp/restic/public_key", "", false}}, + {passed{Path: "/tmp", TLS: true}, expected{"/tmp/private_key", "/tmp/public_key", "", false}}, + {passed{Path: "/tmp", TLS: true, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"/etc/restic/key", "/etc/restic/cert", "", false}}, + {passed{Path: "/tmp", TLS: false, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"", "", "", true}}, + {passed{Path: "/tmp", TLS: false, TLSKey: "/etc/restic/key"}, expected{"", "", "", true}}, + {passed{Path: "/tmp", TLS: false, TLSCert: "/etc/restic/cert"}, expected{"", "", "", true}}, + + {passed{TLS: false, MTLS: true}, expected{"/tmp/private_key", "/tmp/public_key", "/etc/restic/cacert", false}}, + {passed{TLS: true, MTLS: true}, expected{"/tmp/restic/private_key", "/tmp/restic/public_key", "", false}}, + {passed{Path: "/tmp", TLS: true, MTLS: true}, expected{"/tmp/private_key", "/tmp/public_key", "", false}}, + {passed{Path: "/tmp", TLS: true, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert", MTLS: true}, expected{"/etc/restic/key", "/etc/restic/cert", "", false}}, + {passed{Path: "/tmp", TLS: false, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"", "", "", true}}, + {passed{Path: "/tmp", TLS: false, TLSKey: "/etc/restic/key"}, expected{"", "", "", true}}, + {passed{Path: "/tmp", TLS: false, TLSCert: "/etc/restic/cert"}, expected{"", "", "", true}}, } for _, test := range tests { @@ -57,7 +57,7 @@ func TestTLSSettings(t *testing.T) { server.TLSKey = test.passed.TLSKey server.TLSCert = test.passed.TLSCert - gotTLS, gotKey, gotCert, err := tlsSettings() + gotTLS, gotMTLS, gotCAcert, gotKey, gotCert, err := tlsSettings() if err != nil && !test.expected.Error { t.Fatalf("tls_settings returned err (%v)", err) } @@ -71,6 +71,7 @@ func TestTLSSettings(t *testing.T) { if gotTLS != test.passed.TLS { t.Errorf("TLS enabled, want (%v), got (%v)", test.passed.TLS, gotTLS) } + wantKey := test.expected.TLSKey if gotKey != wantKey { t.Errorf("wrong TLSPrivPath path, want (%v), got (%v)", wantKey, gotKey) @@ -81,6 +82,14 @@ func TestTLSSettings(t *testing.T) { t.Errorf("wrong TLSCertPath path, want (%v), got (%v)", wantCert, gotCert) } + if gotMTLS != test.passed.MTLS { + t.Errorf("mTLS enabled, want (%v), got (%v)", test.passed.MTLS, gotMTLS) + } + + wantCAcert := test.expected.CAcert + if gotCAcert != wantCAcert { + t.Errorf("wrong CACertPath path, want (%v), got (%v)", wantCAcert, gotCAcert) + } }) } } diff --git a/handlers.go b/handlers.go index 9df6adf8..fd411300 100644 --- a/handlers.go +++ b/handlers.go @@ -18,9 +18,11 @@ type Server struct { Listen string Log string CPUProfile string + CACert string TLSKey string TLSCert string TLS bool + MTLS bool NoAuth bool AppendOnly bool PrivateRepos bool diff --git a/mux.go b/mux.go index 6b4ad4c5..b781c0c9 100644 --- a/mux.go +++ b/mux.go @@ -1,7 +1,6 @@ package restserver import ( - "fmt" "log" "net/http" "os" @@ -33,12 +32,35 @@ func (s *Server) checkAuth(r *http.Request) (username string, ok bool) { if s.NoAuth { return username, true } - var password string - username, password, ok = r.BasicAuth() - if !ok || !s.htpasswdFile.Validate(username, password) { - return "", false + + username, ok = s.validateBasicAuth(r) + if ok { + return username, true + } + + username, ok = validateClientCert(r) + if ok { + return username, true + } + + return username, false +} + +func (s *Server) validateBasicAuth(r *http.Request) (string, bool) { + username, password, ok := r.BasicAuth() + return username, ok && s.htpasswdFile.Validate(username, password) +} + +func validateClientCert(r *http.Request) (string, bool) { + if r.TLS != nil { + for _, cert := range r.TLS.PeerCertificates { + username := cert.Subject.CommonName + if username != "" { + return username, true + } + } } - return username, true + return "", false } func (s *Server) wrapMetricsAuth(f http.HandlerFunc) http.HandlerFunc { @@ -62,7 +84,7 @@ func NewHandler(server *Server) (http.Handler, error) { var err error server.htpasswdFile, err = NewHtpasswdFromFile(filepath.Join(server.Path, ".htpasswd")) if err != nil { - return nil, fmt.Errorf("cannot load .htpasswd (use --no-auth to disable): %v", err) + // return nil, fmt.Errorf("cannot load .htpasswd (use --no-auth to disable): %v", err) } }