Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
mfridman committed Oct 24, 2023
1 parent 0030c49 commit 1b052ec
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 168 deletions.
9 changes: 8 additions & 1 deletion internal/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ func NewProvider(dialect Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption
}
cfg := config{
registered: make(map[int64]*goMigration),
excludes: make(map[string]bool),
}
for _, opt := range opts {
if err := opt.apply(&cfg); err != nil {
Expand Down Expand Up @@ -183,6 +184,9 @@ func (p *Provider) Close() error {
// When direction is true, the up migration is executed, and when direction is false, the down
// migration is executed.
func (p *Provider) ApplyVersion(ctx context.Context, version int64, direction bool) (*MigrationResult, error) {
if version < 1 {
return nil, fmt.Errorf("invalid version: must be greater than zero: %d", version)
}
return p.apply(ctx, version, direction)
}

Expand Down Expand Up @@ -215,6 +219,9 @@ func (p *Provider) UpByOne(ctx context.Context) (*MigrationResult, error) {
// For instance, if there are three new migrations (9,10,11) and the current database version is 8
// with a requested version of 10, only versions 9,10 will be applied.
func (p *Provider) UpTo(ctx context.Context, version int64) ([]*MigrationResult, error) {
if version < 1 {
return nil, fmt.Errorf("invalid version: must be greater than zero: %d", version)
}
return p.up(ctx, false, version)
}

Expand All @@ -240,7 +247,7 @@ func (p *Provider) Down(ctx context.Context) (*MigrationResult, error) {
// migrations 11, 10 will be rolled back.
func (p *Provider) DownTo(ctx context.Context, version int64) ([]*MigrationResult, error) {
if version < 0 {
return nil, fmt.Errorf("version must be a number greater than or equal zero: %d", version)
return nil, fmt.Errorf("invalid version: must be a valid number or zero: %d", version)
}
return p.down(ctx, false, version)
}
145 changes: 141 additions & 4 deletions internal/provider/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"io/fs"
"sort"
"strings"
"time"

"github.com/pressly/goose/v3/internal/sqladapter"
Expand All @@ -18,6 +19,146 @@ var (
errMissingZeroVersion = errors.New("missing zero version migration")
)

func (p *Provider) up(ctx context.Context, upByOne bool, version int64) (_ []*MigrationResult, retErr error) {
if version < 1 {
return nil, errors.New("version must be greater than zero")
}
conn, cleanup, err := p.initialize(ctx)
if err != nil {
return nil, err
}
defer func() {
retErr = multierr.Append(retErr, cleanup())
}()
if len(p.migrations) == 0 {
return nil, nil
}
var apply []*migration
if p.cfg.noVersioning {
apply = p.migrations
} else {
// optimize(mf): Listing all migrations from the database isn't great. This is only required to
// support the allow missing (out-of-order) feature. For users that don't use this feature, we
// could just query the database for the current max version and then apply migrations greater
// than that version.
dbMigrations, err := p.store.ListMigrations(ctx, conn)
if err != nil {
return nil, err
}
if len(dbMigrations) == 0 {
return nil, errMissingZeroVersion
}
apply, err = p.resolveUpMigrations(dbMigrations, version)
if err != nil {
return nil, err
}
}
// feat(mf): this is where can (optionally) group multiple migrations to be run in a single
// transaction. The default is to apply each migration sequentially on its own.
// https://github.com/pressly/goose/issues/222
//
// Careful, we can't use a single transaction for all migrations because some may have to be run
// in their own transaction.
return p.runMigrations(ctx, conn, apply, sqlparser.DirectionUp, upByOne)
}

func (p *Provider) resolveUpMigrations(
dbVersions []*sqladapter.ListMigrationsResult,
version int64,
) ([]*migration, error) {
var apply []*migration
var dbMaxVersion int64
// dbAppliedVersions is a map of all applied migrations in the database.
dbAppliedVersions := make(map[int64]bool, len(dbVersions))
for _, m := range dbVersions {
dbAppliedVersions[m.Version] = true
if m.Version > dbMaxVersion {
dbMaxVersion = m.Version
}
}
missingMigrations := findMissingMigrations(dbVersions, p.migrations)
// feat(mf): It is very possible someone may want to apply ONLY new migrations and skip missing
// migrations entirely. At the moment this is not supported, but leaving this comment because
// that's where that logic would be handled.
//
// For example, if db has 1,4 applied and 2,3,5 are new, we would apply only 5 and skip 2,3.
// Not sure if this is a common use case, but it's possible.
if len(missingMigrations) > 0 && !p.cfg.allowMissing {
var collected []string
for _, v := range missingMigrations {
collected = append(collected, v.filename)
}
msg := "migration"
if len(collected) > 1 {
msg += "s"
}
return nil, fmt.Errorf("found %d missing (out-of-order) %s lower than current max (%d): [%s]",
len(missingMigrations), msg, dbMaxVersion, strings.Join(collected, ","),
)
}
for _, v := range missingMigrations {
m, err := p.getMigration(v.versionID)
if err != nil {
return nil, err
}
apply = append(apply, m)
}
// filter all migrations with a version greater than the supplied version (min) and less than or
// equal to the requested version (max). Skip any migrations that have already been applied.
for _, m := range p.migrations {
if dbAppliedVersions[m.Source.Version] {
continue
}
if m.Source.Version > dbMaxVersion && m.Source.Version <= version {
apply = append(apply, m)
}
}
return apply, nil
}

func (p *Provider) down(ctx context.Context, downByOne bool, version int64) (_ []*MigrationResult, retErr error) {
conn, cleanup, err := p.initialize(ctx)
if err != nil {
return nil, err
}
defer func() {
retErr = multierr.Append(retErr, cleanup())
}()
if len(p.migrations) == 0 {
return nil, nil
}
if p.cfg.noVersioning {
downMigrations := p.migrations
if downByOne {
last := p.migrations[len(p.migrations)-1]
downMigrations = []*migration{last}
}
return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, downByOne)
}
dbMigrations, err := p.store.ListMigrations(ctx, conn)
if err != nil {
return nil, err
}
if len(dbMigrations) == 0 {
return nil, errMissingZeroVersion
}
if dbMigrations[0].Version == 0 {
return nil, nil
}
var downMigrations []*migration
for _, dbMigration := range dbMigrations {
if dbMigration.Version <= version {
break
}
m, err := p.getMigration(dbMigration.Version)
if err != nil {
return nil, err
}
downMigrations = append(downMigrations, m)
}
return p.runMigrations(ctx, conn, downMigrations, sqlparser.DirectionDown, downByOne)
}

// runMigrations runs migrations sequentially in the given direction.
//
// If the migrations list is empty, return nil without error.
Expand Down Expand Up @@ -277,10 +418,6 @@ func (p *Provider) getMigration(version int64) (*migration, error) {
}

func (p *Provider) apply(ctx context.Context, version int64, direction bool) (_ *MigrationResult, retErr error) {
if version < 1 {
return nil, errors.New("version must be greater than zero")
}

m, err := p.getMigration(version)
if err != nil {
return nil, err
Expand Down
51 changes: 0 additions & 51 deletions internal/provider/run_down.go

This file was deleted.

6 changes: 3 additions & 3 deletions internal/provider/run_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@ func TestProviderRun(t *testing.T) {
p, _ := newProviderWithDB(t)
_, err := p.UpTo(context.Background(), 0)
check.HasError(t, err)
check.Equal(t, err.Error(), "version must be greater than zero")
check.Equal(t, err.Error(), "invalid version: must be greater than zero: 0")
_, err = p.DownTo(context.Background(), -1)
check.HasError(t, err)
check.Equal(t, err.Error(), "version must be a number greater than or equal zero: -1")
check.Equal(t, err.Error(), "invalid version: must be a valid number or zero: -1")
_, err = p.ApplyVersion(context.Background(), 0, true)
check.HasError(t, err)
check.Equal(t, err.Error(), "version must be greater than zero")
check.Equal(t, err.Error(), "invalid version: must be greater than zero: 0")
})
t.Run("up_and_down_all", func(t *testing.T) {
ctx := context.Background()
Expand Down
109 changes: 0 additions & 109 deletions internal/provider/run_up.go

This file was deleted.

0 comments on commit 1b052ec

Please sign in to comment.