diff --git a/README.md b/README.md index d2fd3c1..707876b 100644 --- a/README.md +++ b/README.md @@ -30,23 +30,21 @@ Examples: dserve Serves the current directory over http at :9011 dserve -local Serves the current directory on localhost:9011 dserve -dir ~/dir Serves the directory ~/dir over http - dserve -secure Serves the current directory with basicauth using sample .basicauth.json - dserve -secure -basicauth myauth.json - Serves the current directory with basicauth using config file myauth.json + dserve -basic "guest:Pass1234" + Serves the current directory with basicauth (only use this over https) Flags: - -basicauth string - file to be used for basicauth json config (default ".basicauth.json") -dir string the directory to serve, defaults to current directory (default "./") - -local - whether to serve on all address or on localhost, default all addresses + -local bool + whether to only serve on localhost -port int the port to serve at, defaults 9011 (default 9011) - -secure - whether to create a basic_auth secured secure/ directory, default false -timeout duration http server read timeout, write timeout will be double this (default 3m0s) + -basicauth string + enable HTTP basic authentication, arguments should be USERNAME:PASSWORD + example: dserve -basicauth "admin:passw0rd" ``` diff --git a/main.go b/main.go index 40f4dfc..4c45871 100644 --- a/main.go +++ b/main.go @@ -1,11 +1,8 @@ package main import ( - "encoding/base64" - "encoding/json" "flag" "fmt" - "io/ioutil" "log" "net/http" "os" @@ -17,8 +14,7 @@ var ( dir = flag.String("dir", "./", "the directory to serve, defaults to current directory") port = flag.Int("port", 9011, "the port to serve at, defaults 9011") local = flag.Bool("local", false, "whether to serve on all address or on localhost, default all addresses") - secure = flag.Bool("secure", false, "whether to create a basic_auth secured secure/ directory, default false") - basicauth = flag.String("basicauth", ".basicauth.json", "file to be used for basicauth json config") + basicauth = flag.String("basicauth", "", "enable basic authentication") timeout = flag.Duration("timeout", time.Minute*3, "http server read timeout, write timeout will be double this") ) @@ -27,14 +23,13 @@ var usage = func() { Usage: dserve - dserve [flags].. [directory] + dserve [flags].. Examples: dserve Serves the current directory over http at :9011 dserve -local Serves the current directory on localhost:9011 dserve -dir ~/dir Serves the directory ~/dir over http - dserve -secure Serves the current directory with basicauth using sample .basicauth.json - dserve -secure -basicauth myauth.json + dserve -basicauth admin:Passw0rd Serves the current directory with basicauth using config file myauth.json Flags: @@ -46,7 +41,9 @@ Flags: func main() { flag.Usage = usage flag.Parse() + log.SetPrefix("dserve: ") + if err := os.Chdir(*dir); err != nil { log.Fatal(err) } @@ -54,38 +51,30 @@ func main() { if *local { addr = "localhost" } - if *secure { - if err := authInit(); err != nil { - fmt.Printf("Basic Auth credentials %s missing: edit and rename %s.sample\n", - *basicauth, *basicauth) - os.Exit(1) - } + if err := authInit(*basicauth); err != nil { + fmt.Print("invalid basicauth flag value: value should be USERNAME:PASSWORD, e.g. dserve -basicauth admin:passw0rd") + os.Exit(1) } listenAddr := fmt.Sprintf("%s:%d", addr, *port) fmt.Printf("Launching dserve http server %s on %s\n", *dir, listenAddr) - if err := Serve(listenAddr, *secure, *timeout); err != nil { + if err := Serve(listenAddr, *timeout); err != nil { log.Fatalf("Server crashed: %v", err) } } // Serve launches HTTP server serving on listenAddr and servers a basic_auth secured directory at secure/static -func Serve(listenAddr string, secureDir bool, timeout time.Duration) error { +func Serve(listenAddr string, timeout time.Duration) error { mux := http.NewServeMux() fs := hideRootDotfiles(http.FileServer(http.Dir("."))) - switch secureDir { - case true: - if err := authInit(); err != nil { - return fmt.Errorf("failed to initialize basic auth: %v", err) - } - fmt.Printf("BasicAuth enabled using credentials in %s\n", *basicauth) - mux.Handle("/", BASICAUTH(fs)) - default: - mux.Handle("/", fs) + if creds != nil { + fs = BASICAUTH(fs) } + mux.Handle("/", fs) + svr := &http.Server{ Addr: listenAddr, Handler: mux, @@ -127,62 +116,32 @@ type AuthCreds struct { Password string `json:"password"` } -// authInit initializes the secure directory -func authInit() error { - // get creds - err := func() error { - // Read the securepass.json creds - _, err := getCreds() - return err - }() - if err != nil { - sample := &AuthCreds{Username: "example", Password: "pass123"} - d, err := json.MarshalIndent(sample, "", " ") - if err != nil { - log.Print(err) - return fmt.Errorf("internal error") - } - if err := ioutil.WriteFile(fmt.Sprintf("%s", *basicauth), d, 0644); err != nil { - return fmt.Errorf("unable to create sample file %s", *basicauth) - } +var creds *AuthCreds // written during initialization + +// authInit initializes basicauth +func authInit(bAuth string) error { + if bAuth == "" { + return nil + } + i := strings.Index(bAuth, ":") + if i < 3 && i < len(bAuth)-1 { + return fmt.Errorf("invalid basicauth flag value") + } + creds = &AuthCreds{ + Username: bAuth[:i], + Password: bAuth[i+1:], } return nil } // validBasicAuth checks the basicauth authentication credentials func validBasicAuth(r *http.Request) bool { - creds, err := getCreds() - if err != nil { + if creds == nil { return false } - s := strings.SplitN(r.Header.Get("Authorization"), " ", 2) - if len(s) != 2 { - return false - } - b, err := base64.StdEncoding.DecodeString(s[1]) - if err != nil { - return false - } - pair := strings.SplitN(string(b), ":", 2) - if len(pair) != 2 { - return false - } - return pair[0] == creds.Username && pair[1] == creds.Password -} - -// getCreds gets the current http basic credentials -func getCreds() (*AuthCreds, error) { - creds := &AuthCreds{} - sp, err := ioutil.ReadFile(*basicauth) - if err != nil { - return creds, err - } - err = json.Unmarshal(sp, &creds) - if err != nil { - return creds, err - } - if creds.Username == "" && creds.Password == "" { - return creds, fmt.Errorf("no username and password in %s", *basicauth) + u, p, ok := r.BasicAuth() + if !ok { + return ok } - return creds, nil + return u == creds.Username && p == creds.Password } diff --git a/main_test.go b/main_test.go index 44626b5..33ad425 100644 --- a/main_test.go +++ b/main_test.go @@ -11,15 +11,32 @@ import ( var fakeFSHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "fs") }) func TestValidBasicAuth(t *testing.T) { - username, password := "tester", "pass123" - req, err := http.NewRequest("GET", "/", nil) - if err != nil { - t.Fatal(err) + creds = &AuthCreds{ + Username: "test", + Password: "want12345", } - req.Header.Set("Authorization", - fmt.Sprintf("%s %s", username, base64.StdEncoding.EncodeToString([]byte(password)))) - if valid := validBasicAuth(req); valid { - t.Errorf("validbasicauth got %v want %v", valid, false) + var tests = []struct { + username string + password string + valid bool + }{ + {"tester", "abc", false}, + {"test", "want12345", true}, + } + for _, test := range tests { + req, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Errorf("failed to create request: %v", err) + continue + } + auth := []byte(test.username + ":" + test.password) + req.Header.Set("Authorization", + fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString(auth))) + + valid := validBasicAuth(req) + if test.valid != valid { + t.Errorf("expected valid %t, got %t", test.valid, valid) + } } } diff --git a/release.sh b/release.sh index 5e96f23..7383c5f 100755 --- a/release.sh +++ b/release.sh @@ -1,13 +1,8 @@ #!/bin/bash -set -e +set -e +# confirm goreleaser exists goreleaser --help 2>&1 >/dev/null -if [ $? -gt 0 ] -then - echo "FAILED: goreleaser binary not in path, you should:" - echo -e "\tgo get -u github.com/goreleaser/goreleaser" - exit 1 -fi go vet ./...