Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(experimental): add migration logic with tests #617

Merged
merged 2 commits into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 45 additions & 26 deletions internal/provider/collect.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,14 @@ import (
"errors"
"fmt"
"io/fs"
"os"
"path/filepath"
"sort"
"strconv"
"strings"

"github.com/pressly/goose/v3"
)

// Source represents a single migration source.
//
// For SQL migrations, Fullpath will always be set. For Go migrations, Fullpath will will be set if
// the migration has a corresponding file on disk. It will be empty if the migration was registered
// manually.
type Source struct {
// Type is the type of migration.
Type MigrationType
// Full path to the migration file.
//
// Example: /path/to/migrations/001_create_users_table.sql
Fullpath string
// Version is the version of the migration.
Version int64
}

func newSource(t MigrationType, fullpath string, version int64) Source {
func NewSource(t MigrationType, fullpath string, version int64) Source {
return Source{
Type: t,
Fullpath: fullpath,
Expand All @@ -41,6 +25,7 @@ type fileSources struct {
goSources []Source
}

// TODO(mf): remove?
func (s *fileSources) lookup(t MigrationType, version int64) *Source {
switch t {
case TypeGo:
Expand Down Expand Up @@ -93,7 +78,7 @@ func collectFileSources(fsys fs.FS, strict bool, excludes map[string]bool) (*fil
// filenames, but still have versioned migrations within the same directory. For
// example, a user could have a helpers.go file which contains unexported helper
// functions for migrations.
version, err := goose.NumericComponent(base)
version, err := NumericComponent(base)
if err != nil {
if strict {
return nil, fmt.Errorf("failed to parse numeric component from %q: %w", base, err)
Expand All @@ -110,9 +95,9 @@ func collectFileSources(fsys fs.FS, strict bool, excludes map[string]bool) (*fil
}
switch filepath.Ext(base) {
case ".sql":
sources.sqlSources = append(sources.sqlSources, newSource(TypeSQL, fullpath, version))
sources.sqlSources = append(sources.sqlSources, NewSource(TypeSQL, fullpath, version))
case ".go":
sources.goSources = append(sources.goSources, newSource(TypeGo, fullpath, version))
sources.goSources = append(sources.goSources, NewSource(TypeGo, fullpath, version))
default:
// Should never happen since we already filtered out all other file types.
return nil, fmt.Errorf("unknown migration type: %s", base)
Expand Down Expand Up @@ -161,12 +146,15 @@ func merge(sources *fileSources, registerd map[int64]*goMigration) ([]*migration
// wholesale as part of migrations. This allows users to build a custom binary that only embeds
// the SQL migration files.
for version, r := range registerd {
var fullpath string
if s := sources.lookup(TypeGo, version); s != nil {
fullpath = s.Fullpath
fullpath := r.fullpath
if fullpath == "" {
if s := sources.lookup(TypeGo, version); s != nil {
fullpath = s.Fullpath
}
}
// Ensure there are no duplicate versions.
if existing, ok := migrationLookup[version]; ok {
fullpath := r.fullpath
if fullpath == "" {
fullpath = "manually registered (no source)"
}
Expand All @@ -178,7 +166,7 @@ func merge(sources *fileSources, registerd map[int64]*goMigration) ([]*migration
}
m := &migration{
// Note, the fullpath may be empty if the migration was registered manually.
Source: newSource(TypeGo, fullpath, version),
Source: NewSource(TypeGo, fullpath, version),
Go: r,
}
migrations = append(migrations, m)
Expand Down Expand Up @@ -211,3 +199,34 @@ func unregisteredError(unregistered []string) error {

return errors.New(b.String())
}

type noopFS struct{}

var _ fs.FS = noopFS{}

func (f noopFS) Open(name string) (fs.File, error) {
return nil, os.ErrNotExist
}

// NumericComponent parses the version from the migration file name.
//
// XXX_descriptivename.ext where XXX specifies the version number and ext specifies the type of
// migration, either .sql or .go.
func NumericComponent(filename string) (int64, error) {
base := filepath.Base(filename)
if ext := filepath.Ext(base); ext != ".go" && ext != ".sql" {
return 0, errors.New("migration file does not have .sql or .go file extension")
}
idx := strings.Index(base, "_")
if idx < 0 {
return 0, errors.New("no filename separator '_' found")
}
n, err := strconv.ParseInt(base[:idx], 10, 64)
if err != nil {
return 0, err
}
if n < 1 {
return 0, errors.New("migration version must be greater than zero")
}
return n, nil
}
70 changes: 35 additions & 35 deletions internal/provider/collect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ func TestCollectFileSources(t *testing.T) {
check.Number(t, len(sources.goSources), 0)
expected := fileSources{
sqlSources: []Source{
newSource(TypeSQL, "00001_foo.sql", 1),
newSource(TypeSQL, "00002_bar.sql", 2),
newSource(TypeSQL, "00003_baz.sql", 3),
newSource(TypeSQL, "00110_qux.sql", 110),
NewSource(TypeSQL, "00001_foo.sql", 1),
NewSource(TypeSQL, "00002_bar.sql", 2),
NewSource(TypeSQL, "00003_baz.sql", 3),
NewSource(TypeSQL, "00110_qux.sql", 110),
},
}
for i := 0; i < len(sources.sqlSources); i++ {
Expand All @@ -74,8 +74,8 @@ func TestCollectFileSources(t *testing.T) {
check.Number(t, len(sources.goSources), 0)
expected := fileSources{
sqlSources: []Source{
newSource(TypeSQL, "00001_foo.sql", 1),
newSource(TypeSQL, "00003_baz.sql", 3),
NewSource(TypeSQL, "00001_foo.sql", 1),
NewSource(TypeSQL, "00003_baz.sql", 3),
},
}
for i := 0; i < len(sources.sqlSources); i++ {
Expand Down Expand Up @@ -159,15 +159,15 @@ func TestCollectFileSources(t *testing.T) {
}
}
assertDirpath(".", []Source{
newSource(TypeSQL, "876_a.sql", 876),
NewSource(TypeSQL, "876_a.sql", 876),
})
assertDirpath("dir1", []Source{
newSource(TypeSQL, "101_a.sql", 101),
newSource(TypeSQL, "102_b.sql", 102),
newSource(TypeSQL, "103_c.sql", 103),
NewSource(TypeSQL, "101_a.sql", 101),
NewSource(TypeSQL, "102_b.sql", 102),
NewSource(TypeSQL, "103_c.sql", 103),
})
assertDirpath("dir2", []Source{
newSource(TypeSQL, "201_a.sql", 201),
NewSource(TypeSQL, "201_a.sql", 201),
})
assertDirpath("dir3", nil)
})
Expand Down Expand Up @@ -199,14 +199,14 @@ func TestMerge(t *testing.T) {

t.Run("valid", func(t *testing.T) {
migrations, err := merge(sources, map[int64]*goMigration{
2: {version: 2},
3: {version: 3},
2: newGoMigration("", nil, nil),
3: newGoMigration("", nil, nil),
})
check.NoError(t, err)
check.Number(t, len(migrations), 3)
assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1))
assertMigration(t, migrations[1], newSource(TypeGo, "00002_bar.go", 2))
assertMigration(t, migrations[2], newSource(TypeGo, "00003_baz.go", 3))
assertMigration(t, migrations[0], NewSource(TypeSQL, "00001_foo.sql", 1))
assertMigration(t, migrations[1], NewSource(TypeGo, "00002_bar.go", 2))
assertMigration(t, migrations[2], NewSource(TypeGo, "00003_baz.go", 3))
})
t.Run("unregistered_all", func(t *testing.T) {
_, err := merge(sources, nil)
Expand All @@ -217,17 +217,17 @@ func TestMerge(t *testing.T) {
})
t.Run("unregistered_some", func(t *testing.T) {
_, err := merge(sources, map[int64]*goMigration{
2: {version: 2},
2: newGoMigration("", nil, nil),
})
check.HasError(t, err)
check.Contains(t, err.Error(), "error: detected 1 unregistered Go file")
check.Contains(t, err.Error(), "00003_baz.go")
})
t.Run("duplicate_sql", func(t *testing.T) {
_, err := merge(sources, map[int64]*goMigration{
1: {version: 1}, // duplicate. SQL already exists.
2: {version: 2},
3: {version: 3},
1: newGoMigration("", nil, nil), // duplicate. SQL already exists.
2: newGoMigration("", nil, nil),
3: newGoMigration("", nil, nil),
})
check.HasError(t, err)
check.Contains(t, err.Error(), "found duplicate migration version 1")
Expand All @@ -246,17 +246,17 @@ func TestMerge(t *testing.T) {
check.NoError(t, err)
t.Run("unregistered_all", func(t *testing.T) {
migrations, err := merge(sources, map[int64]*goMigration{
3: {version: 3},
3: newGoMigration("", nil, nil),
// 4 is missing
6: {version: 6},
6: newGoMigration("", nil, nil),
})
check.NoError(t, err)
check.Number(t, len(migrations), 5)
assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1))
assertMigration(t, migrations[1], newSource(TypeSQL, "00002_bar.sql", 2))
assertMigration(t, migrations[2], newSource(TypeGo, "", 3))
assertMigration(t, migrations[3], newSource(TypeSQL, "00005_baz.sql", 5))
assertMigration(t, migrations[4], newSource(TypeGo, "", 6))
assertMigration(t, migrations[0], NewSource(TypeSQL, "00001_foo.sql", 1))
assertMigration(t, migrations[1], NewSource(TypeSQL, "00002_bar.sql", 2))
assertMigration(t, migrations[2], NewSource(TypeGo, "", 3))
assertMigration(t, migrations[3], NewSource(TypeSQL, "00005_baz.sql", 5))
assertMigration(t, migrations[4], NewSource(TypeGo, "", 6))
})
})
t.Run("partial_go_files_on_disk", func(t *testing.T) {
Expand All @@ -271,17 +271,17 @@ func TestMerge(t *testing.T) {
t.Run("unregistered_all", func(t *testing.T) {
migrations, err := merge(sources, map[int64]*goMigration{
// This is the only Go file on disk.
2: {version: 2},
2: newGoMigration("", nil, nil),
// These are not on disk. Explicitly registered.
3: {version: 3},
6: {version: 6},
3: newGoMigration("", nil, nil),
6: newGoMigration("", nil, nil),
})
check.NoError(t, err)
check.Number(t, len(migrations), 4)
assertMigration(t, migrations[0], newSource(TypeSQL, "00001_foo.sql", 1))
assertMigration(t, migrations[1], newSource(TypeGo, "00002_bar.go", 2))
assertMigration(t, migrations[2], newSource(TypeGo, "", 3))
assertMigration(t, migrations[3], newSource(TypeGo, "", 6))
assertMigration(t, migrations[0], NewSource(TypeSQL, "00001_foo.sql", 1))
assertMigration(t, migrations[1], NewSource(TypeGo, "00002_bar.go", 2))
assertMigration(t, migrations[2], NewSource(TypeGo, "", 3))
assertMigration(t, migrations[3], NewSource(TypeGo, "", 6))
})
})
}
Expand All @@ -291,7 +291,7 @@ func assertMigration(t *testing.T, got *migration, want Source) {
check.Equal(t, got.Source, want)
switch got.Source.Type {
case TypeGo:
check.Equal(t, got.Go.version, want.Version)
check.Bool(t, got.Go != nil, true)
case TypeSQL:
check.Bool(t, got.SQL == nil, true)
default:
Expand Down
39 changes: 39 additions & 0 deletions internal/provider/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package provider

import (
"errors"
"fmt"
"path/filepath"
)

var (
// ErrVersionNotFound when a migration version is not found.
ErrVersionNotFound = errors.New("version not found")

// ErrAlreadyApplied when a migration has already been applied.
ErrAlreadyApplied = errors.New("already applied")

// ErrNoMigrations is returned by [NewProvider] when no migrations are found.
ErrNoMigrations = errors.New("no migrations found")

// ErrNoNextVersion when the next migration version is not found.
ErrNoNextVersion = errors.New("no next version found")
)

// PartialError is returned when a migration fails, but some migrations already got applied.
type PartialError struct {
// Applied are migrations that were applied successfully before the error occurred.
Applied []*MigrationResult
// Failed contains the result of the migration that failed.
Failed *MigrationResult
// Err is the error that occurred while running the migration.
Err error
}

func (e *PartialError) Error() string {
filename := "(file unknown)"
if e.Failed != nil && e.Failed.Source.Fullpath != "" {
filename = fmt.Sprintf("(%s)", filepath.Base(e.Failed.Source.Fullpath))
}
return fmt.Sprintf("partial migration error %s (%d): %v", filename, e.Failed.Source.Version, e.Err)
}
Loading