Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Authentication using TLS Client Certificates / mTLS #191

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ Wayne Scott <[email protected]>
Zlatko Čalušić <[email protected]>
cgonzalez <[email protected]>
n0npax <[email protected]>
textaligncenter <[email protected]>
48 changes: 40 additions & 8 deletions cmd/rest-server/main.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package main

import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/ioutil"
"log"
"net/http"
"os"
Expand Down Expand Up @@ -46,6 +49,8 @@ func init() {
flags.BoolVar(&server.TLS, "tls", server.TLS, "turn on TLS support")
flags.StringVar(&server.TLSCert, "tls-cert", server.TLSCert, "TLS certificate path")
flags.StringVar(&server.TLSKey, "tls-key", server.TLSKey, "TLS key path")
flags.BoolVar(&server.MTLS, "mtls", server.MTLS, "turn on client certificate support")
flags.StringVar(&server.CACert, "cacert", server.CACert, "mTLS CA certificate path")
flags.BoolVar(&server.NoAuth, "no-auth", server.NoAuth, "disable .htpasswd authentication")
flags.BoolVar(&server.NoVerifyUpload, "no-verify-upload", server.NoVerifyUpload,
"do not verify the integrity of uploaded data. DO NOT enable unless the rest-server runs on a very low-power device")
Expand All @@ -57,12 +62,12 @@ func init() {

var version = "0.11.0"

func tlsSettings() (bool, string, string, error) {
func tlsSettings() (bool, bool, string, string, string, error) {
var key, cert string
if !server.TLS && (server.TLSKey != "" || server.TLSCert != "") {
return false, "", "", errors.New("requires enabled TLS")
if (!server.TLS && !server.MTLS) && (server.TLSKey != "" || server.TLSCert != "") {
return false, false, "", "", "", errors.New("requires enabled TLS or mTLS")
} else if !server.TLS {
return false, "", "", nil
return false, false, "", "", "", nil
}
if server.TLSKey != "" {
key = server.TLSKey
Expand All @@ -74,7 +79,11 @@ func tlsSettings() (bool, string, string, error) {
} else {
cert = filepath.Join(server.Path, "public_key")
}
return server.TLS, key, cert, nil

if server.MTLS && server.CACert == "" {
return false, false, "", "", "", errors.New("missing cacert")
}
return server.TLS, server.MTLS, server.CACert, key, cert, nil
}

func runRoot(cmd *cobra.Command, args []string) error {
Expand Down Expand Up @@ -125,7 +134,7 @@ func runRoot(cmd *cobra.Command, args []string) error {
log.Println("Private repositories disabled")
}

enabledTLS, privateKey, publicKey, err := tlsSettings()
enabledTLS, enabledMTLS, caCert, privateKey, publicKey, err := tlsSettings()
if err != nil {
return err
}
Expand All @@ -138,8 +147,31 @@ func runRoot(cmd *cobra.Command, args []string) error {
if !enabledTLS {
err = http.Serve(listener, handler)
} else {
log.Printf("TLS enabled, private key %s, pubkey %v", privateKey, publicKey)
err = http.ServeTLS(listener, handler, publicKey, privateKey)
if enabledMTLS {
log.Printf("mTLS enabled, private key %s, pubkey %s, cacert %s", privateKey, publicKey, caCert)
caCertPool := x509.NewCertPool()
caCertPem, err := ioutil.ReadFile(caCert)
if err != nil {
return errors.New("unable to read cacert")
}
caCertPool.AppendCertsFromPEM(caCertPem)

tlsConfig := &tls.Config{
ClientCAs: caCertPool,
ClientAuth: tls.VerifyClientCertIfGiven,
}
tlsConfig.BuildNameToCertificate()

server := &http.Server{
Addr: server.Listen,
TLSConfig: tlsConfig,
}
server.Handler = handler
err = server.ServeTLS(listener, publicKey, privateKey)
} else {
log.Printf("TLS enabled, private key %s, pubkey %v", privateKey, publicKey)
err = http.ServeTLS(listener, handler, publicKey, privateKey)
}
}

return err
Expand Down
47 changes: 28 additions & 19 deletions cmd/rest-server/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,37 +13,37 @@ func TestTLSSettings(t *testing.T) {
type expected struct {
TLSKey string
TLSCert string
CAcert string
Error bool
}
type passed struct {
Path string
TLS bool
TLSKey string
TLSCert string
MTLS bool
CACert string
}

var tests = []struct {
passed passed
expected expected
}{
{passed{TLS: false}, expected{"", "", false}},
{passed{TLS: true}, expected{
filepath.Join(os.TempDir(), "restic/private_key"),
filepath.Join(os.TempDir(), "restic/public_key"),
false,
}},
{passed{
Path: os.TempDir(),
TLS: true,
}, expected{
filepath.Join(os.TempDir(), "private_key"),
filepath.Join(os.TempDir(), "public_key"),
false,
}},
{passed{Path: os.TempDir(), TLS: true, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"/etc/restic/key", "/etc/restic/cert", false}},
{passed{Path: os.TempDir(), TLS: false, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"", "", true}},
{passed{Path: os.TempDir(), TLS: false, TLSKey: "/etc/restic/key"}, expected{"", "", true}},
{passed{Path: os.TempDir(), TLS: false, TLSCert: "/etc/restic/cert"}, expected{"", "", true}},
{passed{TLS: false}, expected{"", "", "", false}},
{passed{TLS: true}, expected{"/tmp/restic/private_key", "/tmp/restic/public_key", "", false}},
{passed{Path: "/tmp", TLS: true}, expected{"/tmp/private_key", "/tmp/public_key", "", false}},
{passed{Path: "/tmp", TLS: true, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"/etc/restic/key", "/etc/restic/cert", "", false}},
{passed{Path: "/tmp", TLS: false, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"", "", "", true}},
{passed{Path: "/tmp", TLS: false, TLSKey: "/etc/restic/key"}, expected{"", "", "", true}},
{passed{Path: "/tmp", TLS: false, TLSCert: "/etc/restic/cert"}, expected{"", "", "", true}},

{passed{TLS: false, MTLS: true}, expected{"/tmp/private_key", "/tmp/public_key", "/etc/restic/cacert", false}},
{passed{TLS: true, MTLS: true}, expected{"/tmp/restic/private_key", "/tmp/restic/public_key", "", false}},
{passed{Path: "/tmp", TLS: true, MTLS: true}, expected{"/tmp/private_key", "/tmp/public_key", "", false}},
{passed{Path: "/tmp", TLS: true, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert", MTLS: true}, expected{"/etc/restic/key", "/etc/restic/cert", "", false}},
{passed{Path: "/tmp", TLS: false, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"", "", "", true}},
{passed{Path: "/tmp", TLS: false, TLSKey: "/etc/restic/key"}, expected{"", "", "", true}},
{passed{Path: "/tmp", TLS: false, TLSCert: "/etc/restic/cert"}, expected{"", "", "", true}},
}

for _, test := range tests {
Expand All @@ -57,7 +57,7 @@ func TestTLSSettings(t *testing.T) {
server.TLSKey = test.passed.TLSKey
server.TLSCert = test.passed.TLSCert

gotTLS, gotKey, gotCert, err := tlsSettings()
gotTLS, gotMTLS, gotCAcert, gotKey, gotCert, err := tlsSettings()
if err != nil && !test.expected.Error {
t.Fatalf("tls_settings returned err (%v)", err)
}
Expand All @@ -71,6 +71,7 @@ func TestTLSSettings(t *testing.T) {
if gotTLS != test.passed.TLS {
t.Errorf("TLS enabled, want (%v), got (%v)", test.passed.TLS, gotTLS)
}

wantKey := test.expected.TLSKey
if gotKey != wantKey {
t.Errorf("wrong TLSPrivPath path, want (%v), got (%v)", wantKey, gotKey)
Expand All @@ -81,6 +82,14 @@ func TestTLSSettings(t *testing.T) {
t.Errorf("wrong TLSCertPath path, want (%v), got (%v)", wantCert, gotCert)
}

if gotMTLS != test.passed.MTLS {
t.Errorf("mTLS enabled, want (%v), got (%v)", test.passed.MTLS, gotMTLS)
}

wantCAcert := test.expected.CAcert
if gotCAcert != wantCAcert {
t.Errorf("wrong CACertPath path, want (%v), got (%v)", wantCAcert, gotCAcert)
}
})
}
}
Expand Down
2 changes: 2 additions & 0 deletions handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ type Server struct {
Listen string
Log string
CPUProfile string
CACert string
TLSKey string
TLSCert string
TLS bool
MTLS bool
NoAuth bool
AppendOnly bool
PrivateRepos bool
Expand Down
36 changes: 29 additions & 7 deletions mux.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package restserver

import (
"fmt"
"log"
"net/http"
"os"
Expand Down Expand Up @@ -33,12 +32,35 @@ func (s *Server) checkAuth(r *http.Request) (username string, ok bool) {
if s.NoAuth {
return username, true
}
var password string
username, password, ok = r.BasicAuth()
if !ok || !s.htpasswdFile.Validate(username, password) {
return "", false

username, ok = s.validateBasicAuth(r)
if ok {
return username, true
}

username, ok = validateClientCert(r)
if ok {
return username, true
}

return username, false
}

func (s *Server) validateBasicAuth(r *http.Request) (string, bool) {
username, password, ok := r.BasicAuth()
return username, ok && s.htpasswdFile.Validate(username, password)
}

func validateClientCert(r *http.Request) (string, bool) {
if r.TLS != nil {
for _, cert := range r.TLS.PeerCertificates {
username := cert.Subject.CommonName
if username != "" {
return username, true
}
}
}
return username, true
return "", false
}

func (s *Server) wrapMetricsAuth(f http.HandlerFunc) http.HandlerFunc {
Expand All @@ -62,7 +84,7 @@ func NewHandler(server *Server) (http.Handler, error) {
var err error
server.htpasswdFile, err = NewHtpasswdFromFile(filepath.Join(server.Path, ".htpasswd"))
if err != nil {
return nil, fmt.Errorf("cannot load .htpasswd (use --no-auth to disable): %v", err)
// return nil, fmt.Errorf("cannot load .htpasswd (use --no-auth to disable): %v", err)
}
}

Expand Down