From 92ecfffa4f139c223b9b185f29150755abc8b64b Mon Sep 17 00:00:00 2001 From: Mike Fridman Date: Sat, 14 Oct 2023 23:27:51 -0400 Subject: [PATCH 1/2] add migration logic --- internal/provider/collect.go | 71 +- internal/provider/collect_test.go | 70 +- internal/provider/errors.go | 39 + internal/provider/migration.go | 91 +- internal/provider/misc.go | 39 + internal/provider/provider.go | 138 +- internal/provider/provider_options.go | 40 +- internal/provider/provider_options_test.go | 2 +- internal/provider/provider_test.go | 71 +- internal/provider/run.go | 386 ++++- internal/provider/run_down.go | 53 + internal/provider/run_test.go | 1288 +++++++++++++++++ internal/provider/run_up.go | 96 ++ .../no-versioning/migrations/00001_a.sql | 8 + .../no-versioning/migrations/00002_b.sql | 9 + .../no-versioning/migrations/00003_c.sql | 9 + .../testdata/no-versioning/seed/00001_a.sql | 17 + .../testdata/no-versioning/seed/00002_b.sql | 15 + internal/provider/types.go | 99 ++ internal/sqlparser/parse.go | 7 +- testdata/migrations/00002_posts_table.sql | 2 + 21 files changed, 2323 insertions(+), 227 deletions(-) create mode 100644 internal/provider/errors.go create mode 100644 internal/provider/misc.go create mode 100644 internal/provider/run_down.go create mode 100644 internal/provider/run_test.go create mode 100644 internal/provider/run_up.go create mode 100644 internal/provider/testdata/no-versioning/migrations/00001_a.sql create mode 100644 internal/provider/testdata/no-versioning/migrations/00002_b.sql create mode 100644 internal/provider/testdata/no-versioning/migrations/00003_c.sql create mode 100644 internal/provider/testdata/no-versioning/seed/00001_a.sql create mode 100644 internal/provider/testdata/no-versioning/seed/00002_b.sql create mode 100644 internal/provider/types.go diff --git a/internal/provider/collect.go b/internal/provider/collect.go index 6658c8067..fd7d63e75 100644 --- a/internal/provider/collect.go +++ b/internal/provider/collect.go @@ -4,30 +4,14 @@ import ( "errors" "fmt" "io/fs" + "os" "path/filepath" "sort" + "strconv" "strings" - - "github.com/pressly/goose/v3" ) -// Source represents a single migration source. -// -// For SQL migrations, Fullpath will always be set. For Go migrations, Fullpath will will be set if -// the migration has a corresponding file on disk. It will be empty if the migration was registered -// manually. -type Source struct { - // Type is the type of migration. - Type MigrationType - // Full path to the migration file. - // - // Example: /path/to/migrations/001_create_users_table.sql - Fullpath string - // Version is the version of the migration. - Version int64 -} - -func newSource(t MigrationType, fullpath string, version int64) Source { +func NewSource(t MigrationType, fullpath string, version int64) Source { return Source{ Type: t, Fullpath: fullpath, @@ -41,6 +25,7 @@ type fileSources struct { goSources []Source } +// TODO(mf): remove? func (s *fileSources) lookup(t MigrationType, version int64) *Source { switch t { case TypeGo: @@ -93,7 +78,7 @@ func collectFileSources(fsys fs.FS, strict bool, excludes map[string]bool) (*fil // filenames, but still have versioned migrations within the same directory. For // example, a user could have a helpers.go file which contains unexported helper // functions for migrations. - version, err := goose.NumericComponent(base) + version, err := NumericComponent(base) if err != nil { if strict { return nil, fmt.Errorf("failed to parse numeric component from %q: %w", base, err) @@ -110,9 +95,9 @@ func collectFileSources(fsys fs.FS, strict bool, excludes map[string]bool) (*fil } switch filepath.Ext(base) { case ".sql": - sources.sqlSources = append(sources.sqlSources, newSource(TypeSQL, fullpath, version)) + sources.sqlSources = append(sources.sqlSources, NewSource(TypeSQL, fullpath, version)) case ".go": - sources.goSources = append(sources.goSources, newSource(TypeGo, fullpath, version)) + sources.goSources = append(sources.goSources, NewSource(TypeGo, fullpath, version)) default: // Should never happen since we already filtered out all other file types. return nil, fmt.Errorf("unknown migration type: %s", base) @@ -161,12 +146,15 @@ func merge(sources *fileSources, registerd map[int64]*goMigration) ([]*migration // wholesale as part of migrations. This allows users to build a custom binary that only embeds // the SQL migration files. for version, r := range registerd { - var fullpath string - if s := sources.lookup(TypeGo, version); s != nil { - fullpath = s.Fullpath + fullpath := r.fullpath + if fullpath == "" { + if s := sources.lookup(TypeGo, version); s != nil { + fullpath = s.Fullpath + } } // Ensure there are no duplicate versions. if existing, ok := migrationLookup[version]; ok { + fullpath := r.fullpath if fullpath == "" { fullpath = "manually registered (no source)" } @@ -178,7 +166,7 @@ func merge(sources *fileSources, registerd map[int64]*goMigration) ([]*migration } m := &migration{ // Note, the fullpath may be empty if the migration was registered manually. - Source: newSource(TypeGo, fullpath, version), + Source: NewSource(TypeGo, fullpath, version), Go: r, } migrations = append(migrations, m) @@ -211,3 +199,34 @@ func unregisteredError(unregistered []string) error { return errors.New(b.String()) } + +type noopFS struct{} + +var _ fs.FS = noopFS{} + +func (f noopFS) Open(name string) (fs.File, error) { + return nil, os.ErrNotExist +} + +// NumericComponent parses the version from the migration file name. +// +// XXX_descriptivename.ext where XXX specifies the version number and ext specifies the type of +// migration, either .sql or .go. +func NumericComponent(filename string) (int64, error) { + base := filepath.Base(filename) + if ext := filepath.Ext(base); ext != ".go" && ext != ".sql" { + return 0, errors.New("migration file does not have .sql or .go file extension") + } + idx := strings.Index(base, "_") + if idx < 0 { + return 0, errors.New("no filename separator '_' found") + } + n, err := strconv.ParseInt(base[:idx], 10, 64) + if err != nil { + return 0, err + } + if n < 1 { + return 0, errors.New("migration version must be greater than zero") + } + return n, nil +} diff --git a/internal/provider/collect_test.go b/internal/provider/collect_test.go index 401a1ce40..73b2642c5 100644 --- a/internal/provider/collect_test.go +++ b/internal/provider/collect_test.go @@ -47,10 +47,10 @@ func TestCollectFileSources(t *testing.T) { check.Number(t, len(sources.goSources), 0) expected := fileSources{ sqlSources: []Source{ - newSource(TypeSQL, "00001_foo.sql", 1), - newSource(TypeSQL, "00002_bar.sql", 2), - newSource(TypeSQL, "00003_baz.sql", 3), - newSource(TypeSQL, "00110_qux.sql", 110), + NewSource(TypeSQL, "00001_foo.sql", 1), + NewSource(TypeSQL, "00002_bar.sql", 2), + NewSource(TypeSQL, "00003_baz.sql", 3), + NewSource(TypeSQL, "00110_qux.sql", 110), }, } for i := 0; i < len(sources.sqlSources); i++ { @@ -74,8 +74,8 @@ func TestCollectFileSources(t *testing.T) { check.Number(t, len(sources.goSources), 0) expected := fileSources{ sqlSources: []Source{ - newSource(TypeSQL, "00001_foo.sql", 1), - newSource(TypeSQL, "00003_baz.sql", 3), + NewSource(TypeSQL, "00001_foo.sql", 1), + NewSource(TypeSQL, "00003_baz.sql", 3), }, } for i := 0; i < len(sources.sqlSources); i++ { @@ -159,15 +159,15 @@ func TestCollectFileSources(t *testing.T) { } } assertDirpath(".", []Source{ - newSource(TypeSQL, "876_a.sql", 876), + NewSource(TypeSQL, "876_a.sql", 876), }) assertDirpath("dir1", []Source{ - newSource(TypeSQL, "101_a.sql", 101), - newSource(TypeSQL, "102_b.sql", 102), - newSource(TypeSQL, "103_c.sql", 103), + NewSource(TypeSQL, "101_a.sql", 101), + NewSource(TypeSQL, "102_b.sql", 102), + NewSource(TypeSQL, "103_c.sql", 103), }) assertDirpath("dir2", []Source{ - newSource(TypeSQL, "201_a.sql", 201), + NewSource(TypeSQL, "201_a.sql", 201), }) assertDirpath("dir3", nil) }) @@ -199,14 +199,14 @@ func TestMerge(t *testing.T) { t.Run("valid", func(t *testing.T) { migrations, err := merge(sources, map[int64]*goMigration{ - 2: {version: 2}, - 3: {version: 3}, + 2: newGoMigration("", nil, nil), + 3: newGoMigration("", nil, nil), }) check.NoError(t, err) check.Number(t, len(migrations), 3) - assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1)) - assertMigration(t, migrations[1], newSource(TypeGo, "00002_bar.go", 2)) - assertMigration(t, migrations[2], newSource(TypeGo, "00003_baz.go", 3)) + assertMigration(t, migrations[0], NewSource(TypeSQL, "00001_foo.sql", 1)) + assertMigration(t, migrations[1], NewSource(TypeGo, "00002_bar.go", 2)) + assertMigration(t, migrations[2], NewSource(TypeGo, "00003_baz.go", 3)) }) t.Run("unregistered_all", func(t *testing.T) { _, err := merge(sources, nil) @@ -217,7 +217,7 @@ func TestMerge(t *testing.T) { }) t.Run("unregistered_some", func(t *testing.T) { _, err := merge(sources, map[int64]*goMigration{ - 2: {version: 2}, + 2: newGoMigration("", nil, nil), }) check.HasError(t, err) check.Contains(t, err.Error(), "error: detected 1 unregistered Go file") @@ -225,9 +225,9 @@ func TestMerge(t *testing.T) { }) t.Run("duplicate_sql", func(t *testing.T) { _, err := merge(sources, map[int64]*goMigration{ - 1: {version: 1}, // duplicate. SQL already exists. - 2: {version: 2}, - 3: {version: 3}, + 1: newGoMigration("", nil, nil), // duplicate. SQL already exists. + 2: newGoMigration("", nil, nil), + 3: newGoMigration("", nil, nil), }) check.HasError(t, err) check.Contains(t, err.Error(), "found duplicate migration version 1") @@ -246,17 +246,17 @@ func TestMerge(t *testing.T) { check.NoError(t, err) t.Run("unregistered_all", func(t *testing.T) { migrations, err := merge(sources, map[int64]*goMigration{ - 3: {version: 3}, + 3: newGoMigration("", nil, nil), // 4 is missing - 6: {version: 6}, + 6: newGoMigration("", nil, nil), }) check.NoError(t, err) check.Number(t, len(migrations), 5) - assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1)) - assertMigration(t, migrations[1], newSource(TypeSQL, "00002_bar.sql", 2)) - assertMigration(t, migrations[2], newSource(TypeGo, "", 3)) - assertMigration(t, migrations[3], newSource(TypeSQL, "00005_baz.sql", 5)) - assertMigration(t, migrations[4], newSource(TypeGo, "", 6)) + assertMigration(t, migrations[0], NewSource(TypeSQL, "00001_foo.sql", 1)) + assertMigration(t, migrations[1], NewSource(TypeSQL, "00002_bar.sql", 2)) + assertMigration(t, migrations[2], NewSource(TypeGo, "", 3)) + assertMigration(t, migrations[3], NewSource(TypeSQL, "00005_baz.sql", 5)) + assertMigration(t, migrations[4], NewSource(TypeGo, "", 6)) }) }) t.Run("partial_go_files_on_disk", func(t *testing.T) { @@ -271,17 +271,17 @@ func TestMerge(t *testing.T) { t.Run("unregistered_all", func(t *testing.T) { migrations, err := merge(sources, map[int64]*goMigration{ // This is the only Go file on disk. - 2: {version: 2}, + 2: newGoMigration("", nil, nil), // These are not on disk. Explicitly registered. - 3: {version: 3}, - 6: {version: 6}, + 3: newGoMigration("", nil, nil), + 6: newGoMigration("", nil, nil), }) check.NoError(t, err) check.Number(t, len(migrations), 4) - assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1)) - assertMigration(t, migrations[1], newSource(TypeGo, "00002_bar.go", 2)) - assertMigration(t, migrations[2], newSource(TypeGo, "", 3)) - assertMigration(t, migrations[3], newSource(TypeGo, "", 6)) + assertMigration(t, migrations[0], NewSource(TypeSQL, "00001_foo.sql", 1)) + assertMigration(t, migrations[1], NewSource(TypeGo, "00002_bar.go", 2)) + assertMigration(t, migrations[2], NewSource(TypeGo, "", 3)) + assertMigration(t, migrations[3], NewSource(TypeGo, "", 6)) }) }) } @@ -291,7 +291,7 @@ func assertMigration(t *testing.T, got *migration, want Source) { check.Equal(t, got.Source, want) switch got.Source.Type { case TypeGo: - check.Equal(t, got.Go.version, want.Version) + check.Bool(t, got.Go != nil, true) case TypeSQL: check.Bool(t, got.SQL == nil, true) default: diff --git a/internal/provider/errors.go b/internal/provider/errors.go new file mode 100644 index 000000000..e8ece3871 --- /dev/null +++ b/internal/provider/errors.go @@ -0,0 +1,39 @@ +package provider + +import ( + "errors" + "fmt" + "path/filepath" +) + +var ( + // ErrVersionNotFound when a migration version is not found. + ErrVersionNotFound = errors.New("version not found") + + // ErrAlreadyApplied when a migration has already been applied. + ErrAlreadyApplied = errors.New("already applied") + + // ErrNoMigrations is returned by [NewProvider] when no migrations are found. + ErrNoMigrations = errors.New("no migrations found") + + // ErrNoNextVersion when the next migration version is not found. + ErrNoNextVersion = errors.New("no next version found") +) + +// PartialError is returned when a migration fails, but some migrations already got applied. +type PartialError struct { + // Applied are migrations that were applied successfully before the error occurred. + Applied []*MigrationResult + // Failed contains the result of the migration that failed. + Failed *MigrationResult + // Err is the error that occurred while running the migration. + Err error +} + +func (e *PartialError) Error() string { + filename := "(file unknown)" + if e.Failed != nil && e.Failed.Source.Fullpath != "" { + filename = fmt.Sprintf("(%s)", filepath.Base(e.Failed.Source.Fullpath)) + } + return fmt.Sprintf("partial migration error %s (%d): %v", filename, e.Failed.Source.Version, e.Err) +} diff --git a/internal/provider/migration.go b/internal/provider/migration.go index cf98abc3e..87098cf22 100644 --- a/internal/provider/migration.go +++ b/internal/provider/migration.go @@ -3,8 +3,8 @@ package provider import ( "context" "database/sql" - "errors" "fmt" + "path/filepath" "github.com/pressly/goose/v3/internal/sqlextended" ) @@ -24,36 +24,83 @@ type migration struct { SQL *sqlMigration } -type MigrationType int +func (m *migration) useTx(direction bool) bool { + switch m.Source.Type { + case TypeSQL: + return m.SQL.UseTx + case TypeGo: + if m.Go == nil { + return false + } + if direction { + return m.Go.up.Run != nil + } + return m.Go.down.Run != nil + } + // This should never happen. + return false +} -const ( - TypeGo MigrationType = iota + 1 - TypeSQL -) +func (m *migration) filename() string { + return filepath.Base(m.Source.Fullpath) +} -func (t MigrationType) String() string { - switch t { - case TypeGo: - return "go" +// run runs the migration inside of a transaction. +func (m *migration) run(ctx context.Context, tx *sql.Tx, direction bool) error { + switch m.Source.Type { case TypeSQL: - return "sql" - default: - // This should never happen. - return fmt.Sprintf("unknown (%d)", t) + if m.SQL == nil { + return fmt.Errorf("tx: sql migration has not been parsed") + } + return m.SQL.run(ctx, tx, direction) + case TypeGo: + return m.Go.run(ctx, tx, direction) } + // This should never happen. + return fmt.Errorf("tx: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Fullpath)) } -func (m *migration) GetSQLStatements(direction bool) ([]string, error) { - if m.Source.Type != TypeSQL { - return nil, fmt.Errorf("expected sql migration, got %s: no sql statements", m.Source.Type) +// runNoTx runs the migration without a transaction. +func (m *migration) runNoTx(ctx context.Context, db *sql.DB, direction bool) error { + switch m.Source.Type { + case TypeSQL: + if m.SQL == nil { + return fmt.Errorf("db: sql migration has not been parsed") + } + return m.SQL.run(ctx, db, direction) + case TypeGo: + return m.Go.runNoTx(ctx, db, direction) } - if m.SQL == nil { - return nil, errors.New("sql migration has not been parsed") + // This should never happen. + return fmt.Errorf("db: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Fullpath)) +} + +// runConn runs the migration without a transaction using the provided connection. +func (m *migration) runConn(ctx context.Context, conn *sql.Conn, direction bool) error { + switch m.Source.Type { + case TypeSQL: + if m.SQL == nil { + return fmt.Errorf("conn: sql migration has not been parsed") + } + return m.SQL.run(ctx, conn, direction) + case TypeGo: + return fmt.Errorf("conn: go migrations are not supported with *sql.Conn") } - if direction { - return m.SQL.UpStatements, nil + // This should never happen. + return fmt.Errorf("conn: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Fullpath)) +} + +type goMigration struct { + fullpath string + up, down *GoMigration +} + +func newGoMigration(fullpath string, up, down *GoMigration) *goMigration { + return &goMigration{ + fullpath: fullpath, + up: up, + down: down, } - return m.SQL.DownStatements, nil } func (g *goMigration) run(ctx context.Context, tx *sql.Tx, direction bool) error { diff --git a/internal/provider/misc.go b/internal/provider/misc.go new file mode 100644 index 000000000..be84b4622 --- /dev/null +++ b/internal/provider/misc.go @@ -0,0 +1,39 @@ +package provider + +import ( + "context" + "database/sql" + "errors" + "fmt" +) + +type Migration struct { + Version int64 + Source string // path to .sql script or go file + Registered bool + UseTx bool + UpFnContext func(context.Context, *sql.Tx) error + DownFnContext func(context.Context, *sql.Tx) error + + UpFnNoTxContext func(context.Context, *sql.DB) error + DownFnNoTxContext func(context.Context, *sql.DB) error +} + +var registeredGoMigrations = make(map[int64]*Migration) + +func SetGlobalGoMigrations(migrations []*Migration) error { + for _, m := range migrations { + if m == nil { + return errors.New("cannot register nil go migration") + } + if _, ok := registeredGoMigrations[m.Version]; ok { + return fmt.Errorf("go migration with version %d already registered", m.Version) + } + registeredGoMigrations[m.Version] = m + } + return nil +} + +func ResetGlobalGoMigrations() { + registeredGoMigrations = make(map[int64]*Migration) +} diff --git a/internal/provider/provider.go b/internal/provider/provider.go index 7d5085069..4174a743a 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -6,21 +6,12 @@ import ( "errors" "fmt" "io/fs" - "os" - "time" + "math" + "sync" - "github.com/pressly/goose/v3" "github.com/pressly/goose/v3/internal/sqladapter" - "github.com/pressly/goose/v3/internal/sqlparser" ) -var ( - // ErrNoMigrations is returned by [NewProvider] when no migrations are found. - ErrNoMigrations = errors.New("no migrations found") -) - -var registeredGoMigrations = make(map[int64]*goose.Migration) - // NewProvider returns a new goose Provider. // // The caller is responsible for matching the database dialect with the database/sql driver. For @@ -36,7 +27,7 @@ var registeredGoMigrations = make(map[int64]*goose.Migration) // Unless otherwise specified, all methods on Provider are safe for concurrent use. // // Experimental: This API is experimental and may change in the future. -func NewProvider(dialect string, db *sql.DB, fsys fs.FS, opts ...ProviderOption) (*Provider, error) { +func NewProvider(dialect Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption) (*Provider, error) { if db == nil { return nil, errors.New("db must not be nil") } @@ -46,7 +37,9 @@ func NewProvider(dialect string, db *sql.DB, fsys fs.FS, opts ...ProviderOption) if fsys == nil { fsys = noopFS{} } - var cfg config + cfg := config{ + registered: make(map[int64]*goMigration), + } for _, opt := range opts { if err := opt.apply(&cfg); err != nil { return nil, err @@ -54,9 +47,9 @@ func NewProvider(dialect string, db *sql.DB, fsys fs.FS, opts ...ProviderOption) } // Set defaults after applying user-supplied options so option funcs can check for empty values. if cfg.tableName == "" { - cfg.tableName = defaultTablename + cfg.tableName = DefaultTablename } - store, err := sqladapter.NewStore(dialect, cfg.tableName) + store, err := sqladapter.NewStore(string(dialect), cfg.tableName) if err != nil { return nil, err } @@ -78,11 +71,7 @@ func NewProvider(dialect string, db *sql.DB, fsys fs.FS, opts ...ProviderOption) registered := make(map[int64]*goMigration) // Add user-registered Go migrations. for version, m := range cfg.registered { - registered[version] = &goMigration{ - version: version, - up: m.up, - down: m.down, - } + registered[version] = newGoMigration("", m.up, m.down) } // Add init() functions. This is a bit ugly because we need to convert from the old Migration // struct to the new goMigration struct. @@ -90,9 +79,7 @@ func NewProvider(dialect string, db *sql.DB, fsys fs.FS, opts ...ProviderOption) if _, ok := registered[version]; ok { return nil, fmt.Errorf("go migration with version %d already registered", version) } - g := &goMigration{ - version: version, - } + g := newGoMigration(m.Source, nil, nil) if m == nil { return nil, errors.New("registered migration with nil init function") } @@ -140,16 +127,12 @@ func NewProvider(dialect string, db *sql.DB, fsys fs.FS, opts ...ProviderOption) }, nil } -type noopFS struct{} - -var _ fs.FS = noopFS{} - -func (f noopFS) Open(name string) (fs.File, error) { - return nil, os.ErrNotExist -} - // Provider is a goose migration provider. type Provider struct { + // mu protects all accesses to the provider and must be held when calling operations on the + // database. + mu sync.Mutex + db *sql.DB fsys fs.FS cfg config @@ -157,48 +140,27 @@ type Provider struct { migrations []*migration } -// State represents the state of a migration. -type State string - -const ( - // StateUntracked represents a migration that is in the database, but not on the filesystem. - StateUntracked State = "untracked" - // StatePending represents a migration that is on the filesystem, but not in the database. - StatePending State = "pending" - // StateApplied represents a migration that is in BOTH the database and on the filesystem. - StateApplied State = "applied" -) - -// MigrationStatus represents the status of a single migration. -type MigrationStatus struct { - // State is the state of the migration. - State State - // AppliedAt is the time the migration was applied. Only set if state is [StateApplied] or - // [StateUntracked]. - AppliedAt time.Time - // Source is the migration source. Only set if the state is [StatePending] or [StateApplied]. - Source *Source -} - // Status returns the status of all migrations, merging the list of migrations from the database and // filesystem. The returned items are ordered by version, in ascending order. func (p *Provider) Status(ctx context.Context) ([]*MigrationStatus, error) { - return nil, errors.New("not implemented") + return p.status(ctx) } // GetDBVersion returns the max version from the database, regardless of the applied order. For // example, if migrations 1,4,2,3 were applied, this method returns 4. If no migrations have been // applied, it returns 0. +// +// TODO(mf): this is not true? func (p *Provider) GetDBVersion(ctx context.Context) (int64, error) { - return 0, errors.New("not implemented") + return p.getDBVersion(ctx) } // ListSources returns a list of all available migration sources the provider is aware of, sorted in // ascending order by version. -func (p *Provider) ListSources() []*Source { - sources := make([]*Source, 0, len(p.migrations)) +func (p *Provider) ListSources() []Source { + sources := make([]Source, 0, len(p.migrations)) for _, m := range p.migrations { - sources = append(sources, &m.Source) + sources = append(sources, m.Source) } return sources } @@ -213,9 +175,6 @@ func (p *Provider) Close() error { return p.db.Close() } -// MigrationResult represents the result of a single migration. -type MigrationResult struct{} - // ApplyVersion applies exactly one migration at the specified version. If there is no source for // the specified version, this method returns [ErrNoCurrentVersion]. If the migration has been // applied already, this method returns [ErrAlreadyApplied]. @@ -223,19 +182,26 @@ type MigrationResult struct{} // When direction is true, the up migration is executed, and when direction is false, the down // migration is executed. func (p *Provider) ApplyVersion(ctx context.Context, version int64, direction bool) (*MigrationResult, error) { - return nil, errors.New("not implemented") + return p.apply(ctx, version, direction) } // Up applies all pending migrations. If there are no new migrations to apply, this method returns // empty list and nil error. func (p *Provider) Up(ctx context.Context) ([]*MigrationResult, error) { - return nil, errors.New("not implemented") + return p.up(ctx, false, math.MaxInt64) } // UpByOne applies the next available migration. If there are no migrations to apply, this method -// returns [ErrNoNextVersion]. -func (p *Provider) UpByOne(ctx context.Context) (*MigrationResult, error) { - return nil, errors.New("not implemented") +// returns [ErrNoNextVersion]. The returned list will always have exactly one migration result. +func (p *Provider) UpByOne(ctx context.Context) ([]*MigrationResult, error) { + res, err := p.up(ctx, true, math.MaxInt64) + if err != nil { + return nil, err + } + if len(res) == 0 { + return nil, ErrNoNextVersion + } + return res, nil } // UpTo applies all available migrations up to and including the specified version. If there are no @@ -244,13 +210,20 @@ func (p *Provider) UpByOne(ctx context.Context) (*MigrationResult, error) { // For instance, if there are three new migrations (9,10,11) and the current database version is 8 // with a requested version of 10, only versions 9 and 10 will be applied. func (p *Provider) UpTo(ctx context.Context, version int64) ([]*MigrationResult, error) { - return nil, errors.New("not implemented") + return p.up(ctx, false, version) } // Down rolls back the most recently applied migration. If there are no migrations to apply, this // method returns [ErrNoNextVersion]. -func (p *Provider) Down(ctx context.Context) (*MigrationResult, error) { - return nil, errors.New("not implemented") +func (p *Provider) Down(ctx context.Context) ([]*MigrationResult, error) { + res, err := p.down(ctx, true, 0) + if err != nil { + return nil, err + } + if len(res) == 0 { + return nil, ErrNoNextVersion + } + return res, nil } // DownTo rolls back all migrations down to but not including the specified version. @@ -258,27 +231,8 @@ func (p *Provider) Down(ctx context.Context) (*MigrationResult, error) { // For instance, if the current database version is 11, and the requested version is 9, only // migrations 11 and 10 will be rolled back. func (p *Provider) DownTo(ctx context.Context, version int64) ([]*MigrationResult, error) { - return nil, errors.New("not implemented") -} - -// ParseSQL parses all SQL migrations in BOTH directions. If a migration has already been parsed, it -// will not be parsed again. -// -// Important: This function will mutate SQL migrations and is not safe for concurrent use. -func ParseSQL(fsys fs.FS, debug bool, migrations []*migration) error { - for _, m := range migrations { - // If the migration is a SQL migration, and it has not been parsed, parse it. - if m.Source.Type == TypeSQL && m.SQL == nil { - parsed, err := sqlparser.ParseAllFromFS(fsys, m.Source.Fullpath, debug) - if err != nil { - return err - } - m.SQL = &sqlMigration{ - UseTx: parsed.UseTx, - UpStatements: parsed.Up, - DownStatements: parsed.Down, - } - } + if version < 0 { + return nil, fmt.Errorf("version must be a number greater than or equal zero: %d", version) } - return nil + return p.down(ctx, false, version) } diff --git a/internal/provider/provider_options.go b/internal/provider/provider_options.go index f3ed15b28..0b7cd7ad6 100644 --- a/internal/provider/provider_options.go +++ b/internal/provider/provider_options.go @@ -10,7 +10,7 @@ import ( ) const ( - defaultTablename = "goose_db_version" + DefaultTablename = "goose_db_version" ) // ProviderOption is a configuration option for a goose provider. @@ -35,9 +35,9 @@ func WithTableName(name string) ProviderOption { } // WithVerbose enables verbose logging. -func WithVerbose() ProviderOption { +func WithVerbose(b bool) ProviderOption { return configFunc(func(c *config) error { - c.verbose = true + c.verbose = b return nil }) } @@ -89,7 +89,7 @@ type GoMigration struct { func WithGoMigration(version int64, up, down *GoMigration) ProviderOption { return configFunc(func(c *config) error { if version < 1 { - return fmt.Errorf("go migration version must be greater than 0") + return errors.New("version must be greater than zero") } if _, ok := c.registered[version]; ok { return fmt.Errorf("go migration with version %d already registered", version) @@ -113,17 +113,33 @@ func WithGoMigration(version int64, up, down *GoMigration) ProviderOption { } } c.registered[version] = &goMigration{ - version: version, - up: up, - down: down, + up: up, + down: down, } return nil }) } -type goMigration struct { - version int64 - up, down *GoMigration +// WithAllowMissing allows the provider to apply missing (out-of-order) migrations. +// +// Example: migrations 1,6 are applied and then version 2,3,5 are introduced. If this option is +// true, then goose will apply 2,3,5 instead of raising an error. The final order of applied +// migrations will be: 1,6,2,3,5. +func WithAllowMissing(b bool) ProviderOption { + return configFunc(func(c *config) error { + c.allowMissing = b + return nil + }) +} + +// WithNoVersioning disables versioning. Disabling versioning allows the ability to apply migrations +// without tracking the versions in the database schema table. Useful for tests, seeding a database +// or running ad-hoc queries. +func WithNoVersioning(b bool) ProviderOption { + return configFunc(func(c *config) error { + c.noVersioning = b + return nil + }) } type config struct { @@ -138,6 +154,10 @@ type config struct { // Locking options lockEnabled bool sessionLocker lock.SessionLocker + + // Feature + noVersioning bool + allowMissing bool } type configFunc func(*config) error diff --git a/internal/provider/provider_options_test.go b/internal/provider/provider_options_test.go index 89a1cda16..2271111ba 100644 --- a/internal/provider/provider_options_test.go +++ b/internal/provider/provider_options_test.go @@ -59,7 +59,7 @@ func TestNewProvider(t *testing.T) { check.NoError(t, err) // Valid dialect, db, fsys, and verbose allowed _, err = provider.NewProvider("sqlite3", db, fsys, - provider.WithVerbose(), + provider.WithVerbose(testing.Verbose()), ) check.NoError(t, err) }) diff --git a/internal/provider/provider_test.go b/internal/provider/provider_test.go index c8b5effe3..ac4ec7e0e 100644 --- a/internal/provider/provider_test.go +++ b/internal/provider/provider_test.go @@ -1,6 +1,7 @@ package provider_test import ( + "context" "database/sql" "errors" "io/fs" @@ -33,14 +34,68 @@ func TestProvider(t *testing.T) { check.NoError(t, err) sources := p.ListSources() check.Equal(t, len(sources), 2) - // 1 - check.Equal(t, sources[0].Version, int64(1)) - check.Equal(t, sources[0].Fullpath, "001_foo.sql") - check.Equal(t, sources[0].Type, provider.TypeSQL) - // 2 - check.Equal(t, sources[1].Version, int64(2)) - check.Equal(t, sources[1].Fullpath, "002_bar.sql") - check.Equal(t, sources[1].Type, provider.TypeSQL) + check.Equal(t, sources[0], provider.NewSource(provider.TypeSQL, "001_foo.sql", 1)) + check.Equal(t, sources[1], provider.NewSource(provider.TypeSQL, "002_bar.sql", 2)) + + t.Run("duplicate_go", func(t *testing.T) { + // Not parallel because it modifies global state. + register := []*provider.Migration{ + { + Version: 1, Source: "00001_users_table.go", Registered: true, UseTx: true, + UpFnContext: nil, + DownFnContext: nil, + }, + } + err := provider.SetGlobalGoMigrations(register) + check.NoError(t, err) + t.Cleanup(provider.ResetGlobalGoMigrations) + + db := newDB(t) + _, err = provider.NewProvider(provider.DialectSQLite3, db, nil, + provider.WithGoMigration(1, nil, nil), + ) + check.HasError(t, err) + check.Equal(t, err.Error(), "go migration with version 1 already registered") + }) + t.Run("empty_go", func(t *testing.T) { + db := newDB(t) + // explicit + _, err := provider.NewProvider(provider.DialectSQLite3, db, nil, + provider.WithGoMigration(1, &provider.GoMigration{Run: nil}, &provider.GoMigration{Run: nil}), + ) + check.HasError(t, err) + check.Contains(t, err.Error(), "go migration with version 1 must have an up function") + }) + t.Run("duplicate_up", func(t *testing.T) { + err := provider.SetGlobalGoMigrations([]*provider.Migration{ + { + Version: 1, Source: "00001_users_table.go", Registered: true, UseTx: true, + UpFnContext: func(context.Context, *sql.Tx) error { return nil }, + UpFnNoTxContext: func(ctx context.Context, db *sql.DB) error { return nil }, + }, + }) + check.NoError(t, err) + t.Cleanup(provider.ResetGlobalGoMigrations) + db := newDB(t) + _, err = provider.NewProvider(provider.DialectSQLite3, db, nil) + check.HasError(t, err) + check.Contains(t, err.Error(), "registered migration with both UpFnContext and UpFnNoTxContext") + }) + t.Run("duplicate_down", func(t *testing.T) { + err := provider.SetGlobalGoMigrations([]*provider.Migration{ + { + Version: 1, Source: "00001_users_table.go", Registered: true, UseTx: true, + DownFnContext: func(context.Context, *sql.Tx) error { return nil }, + DownFnNoTxContext: func(ctx context.Context, db *sql.DB) error { return nil }, + }, + }) + check.NoError(t, err) + t.Cleanup(provider.ResetGlobalGoMigrations) + db := newDB(t) + _, err = provider.NewProvider(provider.DialectSQLite3, db, nil) + check.HasError(t, err) + check.Contains(t, err.Error(), "registered migration with both DownFnContext and DownFnNoTxContext") + }) } var ( diff --git a/internal/provider/run.go b/internal/provider/run.go index f5ca25038..893be7d1a 100644 --- a/internal/provider/run.go +++ b/internal/provider/run.go @@ -3,51 +3,373 @@ package provider import ( "context" "database/sql" + "errors" "fmt" - "path/filepath" + "io/fs" + "sort" + "strings" + "time" + + "github.com/pressly/goose/v3/internal/sqladapter" + "github.com/pressly/goose/v3/internal/sqlparser" + "go.uber.org/multierr" ) -// Run runs the migration inside of a transaction. -func (m *migration) Run(ctx context.Context, tx *sql.Tx, direction bool) error { - switch m.Source.Type { - case TypeSQL: - if m.SQL == nil { - return fmt.Errorf("tx: sql migration has not been parsed") +// runMigrations runs migrations sequentially in the given direction. +// +// If the migrations slice is empty, this function returns nil with no error. +func (p *Provider) runMigrations( + ctx context.Context, + conn *sql.Conn, + migrations []*migration, + direction sqlparser.Direction, + byOne bool, +) ([]*MigrationResult, error) { + if len(migrations) == 0 { + return nil, nil + } + var apply []*migration + if byOne { + apply = []*migration{migrations[0]} + } else { + apply = migrations + } + // Lazily parse SQL migrations (if any) in both directions. We do this before running any + // migrations so that we can fail fast if there are any errors and avoid leaving the database in + // a partially migrated state. + + if err := parseSQL(p.fsys, false, apply); err != nil { + return nil, err + } + + // TODO(mf): If we decide to add support for advisory locks at the transaction level, this may + // be a good place to acquire the lock. However, we need to be sure that ALL migrations are safe + // to run in a transaction. + + // + // + // + + // bug(mf): this is a potential deadlock scenario. We're running Go migrations with *sql.DB, but + // are locking the database with *sql.Conn. If the caller sets max open connections to 1, then + // this will deadlock because the Go migration will try to acquire a connection from the pool, + // but the pool is locked. + // + // A potential solution is to expose a third Go register function *sql.Conn. Or continue to use + // *sql.DB and document that the user SHOULD NOT SET max open connections to 1. This is a bit of + // an edge case. if p.opt.LockMode != LockModeNone && p.db.Stats().MaxOpenConnections == 1 { + // for _, m := range apply { + // if m.IsGo() && !m.Go.UseTx { + // return nil, errors.New("potential deadlock detected: cannot run GoMigrationNoTx with max open connections set to 1") + // } + // } + // } + + // Run migrations individually, opening a new transaction for each migration if the migration is + // safe to run in a transaction. + + // Avoid allocating a slice because we may have a partial migration error. 1. Avoid giving the + // impression that N migrations were applied when in fact some were not 2. Avoid the caller + // having to check for nil results + var results []*MigrationResult + for _, m := range apply { + current := &MigrationResult{ + Source: m.Source, + Direction: strings.ToLower(direction.String()), + // TODO(mf): empty set here } - return m.SQL.run(ctx, tx, direction) - case TypeGo: - return m.Go.run(ctx, tx, direction) + + start := time.Now() + if err := p.runIndividually(ctx, conn, direction.ToBool(), m); err != nil { + // TODO(mf): we should also return the pending migrations here. + current.Error = err + current.Duration = time.Since(start) + return nil, &PartialError{ + Applied: results, + Failed: current, + Err: err, + } + } + + current.Duration = time.Since(start) + results = append(results, current) } - // This should never happen. - return fmt.Errorf("tx: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Fullpath)) + return results, nil } -// RunNoTx runs the migration without a transaction. -func (m *migration) RunNoTx(ctx context.Context, db *sql.DB, direction bool) error { +// runIndividually runs an individual migration, opening a new transaction if the migration is safe +// to run in a transaction. Otherwise, it runs the migration outside of a transaction with the +// supplied connection. +func (p *Provider) runIndividually( + ctx context.Context, + conn *sql.Conn, + direction bool, + m *migration, +) error { + if m.useTx(direction) { + // Run the migration in a transaction. + return p.beginTx(ctx, conn, func(tx *sql.Tx) error { + if err := m.run(ctx, tx, direction); err != nil { + return err + } + if p.cfg.noVersioning { + return nil + } + return p.store.InsertOrDelete(ctx, tx, direction, m.Source.Version) + }) + } + // Run the migration outside of a transaction. switch m.Source.Type { + case TypeGo: + // Note, we're using *sql.DB instead of *sql.Conn because it's the contract of the + // GoMigrationNoTx function. This may be a deadlock scenario if the caller sets max open + // connections to 1. See the comment in runMigrations for more details. + if err := m.Go.runNoTx(ctx, p.db, direction); err != nil { + return err + } case TypeSQL: - if m.SQL == nil { - return fmt.Errorf("db: sql migration has not been parsed") + if err := m.runConn(ctx, conn, direction); err != nil { + return err } - return m.SQL.run(ctx, db, direction) - case TypeGo: - return m.Go.runNoTx(ctx, db, direction) } - // This should never happen. - return fmt.Errorf("db: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Fullpath)) + if p.cfg.noVersioning { + return nil + } + return p.store.InsertOrDelete(ctx, conn, direction, m.Source.Version) } -// RunConn runs the migration without a transaction using the provided connection. -func (m *migration) RunConn(ctx context.Context, conn *sql.Conn, direction bool) error { - switch m.Source.Type { - case TypeSQL: - if m.SQL == nil { - return fmt.Errorf("conn: sql migration has not been parsed") +// beginTx begins a transaction and runs the given function. If the function returns an error, the +// transaction is rolled back. Otherwise, the transaction is committed. +// +// If the provider is configured to use versioning, this function also inserts or deletes the +// migration version. +func (p *Provider) beginTx( + ctx context.Context, + conn *sql.Conn, + fn func(tx *sql.Tx) error, +) (retErr error) { + tx, err := conn.BeginTx(ctx, nil) + if err != nil { + return err + } + defer func() { + if retErr != nil { + retErr = multierr.Append(retErr, tx.Rollback()) } - return m.SQL.run(ctx, conn, direction) - case TypeGo: - return fmt.Errorf("conn: go migrations are not supported with *sql.Conn") + }() + if err := fn(tx); err != nil { + return err + } + return tx.Commit() +} + +func (p *Provider) initialize(ctx context.Context) (*sql.Conn, func() error, error) { + p.mu.Lock() + conn, err := p.db.Conn(ctx) + if err != nil { + p.mu.Unlock() + return nil, nil, err + } + // cleanup is a function that cleans up the connection, and optionally, the session lock. + cleanup := func() error { + p.mu.Unlock() + return conn.Close() + } + if l := p.cfg.sessionLocker; l != nil && p.cfg.lockEnabled { + if err := l.SessionLock(ctx, conn); err != nil { + return nil, nil, multierr.Append(err, cleanup()) + } + cleanup = func() error { + p.mu.Unlock() + // Use a detached context to unlock the session. This is because the context passed to + // SessionLock may have been canceled, and we don't want to cancel the unlock. + // TODO(mf): use [context.WithoutCancel] added in go1.21 + detachedCtx := context.Background() + return multierr.Append(l.SessionUnlock(detachedCtx, conn), conn.Close()) + } + } + // If versioning is enabled, ensure the version table exists. For ad-hoc migrations, we don't + // need the version table because there is no versioning. + if !p.cfg.noVersioning { + if err := p.ensureVersionTable(ctx, conn); err != nil { + return nil, nil, multierr.Append(err, cleanup()) + } + } + return conn, cleanup, nil +} + +// parseSQL parses all SQL migrations in BOTH directions. If a migration has already been parsed, it +// will not be parsed again. +// +// Important: This function will mutate SQL migrations and is not safe for concurrent use. +func parseSQL(fsys fs.FS, debug bool, migrations []*migration) error { + for _, m := range migrations { + // If the migration is a SQL migration, and it has not been parsed, parse it. + if m.Source.Type == TypeSQL && m.SQL == nil { + parsed, err := sqlparser.ParseAllFromFS(fsys, m.Source.Fullpath, debug) + if err != nil { + return err + } + m.SQL = &sqlMigration{ + UseTx: parsed.UseTx, + UpStatements: parsed.Up, + DownStatements: parsed.Down, + } + } + } + return nil +} + +func (p *Provider) ensureVersionTable(ctx context.Context, conn *sql.Conn) (retErr error) { + // feat(mf): this is where we can check if the version table exists instead of trying to fetch + // from a table that may not exist. https://github.com/pressly/goose/issues/461 + res, err := p.store.GetMigration(ctx, conn, 0) + if err == nil && res != nil { + return nil + } + return p.beginTx(ctx, conn, func(tx *sql.Tx) error { + if err := p.store.CreateVersionTable(ctx, tx); err != nil { + return err + } + if p.cfg.noVersioning { + return nil + } + return p.store.InsertOrDelete(ctx, tx, true, 0) + }) +} + +type missingMigration struct { + versionID int64 + filename string +} + +// findMissingMigrations returns a list of migrations that are missing from the database. A missing +// migration is one that has a version less than the max version in the database. +func findMissingMigrations( + dbMigrations []*sqladapter.ListMigrationsResult, + fsMigrations []*migration, + dbMaxVersion int64, +) []missingMigration { + existing := make(map[int64]bool) + for _, m := range dbMigrations { + existing[m.Version] = true + } + var missing []missingMigration + for _, m := range fsMigrations { + version := m.Source.Version + if !existing[version] && version < dbMaxVersion { + missing = append(missing, missingMigration{ + versionID: version, + filename: m.filename(), + }) + } + } + sort.Slice(missing, func(i, j int) bool { + return missing[i].versionID < missing[j].versionID + }) + return missing +} + +// getMigration returns the migration with the given version. If no migration is found, then +// ErrVersionNotFound is returned. +func (p *Provider) getMigration(version int64) (*migration, error) { + for _, m := range p.migrations { + if m.Source.Version == version { + return m, nil + } + } + return nil, ErrVersionNotFound +} + +func (p *Provider) apply(ctx context.Context, version int64, direction bool) (_ *MigrationResult, retErr error) { + if version < 1 { + return nil, errors.New("version must be greater than zero") + } + + m, err := p.getMigration(version) + if err != nil { + return nil, err + } + + conn, cleanup, err := p.initialize(ctx) + if err != nil { + return nil, err + } + defer func() { + retErr = multierr.Append(retErr, cleanup()) + }() + + result, err := p.store.GetMigration(ctx, conn, version) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return nil, err + } + // If the migration has already been applied, return an error, unless the migration is being + // applied in the opposite direction. In that case, we allow the migration to be applied again. + if result != nil && direction { + return nil, fmt.Errorf("version %d: %w", version, ErrAlreadyApplied) + } + + d := sqlparser.DirectionDown + if direction { + d = sqlparser.DirectionUp + } + results, err := p.runMigrations(ctx, conn, []*migration{m}, d, true) + if err != nil { + return nil, err + } + if len(results) == 0 { + return nil, fmt.Errorf("version %d: %w", version, ErrAlreadyApplied) + } + return results[0], nil +} + +func (p *Provider) status(ctx context.Context) (_ []*MigrationStatus, retErr error) { + conn, cleanup, err := p.initialize(ctx) + if err != nil { + return nil, err + } + defer func() { + retErr = multierr.Append(retErr, cleanup()) + }() + + // TODO(mf): add support for limit and order. Also would be nice to refactor the list query to + // support limiting the set. + + status := make([]*MigrationStatus, 0, len(p.migrations)) + for _, m := range p.migrations { + migrationStatus := &MigrationStatus{ + Source: m.Source, + State: StatePending, + } + dbResult, err := p.store.GetMigration(ctx, conn, m.Source.Version) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return nil, err + } + if dbResult != nil { + migrationStatus.State = StateApplied + migrationStatus.AppliedAt = dbResult.Timestamp + } + status = append(status, migrationStatus) + } + + return status, nil +} + +func (p *Provider) getDBVersion(ctx context.Context) (_ int64, retErr error) { + conn, cleanup, err := p.initialize(ctx) + if err != nil { + return 0, err + } + defer func() { + retErr = multierr.Append(retErr, cleanup()) + }() + + res, err := p.store.ListMigrations(ctx, conn) + if err != nil { + return 0, err + } + if len(res) == 0 { + return 0, nil } - // This should never happen. - return fmt.Errorf("conn: failed to run migration %s: neither sql or go", filepath.Base(m.Source.Fullpath)) + return res[0].Version, nil } diff --git a/internal/provider/run_down.go b/internal/provider/run_down.go new file mode 100644 index 000000000..011ba7990 --- /dev/null +++ b/internal/provider/run_down.go @@ -0,0 +1,53 @@ +package provider + +import ( + "context" + + "github.com/pressly/goose/v3/internal/sqlparser" + "go.uber.org/multierr" +) + +func (p *Provider) down(ctx context.Context, downByOne bool, version int64) (_ []*MigrationResult, retErr error) { + conn, cleanup, err := p.initialize(ctx) + if err != nil { + return nil, err + } + defer func() { + retErr = multierr.Append(retErr, cleanup()) + }() + + if len(p.migrations) == 0 { + return nil, nil + } + + if p.cfg.noVersioning { + var downMigrations []*migration + if downByOne { + downMigrations = append(downMigrations, p.migrations[len(p.migrations)-1]) + } else { + downMigrations = p.migrations + } + return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, downByOne) + } + + dbMigrations, err := p.store.ListMigrations(ctx, conn) + if err != nil { + return nil, err + } + if dbMigrations[0].Version == 0 { + return nil, nil + } + + var downMigrations []*migration + for _, dbMigration := range dbMigrations { + if dbMigration.Version <= version { + break + } + m, err := p.getMigration(dbMigration.Version) + if err != nil { + return nil, err + } + downMigrations = append(downMigrations, m) + } + return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, downByOne) +} diff --git a/internal/provider/run_test.go b/internal/provider/run_test.go new file mode 100644 index 000000000..56644ca04 --- /dev/null +++ b/internal/provider/run_test.go @@ -0,0 +1,1288 @@ +package provider_test + +import ( + "context" + "database/sql" + "errors" + "fmt" + "io/fs" + "math" + "math/rand" + "os" + "path/filepath" + "reflect" + "sort" + "sync" + "testing" + "testing/fstest" + + "github.com/pressly/goose/v3/internal/check" + "github.com/pressly/goose/v3/internal/provider" + "github.com/pressly/goose/v3/internal/testdb" + "github.com/pressly/goose/v3/lock" + "golang.org/x/sync/errgroup" +) + +func TestProviderRun(t *testing.T) { + t.Parallel() + + t.Run("closed_db", func(t *testing.T) { + p, db := newProviderWithDB(t) + check.NoError(t, db.Close()) + _, err := p.Up(context.Background()) + check.HasError(t, err) + check.Equal(t, err.Error(), "sql: database is closed") + }) + t.Run("ping_and_close", func(t *testing.T) { + p, _ := newProviderWithDB(t) + t.Cleanup(func() { + check.NoError(t, p.Close()) + }) + check.NoError(t, p.Ping(context.Background())) + }) + t.Run("apply_unknown_version", func(t *testing.T) { + p, _ := newProviderWithDB(t) + _, err := p.ApplyVersion(context.Background(), 999, true) + check.HasError(t, err) + check.Bool(t, errors.Is(err, provider.ErrVersionNotFound), true) + _, err = p.ApplyVersion(context.Background(), 999, false) + check.HasError(t, err) + check.Bool(t, errors.Is(err, provider.ErrVersionNotFound), true) + }) + t.Run("run_zero", func(t *testing.T) { + p, _ := newProviderWithDB(t) + _, err := p.UpTo(context.Background(), 0) + check.HasError(t, err) + check.Equal(t, err.Error(), "version must be greater than zero") + _, err = p.DownTo(context.Background(), -1) + check.HasError(t, err) + check.Equal(t, err.Error(), "version must be a number greater than or equal zero: -1") + _, err = p.ApplyVersion(context.Background(), 0, true) + check.HasError(t, err) + check.Equal(t, err.Error(), "version must be greater than zero") + }) + t.Run("up_and_down_all", func(t *testing.T) { + ctx := context.Background() + p, _ := newProviderWithDB(t) + const ( + numCount = 7 + ) + sources := p.ListSources() + check.Number(t, len(sources), numCount) + // Ensure only SQL migrations are returned + for _, s := range sources { + check.Equal(t, s.Type, provider.TypeSQL) + } + // Test Up + res, err := p.Up(ctx) + check.NoError(t, err) + check.Number(t, len(res), numCount) + assertResult(t, res[0], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "up") + assertResult(t, res[1], provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), "up") + assertResult(t, res[2], provider.NewSource(provider.TypeSQL, "00003_comments_table.sql", 3), "up") + assertResult(t, res[3], provider.NewSource(provider.TypeSQL, "00004_insert_data.sql", 4), "up") + assertResult(t, res[4], provider.NewSource(provider.TypeSQL, "00005_posts_view.sql", 5), "up") + assertResult(t, res[5], provider.NewSource(provider.TypeSQL, "00006_empty_up.sql", 6), "up") + assertResult(t, res[6], provider.NewSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), "up") + // Test Down + res, err = p.DownTo(ctx, 0) + check.NoError(t, err) + check.Number(t, len(res), numCount) + assertResult(t, res[0], provider.NewSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), "down") + assertResult(t, res[1], provider.NewSource(provider.TypeSQL, "00006_empty_up.sql", 6), "down") + assertResult(t, res[2], provider.NewSource(provider.TypeSQL, "00005_posts_view.sql", 5), "down") + assertResult(t, res[3], provider.NewSource(provider.TypeSQL, "00004_insert_data.sql", 4), "down") + assertResult(t, res[4], provider.NewSource(provider.TypeSQL, "00003_comments_table.sql", 3), "down") + assertResult(t, res[5], provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), "down") + assertResult(t, res[6], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "down") + }) + t.Run("up_and_down_by_one", func(t *testing.T) { + ctx := context.Background() + p, _ := newProviderWithDB(t) + maxVersion := len(p.ListSources()) + // Apply all migrations one-by-one. + var counter int + for { + res, err := p.UpByOne(ctx) + counter++ + if counter > maxVersion { + if !errors.Is(err, provider.ErrNoNextVersion) { + t.Fatalf("incorrect error: got:%v want:%v", err, provider.ErrNoNextVersion) + } + break + } + check.NoError(t, err) + check.Number(t, len(res), 1) + check.Number(t, res[0].Source.Version, int64(counter)) + } + currentVersion, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, currentVersion, int64(maxVersion)) + // Reset counter + counter = 0 + // Rollback all migrations one-by-one. + for { + res, err := p.Down(ctx) + counter++ + if counter > maxVersion { + if !errors.Is(err, provider.ErrNoNextVersion) { + t.Fatalf("incorrect error: got:%v want:%v", err, provider.ErrNoNextVersion) + } + break + } + check.NoError(t, err) + check.Number(t, len(res), 1) + check.Number(t, res[0].Source.Version, int64(maxVersion-counter+1)) + } + // Once everything is tested the version should match the highest testdata version + currentVersion, err = p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, currentVersion, 0) + }) + t.Run("up_to", func(t *testing.T) { + ctx := context.Background() + p, db := newProviderWithDB(t) + const ( + upToVersion int64 = 2 + ) + results, err := p.UpTo(ctx, upToVersion) + check.NoError(t, err) + check.Number(t, len(results), upToVersion) + assertResult(t, results[0], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "up") + assertResult(t, results[1], provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), "up") + // Fetch the goose version from DB + currentVersion, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, currentVersion, upToVersion) + // Validate the version actually matches what goose claims it is + gotVersion, err := getMaxVersionID(db, provider.DefaultTablename) + check.NoError(t, err) + check.Number(t, gotVersion, upToVersion) + }) + t.Run("sql_connections", func(t *testing.T) { + tt := []struct { + name string + maxOpenConns int + maxIdleConns int + useDefaults bool + }{ + // Single connection ensures goose is able to function correctly when multiple + // connections are not available. + {name: "single_conn", maxOpenConns: 1, maxIdleConns: 1}, + {name: "defaults", useDefaults: true}, + } + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + // Start a new database for each test case. + p, db := newProviderWithDB(t) + if !tc.useDefaults { + db.SetMaxOpenConns(tc.maxOpenConns) + db.SetMaxIdleConns(tc.maxIdleConns) + } + sources := p.ListSources() + check.NumberNotZero(t, len(sources)) + + currentVersion, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, currentVersion, 0) + + { + // Apply all up migrations + upResult, err := p.Up(ctx) + check.NoError(t, err) + check.Number(t, len(upResult), len(sources)) + currentVersion, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, currentVersion, p.ListSources()[len(sources)-1].Version) + // Validate the db migration version actually matches what goose claims it is + gotVersion, err := getMaxVersionID(db, provider.DefaultTablename) + check.NoError(t, err) + check.Number(t, gotVersion, currentVersion) + tables, err := getTableNames(db) + check.NoError(t, err) + if !reflect.DeepEqual(tables, knownTables) { + t.Logf("got tables: %v", tables) + t.Logf("known tables: %v", knownTables) + t.Fatal("failed to match tables") + } + } + { + // Apply all down migrations + downResult, err := p.DownTo(ctx, 0) + check.NoError(t, err) + check.Number(t, len(downResult), len(sources)) + gotVersion, err := getMaxVersionID(db, provider.DefaultTablename) + check.NoError(t, err) + check.Number(t, gotVersion, 0) + // Should only be left with a single table, the default goose table + tables, err := getTableNames(db) + check.NoError(t, err) + knownTables := []string{provider.DefaultTablename, "sqlite_sequence"} + if !reflect.DeepEqual(tables, knownTables) { + t.Logf("got tables: %v", tables) + t.Logf("known tables: %v", knownTables) + t.Fatal("failed to match tables") + } + } + }) + } + }) + t.Run("apply", func(t *testing.T) { + ctx := context.Background() + p, _ := newProviderWithDB(t) + sources := p.ListSources() + // Apply all migrations in the up direction. + for _, s := range sources { + res, err := p.ApplyVersion(ctx, s.Version, true) + check.NoError(t, err) + // Round-trip the migration result through the database to ensure it's valid. + assertResult(t, res, s, "up") + } + // Apply all migrations in the down direction. + for i := len(sources) - 1; i >= 0; i-- { + s := sources[i] + res, err := p.ApplyVersion(ctx, s.Version, false) + check.NoError(t, err) + // Round-trip the migration result through the database to ensure it's valid. + assertResult(t, res, s, "down") + } + // Try apply version 1 multiple times + _, err := p.ApplyVersion(ctx, 1, true) + check.NoError(t, err) + _, err = p.ApplyVersion(ctx, 1, true) + check.HasError(t, err) + check.Bool(t, errors.Is(err, provider.ErrAlreadyApplied), true) + check.Contains(t, err.Error(), "version 1: already applied") + }) + t.Run("status", func(t *testing.T) { + ctx := context.Background() + p, _ := newProviderWithDB(t) + numCount := len(p.ListSources()) + // Before any migrations are applied, the status should be empty. + status, err := p.Status(ctx) + check.NoError(t, err) + check.Number(t, len(status), numCount) + assertStatus(t, status[0], provider.StatePending, provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), true) + assertStatus(t, status[1], provider.StatePending, provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), true) + assertStatus(t, status[2], provider.StatePending, provider.NewSource(provider.TypeSQL, "00003_comments_table.sql", 3), true) + assertStatus(t, status[3], provider.StatePending, provider.NewSource(provider.TypeSQL, "00004_insert_data.sql", 4), true) + assertStatus(t, status[4], provider.StatePending, provider.NewSource(provider.TypeSQL, "00005_posts_view.sql", 5), true) + assertStatus(t, status[5], provider.StatePending, provider.NewSource(provider.TypeSQL, "00006_empty_up.sql", 6), true) + assertStatus(t, status[6], provider.StatePending, provider.NewSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), true) + // Apply all migrations + _, err = p.Up(ctx) + check.NoError(t, err) + status, err = p.Status(ctx) + check.NoError(t, err) + check.Number(t, len(status), numCount) + assertStatus(t, status[0], provider.StateApplied, provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), false) + assertStatus(t, status[1], provider.StateApplied, provider.NewSource(provider.TypeSQL, "00002_posts_table.sql", 2), false) + assertStatus(t, status[2], provider.StateApplied, provider.NewSource(provider.TypeSQL, "00003_comments_table.sql", 3), false) + assertStatus(t, status[3], provider.StateApplied, provider.NewSource(provider.TypeSQL, "00004_insert_data.sql", 4), false) + assertStatus(t, status[4], provider.StateApplied, provider.NewSource(provider.TypeSQL, "00005_posts_view.sql", 5), false) + assertStatus(t, status[5], provider.StateApplied, provider.NewSource(provider.TypeSQL, "00006_empty_up.sql", 6), false) + assertStatus(t, status[6], provider.StateApplied, provider.NewSource(provider.TypeSQL, "00007_empty_up_down.sql", 7), false) + }) + t.Run("tx_partial_errors", func(t *testing.T) { + countOwners := func(db *sql.DB) (int, error) { + q := `SELECT count(*)FROM owners` + var count int + if err := db.QueryRow(q).Scan(&count); err != nil { + return 0, err + } + return count, nil + } + + ctx := context.Background() + db := newDB(t) + mapFS := fstest.MapFS{ + "00001_users_table.sql": newMapFile(` +-- +goose Up +CREATE TABLE owners ( owner_name TEXT NOT NULL ); +`), + "00002_partial_error.sql": newMapFile(` +-- +goose Up +INSERT INTO invalid_table (invalid_table) VALUES ('invalid_value'); +`), + "00003_insert_data.sql": newMapFile(` +-- +goose Up +INSERT INTO owners (owner_name) VALUES ('seed-user-1'); +INSERT INTO owners (owner_name) VALUES ('seed-user-2'); +INSERT INTO owners (owner_name) VALUES ('seed-user-3'); +`), + } + p, err := provider.NewProvider(provider.DialectSQLite3, db, mapFS) + check.NoError(t, err) + _, err = p.Up(ctx) + check.HasError(t, err) + check.Contains(t, err.Error(), "partial migration error (00002_partial_error.sql) (2)") + var expected *provider.PartialError + check.Bool(t, errors.As(err, &expected), true) + // Check Err field + check.Bool(t, expected.Err != nil, true) + check.Contains(t, expected.Err.Error(), "SQL logic error: no such table: invalid_table (1)") + // Check Results field + check.Number(t, len(expected.Applied), 1) + assertResult(t, expected.Applied[0], provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), "up") + // Check Failed field + check.Bool(t, expected.Failed != nil, true) + assertSource(t, expected.Failed.Source, provider.TypeSQL, "00002_partial_error.sql", 2) + check.Bool(t, expected.Failed.Empty, false) + check.Bool(t, expected.Failed.Error != nil, true) + check.Contains(t, expected.Failed.Error.Error(), "SQL logic error: no such table: invalid_table (1)") + check.Equal(t, expected.Failed.Direction, "up") + check.Bool(t, expected.Failed.Duration > 0, true) + + // Ensure the partial error did not affect the database. + count, err := countOwners(db) + check.NoError(t, err) + check.Number(t, count, 0) + + status, err := p.Status(ctx) + check.NoError(t, err) + check.Number(t, len(status), 3) + assertStatus(t, status[0], provider.StateApplied, provider.NewSource(provider.TypeSQL, "00001_users_table.sql", 1), false) + assertStatus(t, status[1], provider.StatePending, provider.NewSource(provider.TypeSQL, "00002_partial_error.sql", 2), true) + assertStatus(t, status[2], provider.StatePending, provider.NewSource(provider.TypeSQL, "00003_insert_data.sql", 3), true) + }) +} + +func TestConcurrentProvider(t *testing.T) { + t.Parallel() + + t.Run("up", func(t *testing.T) { + ctx := context.Background() + p, _ := newProviderWithDB(t) + maxVersion := len(p.ListSources()) + + ch := make(chan int64) + var wg sync.WaitGroup + for i := 0; i < maxVersion; i++ { + wg.Add(1) + + go func() { + defer wg.Done() + res, err := p.UpByOne(ctx) + if err != nil { + t.Error(err) + return + } + if len(res) != 1 { + t.Errorf("expected 1 result, got %d", len(res)) + return + } + ch <- res[0].Source.Version + }() + } + go func() { + wg.Wait() + close(ch) + }() + var versions []int64 + for version := range ch { + versions = append(versions, version) + } + // Fail early if any of the goroutines failed. + if t.Failed() { + return + } + check.Number(t, len(versions), maxVersion) + for i := 0; i < maxVersion; i++ { + check.Number(t, versions[i], int64(i+1)) + } + currentVersion, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, currentVersion, maxVersion) + }) + t.Run("down", func(t *testing.T) { + ctx := context.Background() + p, _ := newProviderWithDB(t) + maxVersion := len(p.ListSources()) + // Apply all migrations + _, err := p.Up(ctx) + check.NoError(t, err) + currentVersion, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, currentVersion, maxVersion) + + ch := make(chan []*provider.MigrationResult) + var wg sync.WaitGroup + for i := 0; i < maxVersion; i++ { + wg.Add(1) + + go func() { + defer wg.Done() + res, err := p.DownTo(ctx, 0) + if err != nil { + t.Error(err) + return + } + ch <- res + }() + } + go func() { + wg.Wait() + close(ch) + }() + var ( + valid [][]*provider.MigrationResult + empty [][]*provider.MigrationResult + ) + for results := range ch { + if len(results) == 0 { + empty = append(empty, results) + continue + } + valid = append(valid, results) + } + // Fail early if any of the goroutines failed. + if t.Failed() { + return + } + check.Equal(t, len(valid), 1) + check.Equal(t, len(empty), maxVersion-1) + // Ensure the valid result is correct. + check.Number(t, len(valid[0]), maxVersion) + }) +} + +func TestNoVersioning(t *testing.T) { + t.Parallel() + + countSeedOwners := func(db *sql.DB) (int, error) { + q := `SELECT count(*)FROM owners WHERE owner_name LIKE'seed-user-%'` + var count int + if err := db.QueryRow(q).Scan(&count); err != nil { + return 0, err + } + return count, nil + } + countOwners := func(db *sql.DB) (int, error) { + q := `SELECT count(*)FROM owners` + var count int + if err := db.QueryRow(q).Scan(&count); err != nil { + return 0, err + } + return count, nil + } + ctx := context.Background() + dbName := fmt.Sprintf("test_%s.db", randomAlphaNumeric(8)) + db, err := sql.Open("sqlite", filepath.Join(t.TempDir(), dbName)) + check.NoError(t, err) + fsys := os.DirFS(filepath.Join("testdata", "no-versioning", "migrations")) + const ( + // Total owners created by the seed files. + wantSeedOwnerCount = 250 + // These are owners created by migration files. + wantOwnerCount = 4 + ) + p, err := provider.NewProvider(provider.DialectSQLite3, db, fsys, + provider.WithVerbose(testing.Verbose()), + provider.WithNoVersioning(false), // This is the default. + ) + check.Number(t, len(p.ListSources()), 3) + check.NoError(t, err) + _, err = p.Up(ctx) + check.NoError(t, err) + baseVersion, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, baseVersion, 3) + t.Run("seed-up-down-to-zero", func(t *testing.T) { + fsys := os.DirFS(filepath.Join("testdata", "no-versioning", "seed")) + p, err := provider.NewProvider(provider.DialectSQLite3, db, fsys, + provider.WithVerbose(testing.Verbose()), + provider.WithNoVersioning(true), // Provider with no versioning. + ) + check.NoError(t, err) + check.Number(t, len(p.ListSources()), 2) + + // Run (all) up migrations from the seed dir + { + upResult, err := p.Up(ctx) + check.NoError(t, err) + check.Number(t, len(upResult), 2) + // Confirm no changes to the versioned schema in the DB + currentVersion, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, baseVersion, currentVersion) + seedOwnerCount, err := countSeedOwners(db) + check.NoError(t, err) + check.Number(t, seedOwnerCount, wantSeedOwnerCount) + } + // Run (all) down migrations from the seed dir + { + downResult, err := p.DownTo(ctx, 0) + check.NoError(t, err) + check.Number(t, len(downResult), 2) + // Confirm no changes to the versioned schema in the DB + currentVersion, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, baseVersion, currentVersion) + seedOwnerCount, err := countSeedOwners(db) + check.NoError(t, err) + check.Number(t, seedOwnerCount, 0) + } + // The migrations added 4 non-seed owners, they must remain in the database afterwards + ownerCount, err := countOwners(db) + check.NoError(t, err) + check.Number(t, ownerCount, wantOwnerCount) + }) +} + +func TestAllowMissing(t *testing.T) { + t.Parallel() + ctx := context.Background() + + // Developer A and B check out the "main" branch which is currently on version 3. Developer A + // mistakenly creates migration 5 and commits. Developer B did not pull the latest changes and + // commits migration 4. Oops -- now the migrations are out of order. + // + // When goose is set to allow missing migrations, then 5 is applied after 4 with no error. + // Otherwise it's expected to be an error. + + t.Run("missing_now_allowed", func(t *testing.T) { + db := newDB(t) + p, err := provider.NewProvider(provider.DialectSQLite3, db, newFsys(), + provider.WithAllowMissing(false), + ) + check.NoError(t, err) + + // Create and apply first 3 migrations. + _, err = p.UpTo(ctx, 3) + check.NoError(t, err) + currentVersion, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, currentVersion, 3) + + // Developer A - migration 5 (mistakenly applied) + result, err := p.ApplyVersion(ctx, 5, true) + check.NoError(t, err) + check.Number(t, result.Source.Version, 5) + current, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, current, 5) + + // The database has migrations 1,2,3,5 applied. + + // Developer B is on version 3 (e.g., never pulled the latest changes). Adds migration 4. By + // default goose does not allow missing (out-of-order) migrations, which means halt if a + // missing migration is detected. + _, err = p.Up(ctx) + check.HasError(t, err) + // found 1 missing (out-of-order) migration: [00004_insert_data.sql] + check.Contains(t, err.Error(), "missing (out-of-order) migration") + // Confirm db version is unchanged. + current, err = p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, current, 5) + + _, err = p.UpByOne(ctx) + check.HasError(t, err) + // found 1 missing (out-of-order) migration: [00004_insert_data.sql] + check.Contains(t, err.Error(), "missing (out-of-order) migration") + // Confirm db version is unchanged. + current, err = p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, current, 5) + + _, err = p.UpTo(ctx, math.MaxInt64) + check.HasError(t, err) + // found 1 missing (out-of-order) migration: [00004_insert_data.sql] + check.Contains(t, err.Error(), "missing (out-of-order) migration") + // Confirm db version is unchanged. + current, err = p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, current, 5) + }) + + t.Run("missing_allowed", func(t *testing.T) { + db := newDB(t) + p, err := provider.NewProvider(provider.DialectSQLite3, db, newFsys(), + provider.WithAllowMissing(true), + ) + check.NoError(t, err) + + // Create and apply first 3 migrations. + _, err = p.UpTo(ctx, 3) + check.NoError(t, err) + currentVersion, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, currentVersion, 3) + + // Developer A - migration 5 (mistakenly applied) + { + _, err = p.ApplyVersion(ctx, 5, true) + check.NoError(t, err) + current, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, current, 5) + } + // Developer B - migration 4 (missing) and 6 (new) + { + // 4 + upResult, err := p.UpByOne(ctx) + check.NoError(t, err) + check.Number(t, len(upResult), 1) + check.Number(t, upResult[0].Source.Version, 4) + // 6 + upResult, err = p.UpByOne(ctx) + check.NoError(t, err) + check.Number(t, len(upResult), 1) + check.Number(t, upResult[0].Source.Version, 6) + + count, err := getGooseVersionCount(db, provider.DefaultTablename) + check.NoError(t, err) + check.Number(t, count, 6) + current, err := p.GetDBVersion(ctx) + check.NoError(t, err) + // Expecting max(version_id) to be 8 + check.Number(t, current, 6) + } + + // The applied order in the database is expected to be: + // 1,2,3,5,4,6 + // So migrating down should be the reverse of the applied order: + // 6,4,5,3,2,1 + + expected := []int64{6, 4, 5, 3, 2, 1} + for i, v := range expected { + // TODO(mf): this is returning it by the order it was applied. + current, err := p.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, current, v) + downResult, err := p.Down(ctx) + if i == len(expected)-1 { + check.HasError(t, provider.ErrVersionNotFound) + } else { + check.NoError(t, err) + check.Number(t, len(downResult), 1) + check.Number(t, downResult[0].Source.Version, v) + } + } + }) +} + +func getGooseVersionCount(db *sql.DB, gooseTable string) (int64, error) { + var gotVersion int64 + if err := db.QueryRow( + fmt.Sprintf("SELECT count(*) FROM %s WHERE version_id > 0", gooseTable), + ).Scan(&gotVersion); err != nil { + return 0, err + } + return gotVersion, nil +} + +func TestGoOnly(t *testing.T) { + // Not parallel because it modifies global state. + + countUser := func(db *sql.DB) int { + q := `SELECT count(*)FROM users` + var count int + err := db.QueryRow(q).Scan(&count) + check.NoError(t, err) + return count + } + + t.Run("with_tx", func(t *testing.T) { + ctx := context.Background() + register := []*provider.Migration{ + { + Version: 1, Source: "00001_users_table.go", Registered: true, UseTx: true, + UpFnContext: newTxFn("CREATE TABLE users (id INTEGER PRIMARY KEY)"), + DownFnContext: newTxFn("DROP TABLE users"), + }, + } + err := provider.SetGlobalGoMigrations(register) + check.NoError(t, err) + t.Cleanup(provider.ResetGlobalGoMigrations) + + db := newDB(t) + p, err := provider.NewProvider(provider.DialectSQLite3, db, nil, + provider.WithGoMigration( + 2, + &provider.GoMigration{Run: newTxFn("INSERT INTO users (id) VALUES (1), (2), (3)")}, + &provider.GoMigration{Run: newTxFn("DELETE FROM users")}, + ), + ) + check.NoError(t, err) + sources := p.ListSources() + check.Number(t, len(p.ListSources()), 2) + assertSource(t, sources[0], provider.TypeGo, "00001_users_table.go", 1) + assertSource(t, sources[1], provider.TypeGo, "", 2) + // Apply migration 1 + res, err := p.UpByOne(ctx) + check.NoError(t, err) + check.Number(t, len(res), 1) + assertResult(t, res[0], provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "up") + check.Number(t, countUser(db), 0) + check.Bool(t, tableExists(t, db, "users"), true) + // Apply migration 2 + res, err = p.UpByOne(ctx) + check.NoError(t, err) + check.Number(t, len(res), 1) + assertResult(t, res[0], provider.NewSource(provider.TypeGo, "", 2), "up") + check.Number(t, countUser(db), 3) + // Rollback migration 2 + res, err = p.Down(ctx) + check.NoError(t, err) + check.Number(t, len(res), 1) + assertResult(t, res[0], provider.NewSource(provider.TypeGo, "", 2), "down") + check.Number(t, countUser(db), 0) + // Rollback migration 1 + res, err = p.Down(ctx) + check.NoError(t, err) + check.Number(t, len(res), 1) + assertResult(t, res[0], provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "down") + // Check table does not exist + check.Bool(t, tableExists(t, db, "users"), false) + }) + t.Run("with_db", func(t *testing.T) { + ctx := context.Background() + register := []*provider.Migration{ + { + Version: 1, Source: "00001_users_table.go", Registered: true, UseTx: true, + UpFnNoTxContext: newDBFn("CREATE TABLE users (id INTEGER PRIMARY KEY)"), + DownFnNoTxContext: newDBFn("DROP TABLE users"), + }, + } + err := provider.SetGlobalGoMigrations(register) + check.NoError(t, err) + t.Cleanup(provider.ResetGlobalGoMigrations) + + db := newDB(t) + p, err := provider.NewProvider(provider.DialectSQLite3, db, nil, + provider.WithGoMigration( + 2, + &provider.GoMigration{RunNoTx: newDBFn("INSERT INTO users (id) VALUES (1), (2), (3)")}, + &provider.GoMigration{RunNoTx: newDBFn("DELETE FROM users")}, + ), + ) + check.NoError(t, err) + sources := p.ListSources() + check.Number(t, len(p.ListSources()), 2) + assertSource(t, sources[0], provider.TypeGo, "00001_users_table.go", 1) + assertSource(t, sources[1], provider.TypeGo, "", 2) + // Apply migration 1 + res, err := p.UpByOne(ctx) + check.NoError(t, err) + check.Number(t, len(res), 1) + assertResult(t, res[0], provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "up") + check.Number(t, countUser(db), 0) + check.Bool(t, tableExists(t, db, "users"), true) + // Apply migration 2 + res, err = p.UpByOne(ctx) + check.NoError(t, err) + check.Number(t, len(res), 1) + assertResult(t, res[0], provider.NewSource(provider.TypeGo, "", 2), "up") + check.Number(t, countUser(db), 3) + // Rollback migration 2 + res, err = p.Down(ctx) + check.NoError(t, err) + check.Number(t, len(res), 1) + assertResult(t, res[0], provider.NewSource(provider.TypeGo, "", 2), "down") + check.Number(t, countUser(db), 0) + // Rollback migration 1 + res, err = p.Down(ctx) + check.NoError(t, err) + check.Number(t, len(res), 1) + assertResult(t, res[0], provider.NewSource(provider.TypeGo, "00001_users_table.go", 1), "down") + // Check table does not exist + check.Bool(t, tableExists(t, db, "users"), false) + }) +} + +func TestLockModeAdvisorySession(t *testing.T) { + t.Parallel() + if testing.Short() { + t.Skip("skip long running test") + } + + // The migrations are written in such a way that they cannot be applied concurrently, they will + // fail 99.9999% of the time. This test ensures that the advisory session lock mode works as + // expected. + + // TODO(mf): small improvement here is to use the SAME postgres instance but different databases + // created from a template. This will speed up the test. + + db, cleanup, err := testdb.NewPostgres() + check.NoError(t, err) + t.Cleanup(cleanup) + + newProvider := func() *provider.Provider { + sessionLocker, err := lock.NewPostgresSessionLocker() + check.NoError(t, err) + p, err := provider.NewProvider(provider.DialectPostgres, db, os.DirFS("../../testdata/migrations"), + provider.WithSessionLocker(sessionLocker), // Use advisory session lock mode. + provider.WithVerbose(testing.Verbose()), + ) + check.NoError(t, err) + return p + } + provider1 := newProvider() + provider2 := newProvider() + + sources := provider1.ListSources() + maxVersion := sources[len(sources)-1].Version + + // Since the lock mode is advisory session, only one of these providers is expected to apply ALL + // the migrations. The other provider should apply NO migrations. The test MUST fail if both + // providers apply migrations. + + t.Run("up", func(t *testing.T) { + var g errgroup.Group + var res1, res2 int + g.Go(func() error { + ctx := context.Background() + results, err := provider1.Up(ctx) + check.NoError(t, err) + res1 = len(results) + currentVersion, err := provider1.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, currentVersion, maxVersion) + return nil + }) + g.Go(func() error { + ctx := context.Background() + results, err := provider2.Up(ctx) + check.NoError(t, err) + res2 = len(results) + currentVersion, err := provider2.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, currentVersion, maxVersion) + return nil + }) + check.NoError(t, g.Wait()) + // One of the providers should have applied all migrations and the other should have applied + // no migrations, but with no error. + if res1 == 0 && res2 == 0 { + t.Fatal("both providers applied no migrations") + } + if res1 > 0 && res2 > 0 { + t.Fatal("both providers applied migrations") + } + }) + + // Reset the database and run the same test with the advisory lock mode, but apply migrations + // one-by-one. + { + _, err := provider1.DownTo(context.Background(), 0) + check.NoError(t, err) + currentVersion, err := provider1.GetDBVersion(context.Background()) + check.NoError(t, err) + check.Number(t, currentVersion, 0) + } + t.Run("up_by_one", func(t *testing.T) { + var g errgroup.Group + var ( + mu sync.Mutex + applied []int64 + ) + g.Go(func() error { + for { + results, err := provider1.UpByOne(context.Background()) + if err != nil { + if errors.Is(err, provider.ErrNoNextVersion) { + return nil + } + return err + } + check.NoError(t, err) + if len(results) != 1 { + return fmt.Errorf("expected 1 result, got %d", len(results)) + } + mu.Lock() + applied = append(applied, results[0].Source.Version) + mu.Unlock() + } + }) + g.Go(func() error { + for { + results, err := provider2.UpByOne(context.Background()) + if err != nil { + if errors.Is(err, provider.ErrNoNextVersion) { + return nil + } + return err + } + check.NoError(t, err) + if len(results) != 1 { + return fmt.Errorf("expected 1 result, got %d", len(results)) + } + mu.Lock() + applied = append(applied, results[0].Source.Version) + mu.Unlock() + } + }) + check.NoError(t, g.Wait()) + check.Number(t, len(applied), len(sources)) + sort.Slice(applied, func(i, j int) bool { + return applied[i] < applied[j] + }) + // Each migration should have been applied up exactly once. + for i := 0; i < len(sources); i++ { + check.Number(t, applied[i], sources[i].Version) + } + }) + + // Restore the database state by applying all migrations and run the same test with the advisory + // lock mode, but apply down migrations in parallel. + { + _, err := provider1.Up(context.Background()) + check.NoError(t, err) + currentVersion, err := provider1.GetDBVersion(context.Background()) + check.NoError(t, err) + check.Number(t, currentVersion, maxVersion) + } + + t.Run("down_to", func(t *testing.T) { + var g errgroup.Group + var res1, res2 int + g.Go(func() error { + ctx := context.Background() + results, err := provider1.DownTo(ctx, 0) + check.NoError(t, err) + res1 = len(results) + currentVersion, err := provider1.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, currentVersion, 0) + return nil + }) + g.Go(func() error { + ctx := context.Background() + results, err := provider2.DownTo(ctx, 0) + check.NoError(t, err) + res2 = len(results) + currentVersion, err := provider2.GetDBVersion(ctx) + check.NoError(t, err) + check.Number(t, currentVersion, 0) + return nil + }) + check.NoError(t, g.Wait()) + + if res1 == 0 && res2 == 0 { + t.Fatal("both providers applied no migrations") + } + if res1 > 0 && res2 > 0 { + t.Fatal("both providers applied migrations") + } + }) + + // Restore the database state by applying all migrations and run the same test with the advisory + // lock mode, but apply down migrations one-by-one. + { + _, err := provider1.Up(context.Background()) + check.NoError(t, err) + currentVersion, err := provider1.GetDBVersion(context.Background()) + check.NoError(t, err) + check.Number(t, currentVersion, maxVersion) + } + + t.Run("down_by_one", func(t *testing.T) { + var g errgroup.Group + var ( + mu sync.Mutex + applied []int64 + ) + g.Go(func() error { + for { + results, err := provider1.Down(context.Background()) + if err != nil { + if errors.Is(err, provider.ErrNoNextVersion) { + return nil + } + return err + } + if len(results) != 1 { + return fmt.Errorf("expected 1 result, got %d", len(results)) + } + check.NoError(t, err) + mu.Lock() + applied = append(applied, results[0].Source.Version) + mu.Unlock() + } + }) + g.Go(func() error { + for { + results, err := provider2.Down(context.Background()) + if err != nil { + if errors.Is(err, provider.ErrNoNextVersion) { + return nil + } + return err + } + if len(results) != 1 { + return fmt.Errorf("expected 1 result, got %d", len(results)) + } + check.NoError(t, err) + mu.Lock() + applied = append(applied, results[0].Source.Version) + mu.Unlock() + } + }) + check.NoError(t, g.Wait()) + check.Number(t, len(applied), len(sources)) + sort.Slice(applied, func(i, j int) bool { + return applied[i] < applied[j] + }) + // Each migration should have been applied down exactly once. Since this is sequential the + // applied down migrations should be in reverse order. + for i := len(sources) - 1; i >= 0; i-- { + check.Number(t, applied[i], sources[i].Version) + } + }) +} + +func newDBFn(query string) func(context.Context, *sql.DB) error { + return func(ctx context.Context, db *sql.DB) error { + _, err := db.ExecContext(ctx, query) + return err + } +} + +func newTxFn(query string) func(context.Context, *sql.Tx) error { + return func(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, query) + return err + } +} + +func tableExists(t *testing.T, db *sql.DB, table string) bool { + q := fmt.Sprintf(`SELECT CASE WHEN COUNT(*) > 0 THEN 1 ELSE 0 END AS table_exists FROM sqlite_master WHERE type = 'table' AND name = '%s'`, table) + var b string + err := db.QueryRow(q).Scan(&b) + check.NoError(t, err) + return b == "1" +} + +const ( + charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" +) + +func randomAlphaNumeric(length int) string { + b := make([]byte, length) + for i := range b { + b[i] = charset[rand.Intn(len(charset))] + } + return string(b) +} + +func newProviderWithDB(t *testing.T, opts ...provider.ProviderOption) (*provider.Provider, *sql.DB) { + t.Helper() + db := newDB(t) + opts = append( + opts, + provider.WithVerbose(testing.Verbose()), + ) + p, err := provider.NewProvider(provider.DialectSQLite3, db, newFsys(), opts...) + check.NoError(t, err) + return p, db +} + +func newDB(t *testing.T) *sql.DB { + t.Helper() + dbName := fmt.Sprintf("test_%s.db", randomAlphaNumeric(8)) + db, err := sql.Open("sqlite", filepath.Join(t.TempDir(), dbName)) + check.NoError(t, err) + return db +} + +func getMaxVersionID(db *sql.DB, gooseTable string) (int64, error) { + var gotVersion int64 + if err := db.QueryRow( + fmt.Sprintf("select max(version_id) from %s", gooseTable), + ).Scan(&gotVersion); err != nil { + return 0, err + } + return gotVersion, nil +} + +func getTableNames(db *sql.DB) ([]string, error) { + rows, err := db.Query(`SELECT name FROM sqlite_master WHERE type='table' ORDER BY name`) + if err != nil { + return nil, err + } + defer rows.Close() + var tables []string + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return nil, err + } + tables = append(tables, name) + } + if err := rows.Err(); err != nil { + return nil, err + } + return tables, nil +} + +func assertPartialError(t *testing.T, got error) { + t.Helper() + var e *provider.PartialError + check.Bool(t, errors.As(got, &e), true) +} + +func assertStatus(t *testing.T, got *provider.MigrationStatus, state provider.State, source provider.Source, appliedIsZero bool) { + t.Helper() + check.Equal(t, got.State, state) + check.Equal(t, got.Source, source) + check.Bool(t, got.AppliedAt.IsZero(), appliedIsZero) +} + +func assertResult(t *testing.T, got *provider.MigrationResult, source provider.Source, direction string) { + t.Helper() + check.Equal(t, got.Source, source) + check.Equal(t, got.Direction, direction) + check.Equal(t, got.Empty, false) + check.Bool(t, got.Error == nil, true) + check.Bool(t, got.Duration > 0, true) +} + +func assertSource(t *testing.T, got provider.Source, typ provider.MigrationType, name string, version int64) { + t.Helper() + check.Equal(t, got.Type, typ) + check.Equal(t, got.Fullpath, name) + check.Equal(t, got.Version, version) + switch got.Type { + case provider.TypeGo: + check.Equal(t, got.Type.String(), "go") + case provider.TypeSQL: + check.Equal(t, got.Type.String(), "sql") + } +} + +func newMapFile(data string) *fstest.MapFile { + return &fstest.MapFile{ + Data: []byte(data), + } +} + +func newFsys() fs.FS { + return fstest.MapFS{ + "00001_users_table.sql": newMapFile(runMigration1), + "00002_posts_table.sql": newMapFile(runMigration2), + "00003_comments_table.sql": newMapFile(runMigration3), + "00004_insert_data.sql": newMapFile(runMigration4), + "00005_posts_view.sql": newMapFile(runMigration5), + "00006_empty_up.sql": newMapFile(runMigration6), + "00007_empty_up_down.sql": newMapFile(runMigration7), + } +} + +var ( + + // known tables are the tables (including goose table) created by running all migration files. + // If you add a table, make sure to add to this list and keep it in order. + knownTables = []string{ + "comments", + "goose_db_version", + "posts", + "sqlite_sequence", + "users", + } + + runMigration1 = ` +-- +goose Up +CREATE TABLE users ( + id INTEGER PRIMARY KEY, + username TEXT NOT NULL, + email TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +-- +goose Down +DROP TABLE users; +` + + runMigration2 = ` +-- +goose Up +-- +goose StatementBegin +CREATE TABLE posts ( + id INTEGER PRIMARY KEY, + title TEXT NOT NULL, + content TEXT NOT NULL, + author_id INTEGER NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (author_id) REFERENCES users(id) +); +-- +goose StatementEnd +SELECT 1; +SELECT 2; + +-- +goose Down +DROP TABLE posts; +` + + runMigration3 = ` +-- +goose Up +CREATE TABLE comments ( + id INTEGER PRIMARY KEY, + post_id INTEGER NOT NULL, + user_id INTEGER NOT NULL, + content TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (post_id) REFERENCES posts(id), + FOREIGN KEY (user_id) REFERENCES users(id) +); + +-- +goose Down +DROP TABLE comments; +SELECT 1; +SELECT 2; +SELECT 3; +` + + runMigration4 = ` +-- +goose Up +INSERT INTO users (id, username, email) +VALUES + (1, 'john_doe', 'john@example.com'), + (2, 'jane_smith', 'jane@example.com'), + (3, 'alice_wonderland', 'alice@example.com'); + +INSERT INTO posts (id, title, content, author_id) +VALUES + (1, 'Introduction to SQL', 'SQL is a powerful language for managing databases...', 1), + (2, 'Data Modeling Techniques', 'Choosing the right data model is crucial...', 2), + (3, 'Advanced Query Optimization', 'Optimizing queries can greatly improve...', 1); + +INSERT INTO comments (id, post_id, user_id, content) +VALUES + (1, 1, 3, 'Great introduction! Looking forward to more.'), + (2, 1, 2, 'SQL can be a bit tricky at first, but practice helps.'), + (3, 2, 1, 'You covered normalization really well in this post.'); + +-- +goose Down +DELETE FROM comments; +DELETE FROM posts; +DELETE FROM users; +` + + runMigration5 = ` +-- +goose NO TRANSACTION + +-- +goose Up +CREATE VIEW posts_view AS + SELECT + p.id, + p.title, + p.content, + p.created_at, + u.username AS author + FROM posts p + JOIN users u ON p.author_id = u.id; + +-- +goose Down +DROP VIEW posts_view; +` + + runMigration6 = ` +-- +goose Up +` + + runMigration7 = ` +-- +goose Up +-- +goose Down +` +) diff --git a/internal/provider/run_up.go b/internal/provider/run_up.go new file mode 100644 index 000000000..7ee9c6c4f --- /dev/null +++ b/internal/provider/run_up.go @@ -0,0 +1,96 @@ +package provider + +import ( + "context" + "errors" + "fmt" + "strings" + + "github.com/pressly/goose/v3/internal/sqlparser" + "go.uber.org/multierr" +) + +func (p *Provider) up(ctx context.Context, upByOne bool, version int64) (_ []*MigrationResult, retErr error) { + if version < 1 { + return nil, errors.New("version must be greater than zero") + } + + conn, cleanup, err := p.initialize(ctx) + if err != nil { + return nil, err + } + defer func() { + retErr = multierr.Append(retErr, cleanup()) + }() + + if len(p.migrations) == 0 { + return nil, nil + } + if p.cfg.noVersioning { + // Short circuit if versioning is disabled and apply all migrations. + return p.runMigrations(ctx, conn, p.migrations, sqlparser.DirectionUp, upByOne) + } + + // optimize(mf): Listing all migrations from the database isn't great. This is only required to + // support the out-of-order (allow missing) feature. For users who don't use this feature, we + // could just query the database for the current version and then apply migrations that are + // greater than that version. + dbMigrations, err := p.store.ListMigrations(ctx, conn) + if err != nil { + return nil, err + } + dbMaxVersion := dbMigrations[0].Version + // lookupAppliedInDB is a map of all applied migrations in the database. + lookupAppliedInDB := make(map[int64]bool) + for _, m := range dbMigrations { + lookupAppliedInDB[m.Version] = true + } + + missingMigrations := findMissingMigrations(dbMigrations, p.migrations, dbMaxVersion) + + // feature(mf): It is very possible someone may want to apply ONLY new migrations and skip + // missing migrations entirely. At the moment this is not supported, but leaving this comment + // because that's where that logic will be handled. + if len(missingMigrations) > 0 && !p.cfg.allowMissing { + var collected []string + for _, v := range missingMigrations { + collected = append(collected, v.filename) + } + msg := "migration" + if len(collected) > 1 { + msg += "s" + } + return nil, fmt.Errorf("found %d missing (out-of-order) %s: [%s]", + len(missingMigrations), msg, strings.Join(collected, ",")) + } + + var migrationsToApply []*migration + if p.cfg.allowMissing { + for _, v := range missingMigrations { + m, err := p.getMigration(v.versionID) + if err != nil { + return nil, err + } + migrationsToApply = append(migrationsToApply, m) + } + } + // filter all migrations with a version greater than the supplied version (min) and less than or + // equal to the requested version (max). + for _, m := range p.migrations { + if lookupAppliedInDB[m.Source.Version] { + continue + } + if m.Source.Version > dbMaxVersion && m.Source.Version <= version { + migrationsToApply = append(migrationsToApply, m) + } + } + + // feat(mf): this is where can (optionally) group multiple migrations to be run in a single + // transaction. The default is to apply each migration sequentially on its own. + // https://github.com/pressly/goose/issues/222 + // + // Note, we can't use a single transaction for all migrations because some may have to be run in + // their own transaction. + + return p.runMigrations(ctx, conn, migrationsToApply, sqlparser.DirectionUp, upByOne) +} diff --git a/internal/provider/testdata/no-versioning/migrations/00001_a.sql b/internal/provider/testdata/no-versioning/migrations/00001_a.sql new file mode 100644 index 000000000..839cb7a7b --- /dev/null +++ b/internal/provider/testdata/no-versioning/migrations/00001_a.sql @@ -0,0 +1,8 @@ +-- +goose Up +CREATE TABLE owners ( + owner_id INTEGER PRIMARY KEY AUTOINCREMENT, + owner_name TEXT NOT NULL +); + +-- +goose Down +DROP TABLE IF EXISTS owners; diff --git a/internal/provider/testdata/no-versioning/migrations/00002_b.sql b/internal/provider/testdata/no-versioning/migrations/00002_b.sql new file mode 100644 index 000000000..bd15ef51c --- /dev/null +++ b/internal/provider/testdata/no-versioning/migrations/00002_b.sql @@ -0,0 +1,9 @@ +-- +goose Up +-- +goose StatementBegin +INSERT INTO owners(owner_name) VALUES ('lucas'), ('ocean'); +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +DELETE FROM owners; +-- +goose StatementEnd diff --git a/internal/provider/testdata/no-versioning/migrations/00003_c.sql b/internal/provider/testdata/no-versioning/migrations/00003_c.sql new file mode 100644 index 000000000..422fb3068 --- /dev/null +++ b/internal/provider/testdata/no-versioning/migrations/00003_c.sql @@ -0,0 +1,9 @@ +-- +goose Up +-- +goose StatementBegin +INSERT INTO owners(owner_name) VALUES ('james'), ('space'); +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +DELETE FROM owners WHERE owner_name IN ('james', 'space'); +-- +goose StatementEnd diff --git a/internal/provider/testdata/no-versioning/seed/00001_a.sql b/internal/provider/testdata/no-versioning/seed/00001_a.sql new file mode 100644 index 000000000..64f9ff03c --- /dev/null +++ b/internal/provider/testdata/no-versioning/seed/00001_a.sql @@ -0,0 +1,17 @@ +-- +goose Up +-- +goose StatementBegin +-- Insert 100 owners. +INSERT INTO owners (owner_name) +WITH numbers AS ( + SELECT 1 AS n + UNION ALL + SELECT n + 1 FROM numbers WHERE n < 100 +) +SELECT 'seed-user-' || n FROM numbers; +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +-- Delete the previously inserted data. +DELETE FROM owners WHERE owner_name LIKE 'seed-user-%'; +-- +goose StatementEnd diff --git a/internal/provider/testdata/no-versioning/seed/00002_b.sql b/internal/provider/testdata/no-versioning/seed/00002_b.sql new file mode 100644 index 000000000..aafe82752 --- /dev/null +++ b/internal/provider/testdata/no-versioning/seed/00002_b.sql @@ -0,0 +1,15 @@ +-- +goose Up + +-- Insert 150 more owners. +INSERT INTO owners (owner_name) +WITH numbers AS ( + SELECT 101 AS n + UNION ALL + SELECT n + 1 FROM numbers WHERE n < 250 +) +SELECT 'seed-user-' || n FROM numbers; + +-- +goose Down + +-- NOTE: there are 4 migration owners and 100 seed owners, that's why owner_id starts at 105 +DELETE FROM owners WHERE owner_name LIKE 'seed-user-%' AND owner_id BETWEEN 105 AND 254; diff --git a/internal/provider/types.go b/internal/provider/types.go new file mode 100644 index 000000000..21bb18beb --- /dev/null +++ b/internal/provider/types.go @@ -0,0 +1,99 @@ +package provider + +import ( + "fmt" + "time" +) + +// Dialect is the type of database dialect. +type Dialect string + +const ( + DialectClickHouse Dialect = "clickhouse" + DialectMSSQL Dialect = "mssql" + DialectMySQL Dialect = "mysql" + DialectPostgres Dialect = "postgres" + DialectRedshift Dialect = "redshift" + DialectSQLite3 Dialect = "sqlite3" + DialectTiDB Dialect = "tidb" + DialectVertica Dialect = "vertica" +) + +// MigrationType is the type of migration. +type MigrationType int + +const ( + TypeGo MigrationType = iota + 1 + TypeSQL +) + +func (t MigrationType) String() string { + switch t { + case TypeGo: + return "go" + case TypeSQL: + return "sql" + default: + // This should never happen. + return fmt.Sprintf("unknown (%d)", t) + } +} + +// Source represents a single migration source. +// +// For SQL migrations, Fullpath will always be set. For Go migrations, Fullpath will will be set if +// the migration has a corresponding file on disk. It will be empty if the migration was registered +// manually. +type Source struct { + // Type is the type of migration. + Type MigrationType + // Full path to the migration file. + // + // Example: /path/to/migrations/001_create_users_table.sql + Fullpath string + // Version is the version of the migration. + Version int64 +} + +// MigrationResult is the result of a single migration operation. +// +// Note, the caller is responsible for checking the Error field for any errors that occurred while +// running the migration. If the Error field is not nil, the migration failed. +type MigrationResult struct { + Source Source + Duration time.Duration + Direction string + // Empty is true if the file was valid, but no statements to apply. These are still versioned + // migrations, but typically have no effect on the database. + // + // For SQL migrations, this means there was a valid .sql file but contained no statements. For + // Go migrations, this means the function was nil. + Empty bool + + // Error is any error that occurred while running the migration. + Error error +} + +// State represents the state of a migration. +type State string + +const ( + // StatePending represents a migration that is on the filesystem, but not in the database. + StatePending State = "pending" + // StateApplied represents a migration that is in BOTH the database and on the filesystem. + StateApplied State = "applied" + + // StateUntracked represents a migration that is in the database, but not on the filesystem. + // StateUntracked State = "untracked" +) + +// MigrationStatus represents the status of a single migration. +type MigrationStatus struct { + // State is the state of the migration. + State State + // AppliedAt is the time the migration was applied. Only set if state is [StateApplied] or + // [StateUntracked]. + AppliedAt time.Time + // Source is the migration source. Only set if the state is [StatePending] or [StateApplied]. + Source Source +} diff --git a/internal/sqlparser/parse.go b/internal/sqlparser/parse.go index e993587a6..b42fdde14 100644 --- a/internal/sqlparser/parse.go +++ b/internal/sqlparser/parse.go @@ -1,6 +1,7 @@ package sqlparser import ( + "fmt" "io/fs" "go.uber.org/multierr" @@ -50,5 +51,9 @@ func parse(fsys fs.FS, filename string, direction Direction, debug bool) (_ []st defer func() { retErr = multierr.Append(retErr, r.Close()) }() - return ParseSQLMigration(r, direction, debug) + stmts, useTx, err := ParseSQLMigration(r, direction, debug) + if err != nil { + return nil, false, fmt.Errorf("failed to parse %s: %w", filename, err) + } + return stmts, useTx, nil } diff --git a/testdata/migrations/00002_posts_table.sql b/testdata/migrations/00002_posts_table.sql index 25648ed42..be70a2348 100644 --- a/testdata/migrations/00002_posts_table.sql +++ b/testdata/migrations/00002_posts_table.sql @@ -1,4 +1,5 @@ -- +goose Up +-- +goose StatementBegin CREATE TABLE posts ( id INTEGER PRIMARY KEY, title TEXT NOT NULL, @@ -7,6 +8,7 @@ CREATE TABLE posts ( created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (author_id) REFERENCES users(id) ); +-- +goose StatementEnd -- +goose Down DROP TABLE posts; From 11d9c936e0ee38ba378eaa0e825832b8bbcacb31 Mon Sep 17 00:00:00 2001 From: Mike Fridman Date: Mon, 16 Oct 2023 22:01:47 -0400 Subject: [PATCH 2/2] lint --- internal/provider/provider.go | 2 +- internal/provider/run.go | 2 +- internal/provider/run_test.go | 6 ------ 3 files changed, 2 insertions(+), 8 deletions(-) diff --git a/internal/provider/provider.go b/internal/provider/provider.go index 4174a743a..3982ac37b 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -79,10 +79,10 @@ func NewProvider(dialect Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption if _, ok := registered[version]; ok { return nil, fmt.Errorf("go migration with version %d already registered", version) } - g := newGoMigration(m.Source, nil, nil) if m == nil { return nil, errors.New("registered migration with nil init function") } + g := newGoMigration(m.Source, nil, nil) if m.UpFnContext != nil && m.UpFnNoTxContext != nil { return nil, errors.New("registered migration with both UpFnContext and UpFnNoTxContext") } diff --git a/internal/provider/run.go b/internal/provider/run.go index 893be7d1a..55bef9f32 100644 --- a/internal/provider/run.go +++ b/internal/provider/run.go @@ -124,7 +124,7 @@ func (p *Provider) runIndividually( // Note, we're using *sql.DB instead of *sql.Conn because it's the contract of the // GoMigrationNoTx function. This may be a deadlock scenario if the caller sets max open // connections to 1. See the comment in runMigrations for more details. - if err := m.Go.runNoTx(ctx, p.db, direction); err != nil { + if err := m.runNoTx(ctx, p.db, direction); err != nil { return err } case TypeSQL: diff --git a/internal/provider/run_test.go b/internal/provider/run_test.go index 56644ca04..97e86ed21 100644 --- a/internal/provider/run_test.go +++ b/internal/provider/run_test.go @@ -1117,12 +1117,6 @@ func getTableNames(db *sql.DB) ([]string, error) { return tables, nil } -func assertPartialError(t *testing.T, got error) { - t.Helper() - var e *provider.PartialError - check.Bool(t, errors.As(got, &e), true) -} - func assertStatus(t *testing.T, got *provider.MigrationStatus, state provider.State, source provider.Source, appliedIsZero bool) { t.Helper() check.Equal(t, got.State, state)