diff --git a/go/libraries/doltcore/sqle/dsess/autoincrement_tracker.go b/go/libraries/doltcore/sqle/dsess/autoincrement_tracker.go index 10415505e1..18aba69e95 100644 --- a/go/libraries/doltcore/sqle/dsess/autoincrement_tracker.go +++ b/go/libraries/doltcore/sqle/dsess/autoincrement_tracker.go @@ -16,10 +16,12 @@ package dsess import ( "context" + "errors" "io" "math" "strings" "sync" + "time" "github.com/dolthub/go-mysql-server/sql" gmstypes "github.com/dolthub/go-mysql-server/sql/types" @@ -48,6 +50,8 @@ type AutoIncrementTracker struct { sequences *sync.Map // map[string]uint64 mm *mutexmap.MutexMap lockMode LockMode + init chan struct{} + initErr error } var _ globalstate.AutoIncrementTracker = &AutoIncrementTracker{} @@ -61,8 +65,9 @@ func NewAutoIncrementTracker(ctx context.Context, dbName string, roots ...doltdb dbName: dbName, sequences: &sync.Map{}, mm: mutexmap.NewMutexMap(), + init: make(chan struct{}), } - ait.InitWithRoots(ctx, roots...) + ait.runInitWithRootsAsync(ctx, roots...) return &ait, nil } @@ -76,13 +81,22 @@ func loadAutoIncValue(sequences *sync.Map, tableName string) uint64 { } // Current returns the next value to be generated in the auto increment sequence for the table named -func (a *AutoIncrementTracker) Current(tableName string) uint64 { - return loadAutoIncValue(a.sequences, tableName) +func (a *AutoIncrementTracker) Current(tableName string) (uint64, error) { + err := a.waitForInit() + if err != nil { + return 0, err + } + return loadAutoIncValue(a.sequences, tableName), nil } // Next returns the next auto increment value for the table named using the provided value from an insert (which may // be null or 0, in which case it will be generated from the sequence). func (a *AutoIncrementTracker) Next(tbl string, insertVal interface{}) (uint64, error) { + err := a.waitForInit() + if err != nil { + return 0, err + } + tbl = strings.ToLower(tbl) given, err := CoerceAutoIncrementValue(insertVal) @@ -113,6 +127,10 @@ func (a *AutoIncrementTracker) Next(tbl string, insertVal interface{}) (uint64, } func (a *AutoIncrementTracker) CoerceAutoIncrementValue(val interface{}) (uint64, error) { + err := a.waitForInit() + if err != nil { + return 0, err + } return CoerceAutoIncrementValue(val) } @@ -140,6 +158,11 @@ func CoerceAutoIncrementValue(val interface{}) (uint64, error) { // table. Otherwise, the update is silently disregarded. So far this matches the MySQL behavior, but Dolt uses the // maximum value for this table across all branches. func (a *AutoIncrementTracker) Set(ctx *sql.Context, tableName string, table *doltdb.Table, ws ref.WorkingSetRef, newAutoIncVal uint64) (*doltdb.Table, error) { + err := a.waitForInit() + if err != nil { + return nil, err + } + tableName = strings.ToLower(tableName) release := a.mm.Lock(tableName) @@ -338,16 +361,27 @@ func getMaxIndexValue(ctx context.Context, indexData durable.Index) (uint64, err } // AddNewTable initializes a new table with an auto increment column to the tracker, as necessary -func (a *AutoIncrementTracker) AddNewTable(tableName string) { +func (a *AutoIncrementTracker) AddNewTable(tableName string) error { + err := a.waitForInit() + if err != nil { + return err + } + tableName = strings.ToLower(tableName) // only initialize the sequence for this table if no other branch has such a table a.sequences.LoadOrStore(tableName, uint64(1)) + return nil } // DropTable drops the table with the name given. // To establish the new auto increment value, callers must also pass all other working sets in scope that may include // a table with the same name, omitting the working set that just deleted the table named. func (a *AutoIncrementTracker) DropTable(ctx *sql.Context, tableName string, wses ...*doltdb.WorkingSet) error { + err := a.waitForInit() + if err != nil { + return err + } + tableName = strings.ToLower(tableName) release := a.mm.Lock(tableName) @@ -389,6 +423,11 @@ func (a *AutoIncrementTracker) DropTable(ctx *sql.Context, tableName string, wse } func (a *AutoIncrementTracker) AcquireTableLock(ctx *sql.Context, tableName string) (func(), error) { + err := a.waitForInit() + if err != nil { + return nil, err + } + _, i, _ := sql.SystemVariables.GetGlobal("innodb_autoinc_lock_mode") lockMode := LockMode(i.(int64)) if lockMode == LockMode_Interleaved { @@ -398,7 +437,23 @@ func (a *AutoIncrementTracker) AcquireTableLock(ctx *sql.Context, tableName stri return a.mm.Lock(tableName), nil } -func (a *AutoIncrementTracker) InitWithRoots(ctx context.Context, roots ...doltdb.Rootish) error { +func (a *AutoIncrementTracker) waitForInit() error { + select { + case <-a.init: + return a.initErr + case <-time.After(5 * time.Minute): + return errors.New("failed to initialize autoincrement tracker") + } +} + +func (a *AutoIncrementTracker) runInitWithRootsAsync(ctx context.Context, roots ...doltdb.Rootish) { + go func() { + defer close(a.init) + a.initErr = a.initWithRoots(ctx, roots...) + }() +} + +func (a *AutoIncrementTracker) initWithRoots(ctx context.Context, roots ...doltdb.Rootish) error { eg, egCtx := errgroup.WithContext(ctx) eg.SetLimit(128) @@ -435,3 +490,13 @@ func (a *AutoIncrementTracker) InitWithRoots(ctx context.Context, roots ...doltd return eg.Wait() } + +func (a *AutoIncrementTracker) InitWithRoots(ctx context.Context, roots ...doltdb.Rootish) error { + err := a.waitForInit() + if err != nil { + return err + } + a.init = make(chan struct{}) + a.runInitWithRootsAsync(ctx, roots...) + return a.waitForInit() +} diff --git a/go/libraries/doltcore/sqle/globalstate/auto_increment_tracker.go b/go/libraries/doltcore/sqle/globalstate/auto_increment_tracker.go index 3c92f6e0a0..87be501d97 100644 --- a/go/libraries/doltcore/sqle/globalstate/auto_increment_tracker.go +++ b/go/libraries/doltcore/sqle/globalstate/auto_increment_tracker.go @@ -27,11 +27,11 @@ import ( // interface here because implementations need to reach into session state, requiring a dependency on this package. type AutoIncrementTracker interface { // Current returns the current auto increment value for the given table. - Current(tableName string) uint64 + Current(tableName string) (uint64, error) // Next returns the next auto increment value for the given table, and increments the current value. Next(tbl string, insertVal interface{}) (uint64, error) // AddNewTable adds a new table to the tracker, initializing the auto increment value to 1. - AddNewTable(tableName string) + AddNewTable(tableName string) error // DropTable removes a table from the tracker. DropTable(ctx *sql.Context, tableName string, wses ...*doltdb.WorkingSet) error // CoerceAutoIncrementValue coerces the given value to a uint64, returning an error if it can't be done. diff --git a/go/libraries/doltcore/sqle/writer/noms_write_session.go b/go/libraries/doltcore/sqle/writer/noms_write_session.go index c0e721fea1..a538774735 100644 --- a/go/libraries/doltcore/sqle/writer/noms_write_session.go +++ b/go/libraries/doltcore/sqle/writer/noms_write_session.go @@ -179,7 +179,10 @@ func (s *nomsWriteSession) flush(ctx *sql.Context) (*doltdb.WorkingSet, error) { // Update the auto increment value for the table if a tracker was provided // TODO: the table probably needs an autoincrement tracker no matter what if schema.HasAutoIncrement(ed.Schema()) { - v := s.aiTracker.Current(name) + v, err := s.aiTracker.Current(name) + if err != nil { + return err + } tbl, err = tbl.SetAutoIncrementValue(ctx, v) if err != nil { return err diff --git a/go/libraries/doltcore/sqle/writer/prolly_write_session.go b/go/libraries/doltcore/sqle/writer/prolly_write_session.go index a73a789699..0e86aeb356 100644 --- a/go/libraries/doltcore/sqle/writer/prolly_write_session.go +++ b/go/libraries/doltcore/sqle/writer/prolly_write_session.go @@ -165,7 +165,10 @@ func (s *prollyWriteSession) flush(ctx *sql.Context, autoIncSet bool, manualAuto // override was specified (e.g. if the next value was set explicitly) if schema.HasAutoIncrement(wr.sch) { // TODO: need schema name for auto increment - autoIncVal := s.aiTracker.Current(name.Name) + autoIncVal, err := s.aiTracker.Current(name.Name) + if err != nil { + return err + } override, hasManuallySetAi := manualAutoIncrementsSettings[name.Name] if hasManuallySetAi { autoIncVal = override