Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add goose provider with unimplemented methods #596

Merged
merged 17 commits into from
Oct 7, 2023
14 changes: 14 additions & 0 deletions dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,20 @@ import (
"github.com/pressly/goose/v3/internal/dialect"
)

// Dialect is the type of database dialect.
type Dialect string

const (
DialectClickHouse Dialect = "clickhouse"
DialectMSSQL Dialect = "mssql"
DialectMySQL Dialect = "mysql"
DialectPostgres Dialect = "postgres"
DialectRedshift Dialect = "redshift"
DialectSQLite3 Dialect = "sqlite3"
DialectTiDB Dialect = "tidb"
DialectVertica Dialect = "vertica"
)

func init() {
store, _ = dialect.NewStore(dialect.Postgres)
}
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ require (
github.com/ory/dockertest/v3 v3.10.0
github.com/vertica/vertica-sql-go v1.3.3
github.com/ziutek/mymysql v1.5.4
go.uber.org/multierr v1.11.0
modernc.org/sqlite v1.25.0
)

Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ go.opentelemetry.io/otel v1.16.0 h1:Z7GVAX/UkAXPKsy94IU+i6thsQS4nb7LviLpnaNeW8s=
go.opentelemetry.io/otel v1.16.0/go.mod h1:vl0h9NUa1D5s1nv3A5vZOYWn8av4K8Ml6JDeHrT/bx4=
go.opentelemetry.io/otel/trace v1.16.0 h1:8JRpaObFoW0pxuVPapkgH8UhHQj+bJW8jJsCZEu5MQs=
go.opentelemetry.io/otel/trace v1.16.0/go.mod h1:Yt9vYq1SdNz3xdjZZK7wcXv1qv2pwLkqr2QVwea0ef0=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
Expand Down
49 changes: 49 additions & 0 deletions internal/sqladapter/sqladapter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Package sqladapter provides an interface for interacting with a SQL database.
//
// All supported database dialects must implement the Store interface.
package sqladapter

import (
"context"
"time"

"github.com/pressly/goose/v3/internal/sqlextended"
)

// Store is the interface that wraps the basic methods for a database dialect.
//
// A dialect is a set of SQL statements that are specific to a database.
//
// By defining a store interface, we can support multiple databases with a single codebase.
//
// The underlying implementation does not modify the error. It is the callers responsibility to
// assert for the correct error, such as sql.ErrNoRows.
type Store interface {
// CreateVersionTable creates the version table within a transaction. This table is used to
// record applied migrations.
CreateVersionTable(ctx context.Context, db sqlextended.DBTxConn, tablename string) error

// InsertOrDelete inserts or deletes a version id from the version table.
InsertOrDelete(ctx context.Context, db sqlextended.DBTxConn, direction bool, version int64) error

// GetMigration retrieves a single migration by version id.
//
// Returns the raw sql error if the query fails. It is the callers responsibility
// to assert for the correct error, such as sql.ErrNoRows.
GetMigration(ctx context.Context, db sqlextended.DBTxConn, version int64) (*GetMigrationResult, error)

// ListMigrations retrieves all migrations sorted in descending order by id.
//
// If there are no migrations, an empty slice is returned with no error.
ListMigrations(ctx context.Context, db sqlextended.DBTxConn) ([]*ListMigrationsResult, error)
}

type GetMigrationResult struct {
IsApplied bool
Timestamp time.Time
}

type ListMigrationsResult struct {
Version int64
IsApplied bool
}
111 changes: 111 additions & 0 deletions internal/sqladapter/store.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package sqladapter

import (
"context"
"errors"
"fmt"

"github.com/pressly/goose/v3/internal/dialect/dialectquery"
"github.com/pressly/goose/v3/internal/sqlextended"
)

var _ Store = (*store)(nil)

type store struct {
tablename string
querier dialectquery.Querier
}

// NewStore returns a new Store backed by the given dialect.
//
// The dialect must match one of the supported dialects defined in dialect.go.
func NewStore(dialect string, table string) (Store, error) {
if table == "" {
return nil, errors.New("table must not be empty")
}
if dialect == "" {
return nil, errors.New("dialect must not be empty")
}
var querier dialectquery.Querier
switch dialect {
case "clickhouse":
querier = &dialectquery.Clickhouse{}
case "mssql":
querier = &dialectquery.Sqlserver{}
case "mysql":
querier = &dialectquery.Mysql{}
case "postgres":
querier = &dialectquery.Postgres{}
case "redshift":
querier = &dialectquery.Redshift{}
case "sqlite3":
querier = &dialectquery.Sqlite3{}
case "tidb":
querier = &dialectquery.Tidb{}
case "vertica":
querier = &dialectquery.Vertica{}
default:
return nil, fmt.Errorf("unknown dialect: %q", dialect)
}
return &store{
tablename: table,
querier: querier,
}, nil
}

func (s *store) CreateVersionTable(ctx context.Context, db sqlextended.DBTxConn, tablename string) error {
q := s.querier.CreateTable(s.tablename)
if _, err := db.ExecContext(ctx, q); err != nil {
return fmt.Errorf("failed to create version table %q: %w", tablename, err)
}
return nil
}

func (s *store) InsertOrDelete(ctx context.Context, db sqlextended.DBTxConn, direction bool, version int64) error {
if direction {
q := s.querier.InsertVersion(s.tablename)
if _, err := db.ExecContext(ctx, q, version, true); err != nil {
return fmt.Errorf("failed to insert version %d: %w", version, err)
}
return nil
}
q := s.querier.DeleteVersion(s.tablename)
if _, err := db.ExecContext(ctx, q, version); err != nil {
return fmt.Errorf("failed to delete version %d: %w", version, err)
}
return nil
}

func (s *store) GetMigration(ctx context.Context, db sqlextended.DBTxConn, version int64) (*GetMigrationResult, error) {
q := s.querier.GetMigrationByVersion(s.tablename)
result := new(GetMigrationResult)
if err := db.QueryRowContext(ctx, q, version).Scan(
&result.Timestamp,
&result.IsApplied,
); err != nil {
return nil, fmt.Errorf("failed to get migration %d: %w", version, err)
}
return result, nil
}

func (s *store) ListMigrations(ctx context.Context, db sqlextended.DBTxConn) ([]*ListMigrationsResult, error) {
q := s.querier.ListMigrations(s.tablename)
rows, err := db.QueryContext(ctx, q)
if err != nil {
return nil, fmt.Errorf("failed to list migrations: %w", err)
}
defer rows.Close()

var migrations []*ListMigrationsResult
for rows.Next() {
result := new(ListMigrationsResult)
if err := rows.Scan(&result.Version, &result.IsApplied); err != nil {
return nil, fmt.Errorf("failed to scan list migrations result: %w", err)
}
migrations = append(migrations, result)
}
if err := rows.Err(); err != nil {
return nil, err
}
return migrations, nil
}
Loading
Loading