Skip to content

Commit

Permalink
Merge pull request #273 from ae-govau/gracefulshutdown
Browse files Browse the repository at this point in the history
fix: shutdown gracefully on TERM or INT signals
  • Loading branch information
MichaelEischer authored Feb 8, 2024
2 parents b28bea1 + ce15402 commit 13083ac
Show file tree
Hide file tree
Showing 3 changed files with 259 additions and 87 deletions.
7 changes: 7 additions & 0 deletions changelog/unreleased/pull-273
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Change: Server is now shutdown cleanly on TERM or INT signals

Server now listens for TERM and INT signals and cleanly closes down the http.Server and listener.

This is particularly useful when listening on a unix socket, as the server will remove the socket file from it shuts down.

https://github.com/restic/rest-server/pull/273
199 changes: 118 additions & 81 deletions cmd/rest-server/main.go
Original file line number Diff line number Diff line change
@@ -1,159 +1,196 @@
package main

import (
"context"
"errors"
"fmt"
"log"
"net"
"net/http"
"os"
"os/signal"
"path/filepath"
"runtime"
"runtime/pprof"
"sync"
"syscall"

restserver "github.com/restic/rest-server"
"github.com/spf13/cobra"
)

// cmdRoot is the base command when no other command has been specified.
var cmdRoot = &cobra.Command{
Use: "rest-server",
Short: "Run a REST server for use with restic",
SilenceErrors: true,
SilenceUsage: true,
RunE: runRoot,
Args: func(cmd *cobra.Command, args []string) error {
if len(args) != 0 {
return fmt.Errorf("rest-server expects no arguments - unknown argument: %s", args[0])
}
return nil
},
Version: fmt.Sprintf("rest-server %s compiled with %v on %v/%v\n", version, runtime.Version(), runtime.GOOS, runtime.GOARCH),
}
type restServerApp struct {
CmdRoot *cobra.Command
Server restserver.Server
CpuProfile string

var server = restserver.Server{
Path: filepath.Join(os.TempDir(), "restic"),
Listen: ":8000",
listenerAddressMu sync.Mutex
listenerAddress net.Addr // set after startup
}

var (
cpuProfile string
)

func init() {
flags := cmdRoot.Flags()
flags.StringVar(&cpuProfile, "cpu-profile", cpuProfile, "write CPU profile to file")
flags.BoolVar(&server.Debug, "debug", server.Debug, "output debug messages")
flags.StringVar(&server.Listen, "listen", server.Listen, "listen address")
flags.StringVar(&server.Log, "log", server.Log, "write HTTP requests in the combined log format to the specified `filename` (use \"-\" for logging to stdout)")
flags.Int64Var(&server.MaxRepoSize, "max-size", server.MaxRepoSize, "the maximum size of the repository in bytes")
flags.StringVar(&server.Path, "path", server.Path, "data directory")
flags.BoolVar(&server.TLS, "tls", server.TLS, "turn on TLS support")
flags.StringVar(&server.TLSCert, "tls-cert", server.TLSCert, "TLS certificate path")
flags.StringVar(&server.TLSKey, "tls-key", server.TLSKey, "TLS key path")
flags.BoolVar(&server.NoAuth, "no-auth", server.NoAuth, "disable .htpasswd authentication")
flags.StringVar(&server.HtpasswdPath, "htpasswd-file", server.HtpasswdPath, "location of .htpasswd file (default: \"<data directory>/.htpasswd)\"")
flags.BoolVar(&server.NoVerifyUpload, "no-verify-upload", server.NoVerifyUpload,
// cmdRoot is the base command when no other command has been specified.
func newRestServerApp() *restServerApp {
rv := &restServerApp{
CmdRoot: &cobra.Command{
Use: "rest-server",
Short: "Run a REST server for use with restic",
SilenceErrors: true,
SilenceUsage: true,
Args: func(cmd *cobra.Command, args []string) error {
if len(args) != 0 {
return fmt.Errorf("rest-server expects no arguments - unknown argument: %s", args[0])
}
return nil
},
Version: fmt.Sprintf("rest-server %s compiled with %v on %v/%v\n", version, runtime.Version(), runtime.GOOS, runtime.GOARCH),
},
Server: restserver.Server{
Path: filepath.Join(os.TempDir(), "restic"),
Listen: ":8000",
},
}
rv.CmdRoot.RunE = rv.runRoot
flags := rv.CmdRoot.Flags()

flags.StringVar(&rv.CpuProfile, "cpu-profile", rv.CpuProfile, "write CPU profile to file")
flags.BoolVar(&rv.Server.Debug, "debug", rv.Server.Debug, "output debug messages")
flags.StringVar(&rv.Server.Listen, "listen", rv.Server.Listen, "listen address")
flags.StringVar(&rv.Server.Log, "log", rv.Server.Log, "write HTTP requests in the combined log format to the specified `filename` (use \"-\" for logging to stdout)")
flags.Int64Var(&rv.Server.MaxRepoSize, "max-size", rv.Server.MaxRepoSize, "the maximum size of the repository in bytes")
flags.StringVar(&rv.Server.Path, "path", rv.Server.Path, "data directory")
flags.BoolVar(&rv.Server.TLS, "tls", rv.Server.TLS, "turn on TLS support")
flags.StringVar(&rv.Server.TLSCert, "tls-cert", rv.Server.TLSCert, "TLS certificate path")
flags.StringVar(&rv.Server.TLSKey, "tls-key", rv.Server.TLSKey, "TLS key path")
flags.BoolVar(&rv.Server.NoAuth, "no-auth", rv.Server.NoAuth, "disable .htpasswd authentication")
flags.StringVar(&rv.Server.HtpasswdPath, "htpasswd-file", rv.Server.HtpasswdPath, "location of .htpasswd file (default: \"<data directory>/.htpasswd)\"")
flags.BoolVar(&rv.Server.NoVerifyUpload, "no-verify-upload", rv.Server.NoVerifyUpload,
"do not verify the integrity of uploaded data. DO NOT enable unless the rest-server runs on a very low-power device")
flags.BoolVar(&server.AppendOnly, "append-only", server.AppendOnly, "enable append only mode")
flags.BoolVar(&server.PrivateRepos, "private-repos", server.PrivateRepos, "users can only access their private repo")
flags.BoolVar(&server.Prometheus, "prometheus", server.Prometheus, "enable Prometheus metrics")
flags.BoolVar(&server.PrometheusNoAuth, "prometheus-no-auth", server.PrometheusNoAuth, "disable auth for Prometheus /metrics endpoint")
flags.BoolVar(&rv.Server.AppendOnly, "append-only", rv.Server.AppendOnly, "enable append only mode")
flags.BoolVar(&rv.Server.PrivateRepos, "private-repos", rv.Server.PrivateRepos, "users can only access their private repo")
flags.BoolVar(&rv.Server.Prometheus, "prometheus", rv.Server.Prometheus, "enable Prometheus metrics")
flags.BoolVar(&rv.Server.PrometheusNoAuth, "prometheus-no-auth", rv.Server.PrometheusNoAuth, "disable auth for Prometheus /metrics endpoint")

return rv
}

var version = "0.12.1-dev"

func tlsSettings() (bool, string, string, error) {
func (app *restServerApp) tlsSettings() (bool, string, string, error) {
var key, cert string
if !server.TLS && (server.TLSKey != "" || server.TLSCert != "") {
if !app.Server.TLS && (app.Server.TLSKey != "" || app.Server.TLSCert != "") {
return false, "", "", errors.New("requires enabled TLS")
} else if !server.TLS {
} else if !app.Server.TLS {
return false, "", "", nil
}
if server.TLSKey != "" {
key = server.TLSKey
if app.Server.TLSKey != "" {
key = app.Server.TLSKey
} else {
key = filepath.Join(server.Path, "private_key")
key = filepath.Join(app.Server.Path, "private_key")
}
if server.TLSCert != "" {
cert = server.TLSCert
if app.Server.TLSCert != "" {
cert = app.Server.TLSCert
} else {
cert = filepath.Join(server.Path, "public_key")
cert = filepath.Join(app.Server.Path, "public_key")
}
return server.TLS, key, cert, nil
return app.Server.TLS, key, cert, nil
}

func runRoot(cmd *cobra.Command, args []string) error {
// returns the address that the app is listening on.
// returns nil if the application hasn't finished starting yet
func (app *restServerApp) ListenerAddress() net.Addr {
app.listenerAddressMu.Lock()
defer app.listenerAddressMu.Unlock()
return app.listenerAddress
}

func (app *restServerApp) runRoot(cmd *cobra.Command, args []string) error {
log.SetFlags(0)

log.Printf("Data directory: %s", server.Path)
log.Printf("Data directory: %s", app.Server.Path)

if cpuProfile != "" {
f, err := os.Create(cpuProfile)
if app.CpuProfile != "" {
f, err := os.Create(app.CpuProfile)
if err != nil {
return err
}
defer f.Close()

if err := pprof.StartCPUProfile(f); err != nil {
return err
}
log.Println("CPU profiling enabled")
defer pprof.StopCPUProfile()

// clean profiling shutdown on sigint
sigintCh := make(chan os.Signal, 1)
go func() {
for range sigintCh {
pprof.StopCPUProfile()
log.Println("Stopped CPU profiling")
err := f.Close()
if err != nil {
log.Printf("error closing CPU profile file: %v", err)
}
os.Exit(130)
}
}()
signal.Notify(sigintCh, syscall.SIGINT)
log.Println("CPU profiling enabled")
defer log.Println("Stopped CPU profiling")
}

if server.NoAuth {
if app.Server.NoAuth {
log.Println("Authentication disabled")
} else {
log.Println("Authentication enabled")
}

handler, err := restserver.NewHandler(&server)
handler, err := restserver.NewHandler(&app.Server)
if err != nil {
log.Fatalf("error: %v", err)
}

if server.PrivateRepos {
if app.Server.PrivateRepos {
log.Println("Private repositories enabled")
} else {
log.Println("Private repositories disabled")
}

enabledTLS, privateKey, publicKey, err := tlsSettings()
enabledTLS, privateKey, publicKey, err := app.tlsSettings()
if err != nil {
return err
}

listener, err := findListener(server.Listen)
listener, err := findListener(app.Server.Listen)
if err != nil {
return fmt.Errorf("unable to listen: %w", err)
}

if !enabledTLS {
err = http.Serve(listener, handler)
} else {
log.Printf("TLS enabled, private key %s, pubkey %v", privateKey, publicKey)
err = http.ServeTLS(listener, handler, publicKey, privateKey)
// set listener address, this is useful for tests
app.listenerAddressMu.Lock()
app.listenerAddress = listener.Addr()
app.listenerAddressMu.Unlock()

srv := &http.Server{
Handler: handler,
}

// run server in background
go func() {
if !enabledTLS {
err = srv.Serve(listener)
} else {
log.Printf("TLS enabled, private key %s, pubkey %v", privateKey, publicKey)
err = srv.ServeTLS(listener, publicKey, privateKey)
}
if err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("listen and serve returned err: %v", err)
}
}()

// wait until done
<-app.CmdRoot.Context().Done()

// gracefully shutdown server
if err := srv.Shutdown(context.Background()); err != nil {
return fmt.Errorf("server shutdown returned an err: %w", err)
}

return err
log.Println("shutdown cleanly")
return nil
}

func main() {
if err := cmdRoot.Execute(); err != nil {
// create context to be notified on interrupt or term signal so that we can shutdown cleanly
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer stop()

if err := newRestServerApp().CmdRoot.ExecuteContext(ctx); err != nil {
log.Fatalf("error: %v", err)
}
}
Loading

0 comments on commit 13083ac

Please sign in to comment.