Skip to content

Commit

Permalink
Merge pull request #7004 from dolthub/aaron/sql-server-startup-rigor
Browse files Browse the repository at this point in the history
[no-release-notes] cmd/dolt/commands/sqlserver: Restructure the start up sequence for sql-server.
  • Loading branch information
reltuk authored Nov 16, 2023
2 parents 184a317 + f1b915a commit 820686b
Show file tree
Hide file tree
Showing 11 changed files with 1,154 additions and 453 deletions.
727 changes: 449 additions & 278 deletions go/cmd/dolt/commands/sqlserver/server.go

Large diffs are not rendered by default.

66 changes: 32 additions & 34 deletions go/cmd/dolt/commands/sqlserver/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/sqle"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/dolt/go/libraries/utils/config"
"github.com/dolthub/dolt/go/libraries/utils/svcs"
)

//TODO: server tests need to expose a higher granularity for server interactions:
Expand Down Expand Up @@ -60,7 +61,7 @@ var (
)

func TestServerArgs(t *testing.T) {
serverController := NewServerController()
controller := svcs.NewController()
dEnv, err := sqle.CreateEnvWithSeedData()
require.NoError(t, err)
defer func() {
Expand All @@ -75,16 +76,16 @@ func TestServerArgs(t *testing.T) {
"-t", "5",
"-l", "info",
"-r",
}, dEnv, serverController)
}, dEnv, controller)
}()
err = serverController.WaitForStart()
err = controller.WaitForStart()
require.NoError(t, err)
conn, err := dbr.Open("mysql", "username:password@tcp(localhost:15200)/", nil)
require.NoError(t, err)
err = conn.Close()
require.NoError(t, err)
serverController.StopServer()
err = serverController.WaitForClose()
controller.Stop()
err = controller.WaitForStop()
assert.NoError(t, err)
}

Expand All @@ -110,22 +111,22 @@ listener:
defer func() {
assert.NoError(t, dEnv.DoltDB.Close())
}()
serverController := NewServerController()
controller := svcs.NewController()
go func() {

dEnv.FS.WriteFile("config.yaml", []byte(yamlConfig), os.ModePerm)
startServer(context.Background(), "0.0.0", "dolt sql-server", []string{
"--config", "config.yaml",
}, dEnv, serverController)
}, dEnv, controller)
}()
err = serverController.WaitForStart()
err = controller.WaitForStart()
require.NoError(t, err)
conn, err := dbr.Open("mysql", "username:password@tcp(localhost:15200)/", nil)
require.NoError(t, err)
err = conn.Close()
require.NoError(t, err)
serverController.StopServer()
err = serverController.WaitForClose()
controller.Stop()
err = controller.WaitForStop()
assert.NoError(t, err)
}

Expand All @@ -145,18 +146,15 @@ func TestServerBadArgs(t *testing.T) {
}

for _, test := range tests {
test := test
t.Run(strings.Join(test, " "), func(t *testing.T) {
serverController := NewServerController()
go func(serverController *ServerController) {
startServer(context.Background(), "test", "dolt sql-server", test, env, serverController)
}(serverController)

// In the event that a test fails, we need to prevent a test from hanging due to a running server
err := serverController.WaitForStart()
require.Error(t, err)
serverController.StopServer()
err = serverController.WaitForClose()
assert.NoError(t, err)
controller := svcs.NewController()
go func() {
startServer(context.Background(), "test", "dolt sql-server", test, env, controller)
}()
if !assert.Error(t, controller.WaitForStart()) {
controller.Stop()
}
})
}
}
Expand Down Expand Up @@ -186,8 +184,8 @@ func TestServerGoodParams(t *testing.T) {

for _, test := range tests {
t.Run(ConfigInfo(test), func(t *testing.T) {
sc := NewServerController()
go func(config ServerConfig, sc *ServerController) {
sc := svcs.NewController()
go func(config ServerConfig, sc *svcs.Controller) {
_, _ = Serve(context.Background(), "0.0.0", config, sc, env)
}(test, sc)
err := sc.WaitForStart()
Expand All @@ -196,8 +194,8 @@ func TestServerGoodParams(t *testing.T) {
require.NoError(t, err)
err = conn.Close()
require.NoError(t, err)
sc.StopServer()
err = sc.WaitForClose()
sc.Stop()
err = sc.WaitForStop()
assert.NoError(t, err)
})
}
Expand All @@ -212,8 +210,8 @@ func TestServerSelect(t *testing.T) {

serverConfig := DefaultServerConfig().withLogLevel(LogLevel_Fatal).WithPort(15300)

sc := NewServerController()
defer sc.StopServer()
sc := svcs.NewController()
defer sc.Stop()
go func() {
_, _ = Serve(context.Background(), "0.0.0", serverConfig, sc, env)
}()
Expand Down Expand Up @@ -261,7 +259,7 @@ func TestServerSelect(t *testing.T) {

// If a port is already in use, throw error "Port XXXX already in use."
func TestServerFailsIfPortInUse(t *testing.T) {
serverController := NewServerController()
controller := svcs.NewController()
server := &http.Server{
Addr: ":15200",
Handler: http.DefaultServeMux,
Expand All @@ -287,10 +285,10 @@ func TestServerFailsIfPortInUse(t *testing.T) {
"-t", "5",
"-l", "info",
"-r",
}, dEnv, serverController)
}, dEnv, controller)
}()

err = serverController.WaitForStart()
err = controller.WaitForStart()
require.Error(t, err)
server.Close()
wg.Wait()
Expand All @@ -311,8 +309,8 @@ func TestServerSetDefaultBranch(t *testing.T) {

serverConfig := DefaultServerConfig().withLogLevel(LogLevel_Fatal).WithPort(15302)

sc := NewServerController()
defer sc.StopServer()
sc := svcs.NewController()
defer sc.Stop()
go func() {
_, _ = Serve(context.Background(), "0.0.0", serverConfig, sc, dEnv)
}()
Expand Down Expand Up @@ -470,7 +468,7 @@ func TestReadReplica(t *testing.T) {
dsess.InitPersistedSystemVars(multiSetup.GetEnv(readReplicaDbName))

// start server as read replica
sc := NewServerController()
sc := svcs.NewController()
serverConfig := DefaultServerConfig().withLogLevel(LogLevel_Fatal).WithPort(15303)

// set socket to nil to force tcp
Expand All @@ -482,7 +480,7 @@ func TestReadReplica(t *testing.T) {
require.NoError(t, err)
}()
require.NoError(t, sc.WaitForStart())
defer sc.StopServer()
defer sc.Stop()

replicatedTable := "new_table"
multiSetup.CreateTable(ctx, sourceDbName, replicatedTable)
Expand Down
98 changes: 0 additions & 98 deletions go/cmd/dolt/commands/sqlserver/servercontroller.go

This file was deleted.

13 changes: 7 additions & 6 deletions go/cmd/dolt/commands/sqlserver/sqlclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import (
"github.com/dolthub/dolt/go/libraries/utils/argparser"
"github.com/dolthub/dolt/go/libraries/utils/filesys"
"github.com/dolthub/dolt/go/libraries/utils/iohelp"
"github.com/dolthub/dolt/go/libraries/utils/svcs"
)

const (
Expand Down Expand Up @@ -107,7 +108,7 @@ func (cmd SqlClientCmd) Exec(ctx context.Context, commandStr string, args []stri

apr := cli.ParseArgsOrDie(ap, args, help)
var serverConfig ServerConfig
var serverController *ServerController
var svcsController *svcs.Controller
var err error

cli.Println(color.YellowString("WARNING: This command is being deprecated and is not recommended for general use.\n" +
Expand Down Expand Up @@ -149,11 +150,11 @@ func (cmd SqlClientCmd) Exec(ctx context.Context, commandStr string, args []stri
}
cli.PrintErrf("Starting server with Config %v\n", ConfigInfo(serverConfig))

serverController = NewServerController()
svcsController = svcs.NewController()
go func() {
_, _ = Serve(ctx, cmd.VersionStr, serverConfig, serverController, dEnv)
_, _ = Serve(ctx, cmd.VersionStr, serverConfig, svcsController, dEnv)
}()
err = serverController.WaitForStart()
err = svcsController.WaitForStart()
if err != nil {
cli.PrintErrln(err.Error())
return 1
Expand Down Expand Up @@ -376,8 +377,8 @@ func (cmd SqlClientCmd) Exec(ctx context.Context, commandStr string, args []stri
cli.PrintErrln(err.Error())
}
if apr.Contains(sqlClientDualFlag) {
serverController.StopServer()
err = serverController.WaitForClose()
svcsController.Stop()
err = svcsController.WaitForStop()
if err != nil {
cli.PrintErrln(err.Error())
}
Expand Down
31 changes: 15 additions & 16 deletions go/cmd/dolt/commands/sqlserver/sqlserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/env"
"github.com/dolthub/dolt/go/libraries/utils/argparser"
"github.com/dolthub/dolt/go/libraries/utils/filesys"
"github.com/dolthub/dolt/go/libraries/utils/svcs"
)

const (
Expand Down Expand Up @@ -186,12 +187,20 @@ func (cmd SqlServerCmd) RequiresRepo() bool {

// Exec executes the command
func (cmd SqlServerCmd) Exec(ctx context.Context, commandStr string, args []string, dEnv *env.DoltEnv, cliCtx cli.CliContext) int {
controller := NewServerController()
controller := svcs.NewController()
newCtx, cancelF := context.WithCancel(context.Background())
go func() {
<-ctx.Done()
controller.StopServer()
cancelF()
// Here we only forward along the SIGINT if the server starts
// up successfully. If the service does not start up
// successfully, or if WaitForStart() blocks indefinitely, then
// startServer() should have returned an error and we do not
// need to Stop the running server or deal with our canceled
// parent context.
if controller.WaitForStart() == nil {
<-ctx.Done()
controller.Stop()
cancelF()
}
}()
return startServer(newCtx, cmd.VersionStr, commandStr, args, dEnv, controller)
}
Expand All @@ -208,7 +217,7 @@ func validateSqlServerArgs(apr *argparser.ArgParseResults) error {
return nil
}

func startServer(ctx context.Context, versionStr, commandStr string, args []string, dEnv *env.DoltEnv, serverController *ServerController) int {
func startServer(ctx context.Context, versionStr, commandStr string, args []string, dEnv *env.DoltEnv, controller *svcs.Controller) int {
ap := SqlServerCmd{}.ArgParser()
help, _ := cli.HelpAndUsagePrinters(cli.CommandDocsForCommandString(commandStr, sqlServerDocs, ap))

Expand All @@ -222,29 +231,19 @@ func startServer(ctx context.Context, versionStr, commandStr string, args []stri
}
serverConfig, err := GetServerConfig(dEnv.FS, apr)
if err != nil {
if serverController != nil {
serverController.StopServer()
serverController.serverStopped(err)
}

cli.PrintErrln(color.RedString("Failed to start server. Bad Configuration"))
cli.PrintErrln(err.Error())
return 1
}
if err = SetupDoltConfig(dEnv, apr, serverConfig); err != nil {
if serverController != nil {
serverController.StopServer()
serverController.serverStopped(err)
}

cli.PrintErrln(color.RedString("Failed to start server. Bad Configuration"))
cli.PrintErrln(err.Error())
return 1
}

cli.PrintErrf("Starting server with Config %v\n", ConfigInfo(serverConfig))

if startError, closeError := Serve(ctx, versionStr, serverConfig, serverController, dEnv); startError != nil || closeError != nil {
if startError, closeError := Serve(ctx, versionStr, serverConfig, controller, dEnv); startError != nil || closeError != nil {
if startError != nil {
cli.PrintErrln(startError)
}
Expand Down
Loading

0 comments on commit 820686b

Please sign in to comment.