Skip to content

Commit

Permalink
Fix web server tunnels issue error for closed listener
Browse files Browse the repository at this point in the history
  • Loading branch information
NHAS committed Apr 28, 2024
1 parent fa4bde5 commit 154a912
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 39 deletions.
13 changes: 9 additions & 4 deletions commands/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,14 @@ func (g *start) Check() error {

}

func teardown() {
router.TearDown(false)
func teardown(force bool) {
router.TearDown(force)
// Tear down Unix socket
server.TearDown()

ui.Teardown()
webserver.Teardown()

}

func clusterState(noIptables bool, errorChan chan<- error) func(string) {
Expand All @@ -125,7 +128,7 @@ func clusterState(noIptables bool, errorChan chan<- error) func(string) {
if !wasDead {
log.Println("Tearing down node")

teardown()
teardown(false)

log.Println("Tear down complete")

Expand Down Expand Up @@ -216,7 +219,9 @@ func (g *start) Run() error {
log.Printf("%s starting, Ctrl + C to stop", wagType)

err = <-errorChan
teardown()

teardown(true)

if err != nil && !strings.Contains(err.Error(), "ignore me I am signal") {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion docker-test-config.json
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@
},
"Clustering": {
"ClusterState": "new",
"ETCDLogLevel": "error",
"ListenAddresses": [
"https://172.20.0.3:2380"
],
"ETCDLogLevel": "info",
"TLSManagerListenURL": "https://container2:3434"
},
"DatabaseLocation": "devices.db",
Expand Down
7 changes: 0 additions & 7 deletions internal/data/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,6 @@ var (
clusterHealthListeners = map[string]func(string){}

EventsQueue = queue.NewQueue(40)

checkState chan bool
exit chan bool
)

func RegisterEventListener[T any](path string, isPrefix bool, f func(key string, current, previous T, et EventType) error) (string, error) {
Expand Down Expand Up @@ -179,10 +176,6 @@ func checkClusterHealth() {
for {

select {

case <-exit:
notifyClusterHealthListeners("dead")
return
case <-etcdServer.Server.LeaderChangedNotify():
notifyHealthy()

Expand Down
25 changes: 15 additions & 10 deletions internal/data/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,20 +177,24 @@ func Load(path, joinToken string, testing bool) error {

log.Println("Successfully connected to etcd")

if doMigration {
// This will be kept for 2 major releases with reduced support.
// It is a no-op if a migration has already taken place
err = migrateFromSql(db)
if !etcdServer.Server.IsLearner() {

if doMigration {
// This will be kept for 2 major releases with reduced support.
// It is a no-op if a migration has already taken place
err = migrateFromSql(db)
if err != nil {
return err
}
}

// This will stay, so that the config can be used to easily spin up a new wag instance.
// After first run this will be a no-op
err = loadInitialSettings()
if err != nil {
return err
}
}

// This will stay, so that the config can be used to easily spin up a new wag instance.
// After first run this will be a no-op
err = loadInitialSettings()
if err != nil {
return err
}

go checkClusterHealth()
Expand Down Expand Up @@ -459,6 +463,7 @@ func migrateFromSql(database *sql.DB) error {

func TearDown() {
if etcdServer != nil {

etcd.Close()
etcdServer.Close()

Expand Down
64 changes: 51 additions & 13 deletions internal/webserver/web.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,35 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)

var (
tunnelHTTPServ *http.Server
tunnelTLSServ *http.Server

publicHTTPServ *http.Server
publicTLSServ *http.Server
)

func Teardown() {

if tunnelHTTPServ != nil {
tunnelHTTPServ.Close()
}

if tunnelTLSServ != nil {
tunnelTLSServ.Close()
}

if publicHTTPServ != nil {
publicHTTPServ.Close()
}

if publicTLSServ != nil {
publicTLSServ.Close()
}

log.Println("Stopped MFA portal")
}

func Start(errChan chan<- error) error {
//https://blog.cloudflare.com/exposing-go-on-the-internet/
tlsConfig := &tls.Config{
Expand Down Expand Up @@ -61,7 +90,7 @@ func Start(errChan chan<- error) error {

go func() {

srv := &http.Server{
publicTLSServ = &http.Server{
Addr: config.Values.Webserver.Public.ListenAddress,
ReadTimeout: 5 * time.Second,
WriteTimeout: 10 * time.Second,
Expand All @@ -70,7 +99,9 @@ func Start(errChan chan<- error) error {
Handler: setSecurityHeaders(public),
}

errChan <- fmt.Errorf("TLS webserver public listener failed: %v", srv.ListenAndServeTLS(config.Values.Webserver.Public.CertPath, config.Values.Webserver.Public.KeyPath))
if err := publicTLSServ.ListenAndServeTLS(config.Values.Webserver.Public.CertPath, config.Values.Webserver.Public.KeyPath); err != nil && err != http.ErrServerClosed {
errChan <- fmt.Errorf("TLS webserver public listener failed: %v", err)
}
}()

if config.Values.NumberProxies == 0 {
Expand All @@ -79,7 +110,7 @@ func Start(errChan chan<- error) error {
address, port, err := net.SplitHostPort(config.Values.Webserver.Public.ListenAddress)

if err != nil {
errChan <- fmt.Errorf("Malformed listen address for public listener: %v", err)
errChan <- fmt.Errorf("malformed listen address for public listener: %v", err)
return
}

Expand All @@ -89,29 +120,31 @@ func Start(errChan chan<- error) error {
port = ""
}

srv := &http.Server{
publicHTTPServ = &http.Server{
Addr: address + ":80",
ReadTimeout: 5 * time.Second,
WriteTimeout: 10 * time.Second,
IdleTimeout: 120 * time.Second,
Handler: setSecurityHeaders(setRedirectHandler(port)),
}

log.Printf("Creating redirection from 80/tcp to TLS webserver public listener failed: %v", srv.ListenAndServe())
log.Printf("Creating redirection from 80/tcp to TLS webserver public listener failed: %v", publicHTTPServ.ListenAndServe())
}()
}

} else {
go func() {
srv := &http.Server{
publicHTTPServ = &http.Server{
Addr: config.Values.Webserver.Public.ListenAddress,
ReadTimeout: 5 * time.Second,
WriteTimeout: 10 * time.Second,
IdleTimeout: 120 * time.Second,
Handler: setSecurityHeaders(public),
}

errChan <- fmt.Errorf("webserver public listener failed: %v", srv.ListenAndServe())
if err := publicHTTPServ.ListenAndServe(); err != nil && err != http.ErrServerClosed {
errChan <- fmt.Errorf("HTTP webserver public listener failed: %v", err)
}
}()
}

Expand Down Expand Up @@ -153,16 +186,18 @@ func Start(errChan chan<- error) error {

go func() {

srv := &http.Server{
tunnelTLSServ = &http.Server{
Addr: tunnelListenAddress,
ReadTimeout: 5 * time.Second,
WriteTimeout: 10 * time.Second,
IdleTimeout: 120 * time.Second,
TLSConfig: tlsConfig,
Handler: setSecurityHeaders(tunnel),
}
if err := tunnelTLSServ.ListenAndServeTLS(config.Values.Webserver.Tunnel.CertPath, config.Values.Webserver.Tunnel.KeyPath); err != nil && err != http.ErrServerClosed {
errChan <- fmt.Errorf("TLS webserver tunnel listener failed: %v", err)
}

errChan <- fmt.Errorf("TLS webserver tunnel listener failed: %v", srv.ListenAndServeTLS(config.Values.Webserver.Tunnel.CertPath, config.Values.Webserver.Tunnel.KeyPath))
}()

if config.Values.NumberProxies == 0 {
Expand All @@ -173,28 +208,31 @@ func Start(errChan chan<- error) error {
port = ""
}

srv := &http.Server{
tunnelHTTPServ = &http.Server{
Addr: config.Values.Wireguard.ServerAddress.String() + ":80",
ReadTimeout: 5 * time.Second,
WriteTimeout: 10 * time.Second,
IdleTimeout: 120 * time.Second,
Handler: setSecurityHeaders(setRedirectHandler(port)),
}

log.Printf("HTTP redirect to TLS webserver tunnel listener failed: %v", srv.ListenAndServe())
log.Printf("HTTP redirect to TLS webserver tunnel listener failed: %v", tunnelHTTPServ.ListenAndServe())
}()
}
} else {
go func() {
srv := &http.Server{
tunnelHTTPServ = &http.Server{
Addr: tunnelListenAddress,
ReadTimeout: 5 * time.Second,
WriteTimeout: 10 * time.Second,
IdleTimeout: 120 * time.Second,
Handler: setSecurityHeaders(tunnel),
}

errChan <- fmt.Errorf("webserver tunnel listener failed: %v", srv.ListenAndServe())
if err := tunnelHTTPServ.ListenAndServe(); err != nil && err != http.ErrServerClosed {
errChan <- fmt.Errorf("webserver tunnel listener failed: %v", err)
}

}()
}

Expand Down
32 changes: 28 additions & 4 deletions ui/ui_webserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ var (
WagVersion string

LogQueue = queue.NewQueue(40)

HTTPSServer *http.Server
HTTPServer *http.Server
)

func renderDefaults(w http.ResponseWriter, r *http.Request, model interface{}, content ...string) error {
Expand Down Expand Up @@ -347,7 +350,7 @@ func StartWebServer(errs chan<- error) error {

go func() {

srv := &http.Server{
HTTPSServer = &http.Server{
Addr: config.Values.ManagementUI.ListenAddress,
ReadTimeout: 5 * time.Second,
WriteTimeout: 10 * time.Second,
Expand All @@ -356,19 +359,24 @@ func StartWebServer(errs chan<- error) error {
Handler: setSecurityHeaders(allRoutes),
}

errs <- fmt.Errorf("TLS management listener failed: %v", srv.ListenAndServeTLS(config.Values.ManagementUI.CertPath, config.Values.ManagementUI.KeyPath))
if err := HTTPSServer.ListenAndServeTLS(config.Values.ManagementUI.CertPath, config.Values.ManagementUI.KeyPath); err != nil && err != http.ErrServerClosed {
errs <- fmt.Errorf("TLS management listener failed: %v", err)
}

}()
} else {
go func() {
srv := &http.Server{
HTTPServer = &http.Server{
Addr: config.Values.ManagementUI.ListenAddress,
ReadTimeout: 5 * time.Second,
WriteTimeout: 10 * time.Second,
IdleTimeout: 120 * time.Second,
Handler: setSecurityHeaders(allRoutes),
}
if err := HTTPServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
errs <- fmt.Errorf("webserver management listener failed: %v", HTTPServer.ListenAndServe())
}

errs <- fmt.Errorf("webserver management listener failed: %v", srv.ListenAndServe())
}()
}
}()
Expand All @@ -378,6 +386,22 @@ func StartWebServer(errs chan<- error) error {
return nil
}

func Teardown() {

if HTTPServer != nil {
HTTPServer.Close()
}

if HTTPSServer != nil {
HTTPSServer.Close()
}

if config.Values.ManagementUI.Enabled {
log.Println("Stopped MFA portal")
}

}

func changePassword(w http.ResponseWriter, r *http.Request) {

_, u := sessionManager.GetSessionFromRequest(r)
Expand Down

0 comments on commit 154a912

Please sign in to comment.