diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 7681537..bee3354 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -1,4 +1,4 @@ -name: Run Go tests +name: Go tests on: [ push ] jobs: Build-and-test: diff --git a/.github/workflows/mysql.yml b/.github/workflows/mysql.yml new file mode 100644 index 0000000..831e891 --- /dev/null +++ b/.github/workflows/mysql.yml @@ -0,0 +1,38 @@ +name: MySQL tests +on: [ push ] + +jobs: + Test-mysql-integration: + runs-on: ubuntu-latest + + services: + mysql: + image: mysql:8 + env: + MYSQL_DATABASE: libschematest + MYSQL_ROOT_PASSWORD: mysql + options: >- + --health-cmd "mysqladmin ping" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 3306:3306 + + steps: + - name: Check out repository code + uses: actions/checkout@v2 + + - name: Set up Go + uses: actions/setup-go@v2 + with: + go-version: 1.16 + + - name: Build + run: go build -v ./... + + - name: Test + env: + LIBSCHEMA_MYSQL_TEST_DSN: "root:mysql@tcp(127.0.0.1:3306)/libschematest?tls=false" + run: go test ./lsmysql/... -v + diff --git a/api.go b/api.go index 602d789..92bf794 100644 --- a/api.go +++ b/api.go @@ -42,11 +42,13 @@ type MigrationOption func(Migration) // Migration defines a single database defintion update. type MigrationBase struct { - Name MigrationName - async bool - rawAfter []MigrationName - order int // overall desired ordring across all libraries, ignores runAfter - status MigrationStatus + Name MigrationName + async bool + rawAfter []MigrationName + order int // overall desired ordring across all libraries, ignores runAfter + status MigrationStatus + skipIf func() (bool, error) + skipRemainingIf func() (bool, error) } func (m MigrationBase) Copy() MigrationBase { @@ -66,9 +68,8 @@ type Migration interface { // MigrationStatus tracks if a migration is complete or not. type MigrationStatus struct { - Done bool - Partial string // for Mysql, the string represents the portion of multiple commands that have completed - Error string // If an attempt was made but failed, this will be set + Done bool + Error string // If an attempt was made but failed, this will be set } // Database tracks all of the migrations for a specific database. @@ -203,6 +204,18 @@ func After(lib, migration string) MigrationOption { } } +func SkipIf(pred func() (bool, error)) MigrationOption { + return func(m Migration) { + m.Base().skipIf = pred + } +} + +func SkipRemainingIf(pred func() (bool, error)) MigrationOption { + return func(m Migration) { + m.Base().skipRemainingIf = pred + } +} + func (d *Database) DB() *sql.DB { return d.db } @@ -242,6 +255,10 @@ func (m *MigrationBase) SetStatus(status MigrationStatus) { m.status = status } +func (m *MigrationBase) HasSkipIf() bool { + return m.skipIf != nil +} + func (n MigrationName) String() string { return n.Library + ": " + n.Name } diff --git a/apply.go b/apply.go index c068b4e..678c7c6 100644 --- a/apply.go +++ b/apply.go @@ -212,15 +212,16 @@ func (d *Database) migrate(ctx context.Context) (err error) { go d.asyncMigrate(ctx) return nil } - err = d.doOneMigration(ctx, m) - if err != nil { + var stop bool + stop, err = d.doOneMigration(ctx, m) + if err != nil || stop { return err } } return nil } -func (d *Database) doOneMigration(ctx context.Context, m Migration) error { +func (d *Database) doOneMigration(ctx context.Context, m Migration) (bool, error) { if d.Options.DebugLogging { d.log.Debug("Starting migration", map[string]interface{}{ "database": d.name, @@ -228,11 +229,29 @@ func (d *Database) doOneMigration(ctx context.Context, m Migration) error { "name": m.Base().Name.Name, }) } + if m.Base().skipIf != nil { + skip, err := m.Base().skipIf() + if err != nil { + return false, errors.Wrapf(err, "SkipIf %s", m.Base().Name) + } + if skip { + return false, nil + } + } + if m.Base().skipRemainingIf != nil { + skip, err := m.Base().skipRemainingIf() + if err != nil { + return false, errors.Wrapf(err, "SkipRemainingIf %s", m.Base().Name) + } + if skip { + return true, nil + } + } err := d.driver.DoOneMigration(ctx, d.log, d, m) if err != nil && d.Options.OnMigrationFailure != nil { d.Options.OnMigrationFailure(m.Base().Name, err) } - return err + return false, err } func (d *Database) lastUnfinishedSynchrnous() int { @@ -279,8 +298,9 @@ func (d *Database) asyncMigrate(ctx context.Context) { if m.Base().Status().Done { continue } - err = d.doOneMigration(ctx, m) - if err != nil { + var stop bool + stop, err = d.doOneMigration(ctx, m) + if err != nil || stop { return } } diff --git a/go.mod b/go.mod index 202cde2..4065f5a 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,10 @@ module github.com/muir/libschema go 1.16 require ( - github.com/lib/pq v1.10.3 - github.com/muir/testinglogur v0.0.0-20210705185900-bc47cbaaadca + github.com/go-sql-driver/mysql v1.5.0 + github.com/lib/pq v1.10.4 + github.com/muir/sqltoken v0.0.4 + github.com/muir/testinglogur v0.0.1 github.com/pkg/errors v0.9.1 - github.com/stretchr/testify v1.7.0 + github.com/stretchr/testify v1.7.1 ) diff --git a/go.sum b/go.sum index 837d6d7..1dee2cd 100644 --- a/go.sum +++ b/go.sum @@ -1,16 +1,29 @@ +github.com/alvaroloes/enumer v1.1.2/go.mod h1:FxrjvuXoDAx9isTJrv4c+T410zFi0DtXIT0m65DJ+Wo= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/lib/pq v1.10.3 h1:v9QZf2Sn6AmjXtQeFpdoq/eaNtYP6IN+7lcrygsIAtg= -github.com/lib/pq v1.10.3/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/muir/testinglogur v0.0.0-20210705185900-bc47cbaaadca h1:umBSRx6i2/+1gbab8wlghfL7vPhBGr8ZwlKlo1nRg04= -github.com/muir/testinglogur v0.0.0-20210705185900-bc47cbaaadca/go.mod h1:18iL5fVrQ2hu0NeXKtEE9pS5jgdaNTgqWHNl+p33g6M= +github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs= +github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/lib/pq v1.10.4 h1:SO9z7FRPzA03QhHKJrH5BXA6HU1rS4V2nIVrrNC1iYk= +github.com/lib/pq v1.10.4/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/muir/sqltoken v0.0.4 h1:SioNnG90ZYXmlfnPaUxUdNC1dFkhKL64pDeS+wXZ8k8= +github.com/muir/sqltoken v0.0.4/go.mod h1:6hPsZxszMpYyNf12og4f4VShFo/Qipz6Of0cn5KGAAU= +github.com/muir/testinglogur v0.0.1 h1:k0lztrKzttiH5Pjtzj7S4tXXXBgUaxqTtVKXK4ndiI8= +github.com/muir/testinglogur v0.0.1/go.mod h1:18iL5fVrQ2hu0NeXKtEE9pS5jgdaNTgqWHNl+p33g6M= +github.com/pascaldekloe/name v0.0.0-20180628100202-0fd16699aae1/go.mod h1:eD5JxqMiuNYyFNmyY9rkJ/slN8y59oEu4Ei7F8OoKWQ= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/tools v0.0.0-20190524210228-3d17549cdc6b/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= diff --git a/lsmysql/README.md b/lsmysql/README.md new file mode 100644 index 0000000..56daf3e --- /dev/null +++ b/lsmysql/README.md @@ -0,0 +1,97 @@ + +# libschema/lsmysql - mysql support for libschema + +[![GoDoc](https://godoc.org/github.com/muir/libschema?status.png)](https://pkg.go.dev/github.com/muir/libschema/lsmysql) + +Install: + + go get github.com/muir/libschema + +--- + +## DDL Transactions + +MySQL and MariaDB do not support DDL (Data Definition Language) transactions like +`CREATE TABLE`. Such commands cause the current transaction to switch to `autocommit` +mode. + +The consequence of this is that it is not possible for a schema migration tool, +like libschema, to track if a migration has been applied or not by tracking the status +of a transaction. + +When working with MySQL and MariaDB, schema-changing migrations should be done +separately from data-changing migrations. Schema-changing transactions that are +idempotent are safe and require no special handling. + +Schema-changing transactions that are not idempotent need to be guarded with conditionals +so that they're skipped if they've already been applied. + +Fortunately, `IF EXISTS` and `IF NOT EXISTS` clauses can be most of the DDL statements. + +### Conditionals + +The DDL statements missing `IF EXISTS` and `IF NOT EXISTS` include: + +```sql +ALTER TABLE ... + ADD CONSTRAINT + ALTER COLUMN SET SET DEFAULT + ALTER COLUMN SET DROP DEFAULT + ADD FULLTEXT + ADD SPATIAL + ADD PERIOD FOR SYSTEM TIME + ADD {INDEX|KEY} index_name [NOT] INVISIBLE + DROP PRIMARY KEY + RENAME COLUMN + RENAME INDEX + RENAME KEY + DISCARD TABLESPACE + IMPORT TABLESPACE + COALESCE PARTITION + REORGANIZE PARTITION + EXCHANGE PARTITION + REMOVE PARTITIONING + DISABLE KEYS + ENABLE KEYS +``` + +To help make these conditional, the lsmysql provides some helper functions to easily +check the current database state. + +For example: + +```go +schema := libschema.NewSchema(ctx, libschema.Options{}) + +sqlDB, err := sql.Open("mysql", "....") + +database, mysql, err := lsmysql.New(logger, "main-db", schema, sqlDB) + +database.Migrations("MyLibrary", + lsmysql.Script("createUserTable", ` + CREATE TABLE users ( + name text, + id bigint, + PRIMARY KEY (id) + ) ENGINE=InnoDB` + }), + lsmysql.Script("dropUserPK", ` + ALTER TABLE users + DROP PRIMARY KEY`, + libschema.SkipIf(func() (bool, error) { + hasPK, err := mysql.HasPrimaryKey("users") + return !hasPK, err + })), + ) +``` + +### Some notes on MySQL + +While most identifiers (table names, etc) can be `"`quoted`"`, you cannot use quotes around +a schema (database really) name with `CREATE SCHEMA`. + +MySQL does not support schemas. A schema is just a synonym for `DATABASE` in the MySQL world. +This means that it is easier to put migrations tracking table in the same schema (database) as +the rest of the tables. It also means that to run migration unit tests, the DSN for testing +has to give access to a user that can create and drop databases. + diff --git a/lsmysql/check.go b/lsmysql/check.go new file mode 100644 index 0000000..0b79ac4 --- /dev/null +++ b/lsmysql/check.go @@ -0,0 +1,53 @@ +package lsmysql + +import ( + "regexp" + "strings" + + "github.com/muir/sqltoken" +) + +type CheckResult string + +const ( + Safe CheckResult = "safe" + DataAndDDL = "dataAndDDL" + NonIdempotentDDL = "nonIdempotentDDL" +) + +var ifExistsRE = regexp.MustCompile(`(?i)\bIF (?:NOT )?EXISTS\b`) + +// CheckScript attempts to validate that an SQL command does not do +// both schema changes (DDL) and data changes. +func CheckScript(s string) CheckResult { + var seenDDL int + var seenData int + var idempotent int + ts := sqltoken.TokenizeMySQL(s) + for _, cmd := range ts.Strip().CmdSplit() { + word := strings.ToLower(cmd[0].Text) + switch word { + case "alter", "rename", "create", "drop", "comment": + seenDDL++ + if ifExistsRE.MatchString(cmd.String()) { + idempotent++ + } + case "truncate": + seenDDL++ + idempotent++ + case "use", "set": + // neither + case "values", "table", "select": + // doesn't modify anything + case "call", "delete", "do", "handler", "import", "insert", "load", "replace", "update", "with": + seenData++ + } + } + if seenDDL > 0 && seenData > 0 { + return DataAndDDL + } + if seenDDL > idempotent { + return NonIdempotentDDL + } + return Safe +} diff --git a/lsmysql/mysql.go b/lsmysql/mysql.go new file mode 100644 index 0000000..7c3a1d3 --- /dev/null +++ b/lsmysql/mysql.go @@ -0,0 +1,355 @@ +package lsmysql + +import ( + "context" + "database/sql" + "fmt" + "regexp" + "strings" + + "github.com/muir/libschema" + "github.com/pkg/errors" +) + +// MySQL is a libschema.Driver for connecting to MySQL-like databases that +// have the following characteristics: +// * CANNOT do DDL commands inside transactions +// * Support UPSERT using INSERT ... ON DUPLICATE KEY UPDATE +// * uses /* -- and # for comments +// +// Because mysql DDL commands cause transactions to autocommit, tracking the schema changes in +// a secondary table (like libschema does) is inherently unsafe. The MySQL driver will +// record that it is about to attempt a migration and it will record if that attempts succeeds +// or fails, but if the program terminates mid-transaction, it is beyond the scope of libschema +// to determine if the transaction succeeded or failed. Such transactions will be retried. +// For this reason, it is reccomend that DDL commands be written such that they are idempotent. +type MySQL struct { + lockTx *sql.Tx + lockStr string + db *sql.DB +} + +// New creates a libschema.Database with a mysql driver built in. +func New(log libschema.MyLogger, name string, schema *libschema.Schema, db *sql.DB) (*libschema.Database, *MySQL, error) { + m := &MySQL{db: db} + d, err := schema.NewDatabase(log, name, db, m) + return d, m, err +} + +type mmigration struct { + libschema.MigrationBase + script func(context.Context, libschema.MyLogger, *sql.Tx) string + computed func(context.Context, libschema.MyLogger, *sql.Tx) error +} + +func (m *mmigration) Copy() libschema.Migration { + return &mmigration{ + MigrationBase: m.MigrationBase.Copy(), + script: m.script, + computed: m.computed, + } +} + +func (m *mmigration) Base() *libschema.MigrationBase { + return &m.MigrationBase +} + +// Script creates a libschema.Migration from a SQL string +func Script(name string, sqlText string, opts ...libschema.MigrationOption) libschema.Migration { + return Generate(name, func(_ context.Context, _ libschema.MyLogger, _ *sql.Tx) string { + return sqlText + }, opts...) +} + +// Generate creates a libschema.Migration from a function that returns a SQL string +func Generate( + name string, + generator func(context.Context, libschema.MyLogger, *sql.Tx) string, + opts ...libschema.MigrationOption) libschema.Migration { + return mmigration{ + MigrationBase: libschema.MigrationBase{ + Name: libschema.MigrationName{ + Name: name, + }, + }, + script: generator, + }.applyOpts(opts) +} + +// Computed creates a libschema.Migration from a Go function to run +// the migration directly. +func Computed( + name string, + action func(context.Context, libschema.MyLogger, *sql.Tx) error, + opts ...libschema.MigrationOption) libschema.Migration { + return mmigration{ + MigrationBase: libschema.MigrationBase{ + Name: libschema.MigrationName{ + Name: name, + }, + }, + computed: action, + }.applyOpts(opts) +} + +func (m mmigration) applyOpts(opts []libschema.MigrationOption) libschema.Migration { + lsm := libschema.Migration(&m) + for _, opt := range opts { + opt(lsm) + } + return lsm +} + +// DoOneMigration applies a single migration. +// It is expected to be called by libschema. +func (p *MySQL) DoOneMigration(ctx context.Context, log libschema.MyLogger, d *libschema.Database, m libschema.Migration) (err error) { + // TODO: DRY + defer func() { + if err == nil { + m.Base().SetStatus(libschema.MigrationStatus{ + Done: true, + }) + } + }() + tx, err := d.DB().BeginTx(ctx, d.Options.MigrationTxOptions) + if err != nil { + return errors.Wrapf(err, "Begin Tx for migration %s", m.Base().Name) + } + if d.Options.SchemaOverride != "" { + if !simpleIdentifierRE.MatchString(d.Options.SchemaOverride) { + return errors.Errorf("Options.SchemaOverride must be a simple identifier, not '%s'", d.Options.SchemaOverride) + } + _, err := tx.Exec(`USE ` + d.Options.SchemaOverride) + if err != nil { + return errors.Wrapf(err, "Set search path to %s for %s", d.Options.SchemaOverride, m.Base().Name) + } + } + defer func() { + if err != nil { + tx.Rollback() + } else { + err = errors.Wrapf(tx.Commit(), "Commit migration %s", m.Base().Name) + } + return + }() + pm := m.(*mmigration) + if pm.script != nil { + script := pm.script(ctx, log, tx) + switch CheckScript(script) { + case Safe: + case DataAndDDL: + err = errors.New("Migration combines DDL (Data Definition Language [schema changes]) and data manipulation") + case NonIdempotentDDL: + if !m.Base().HasSkipIf() { + err = errors.New("Unconditional migration has non-idempotent DDL (Data Definition Language [schema changes])") + } + } + if err == nil { + _, err = tx.Exec(script) + } + err = errors.Wrap(err, script) + } else { + err = pm.computed(ctx, log, tx) + } + if err != nil { + err = errors.Wrapf(err, "Problem with migration %s", m.Base().Name) + tx.Rollback() + ntx, txerr := d.DB().BeginTx(ctx, d.Options.MigrationTxOptions) + if txerr != nil { + return errors.Wrapf(err, "Tx for saving status for %s also failed with %s", m.Base().Name, txerr) + } + tx = ntx + } + txerr := p.saveStatus(log, tx, d, m, err == nil, err) + if txerr != nil { + if err == nil { + err = txerr + } else { + err = errors.Wrapf(err, "Save status for %s also failed: %s", m.Base().Name, txerr) + } + } + return +} + +// CreateSchemaTableIfNotExists creates the migration tracking table for libschema. +// It is expected to be called by libschema. +func (p *MySQL) CreateSchemaTableIfNotExists(ctx context.Context, _ libschema.MyLogger, d *libschema.Database) error { + schema, tableName, err := trackingSchemaTable(d) + if err != nil { + return err + } + if schema != "" { + _, err := d.DB().ExecContext(ctx, fmt.Sprintf(` + CREATE SCHEMA IF NOT EXISTS %s + `, schema)) + if err != nil { + return errors.Wrapf(err, "Could not create libschema schema '%s'", schema) + } + } + _, err = d.DB().ExecContext(ctx, fmt.Sprintf(` + CREATE TABLE IF NOT EXISTS %s ( + library varchar(255) NOT NULL, + migration varchar(255) NOT NULL, + done boolean NOT NULL, + error text NOT NULL, + updated_at timestamp DEFAULT now(), + PRIMARY KEY (library, migration) + ) ENGINE = InnoDB`, tableName)) + if err != nil { + return errors.Wrapf(err, "Could not create libschema migrations table '%s'", tableName) + } + return nil +} + +var simpleIdentifierRE = regexp.MustCompile(`\A[A-Za-z][A-Za-z0-9_]*\z`) + +// When MySQL is in ANSI_QUOTES mode, it allows "table_name" quotes but when +// it is not then it does not. There is no prefect option: in ANSI_QUOTES +// mode, you could have a table called `table` (eg: `CREATE TABLE "table"`) but +// if you're not in ANSI_QUOTES mode then you cannot. We're going to assume +// that we're not in ANSI_QUOTES mode because we cannot assume that we are. +func trackingSchemaTable(d *libschema.Database) (string, string, error) { + tableName := d.Options.TrackingTable + s := strings.Split(tableName, ".") + switch len(s) { + case 2: + schema := s[0] + if !simpleIdentifierRE.MatchString(schema) { + return "", "", errors.Errorf("Tracking table schema name must be a simple identifier, not '%s'", schema) + } + table := s[1] + if !simpleIdentifierRE.MatchString(table) { + return "", "", errors.Errorf("Tracking table table name must be a simple identifier, not '%s'", table) + } + return schema, schema + "." + table, nil + case 1: + if !simpleIdentifierRE.MatchString(tableName) { + return "", "", errors.Errorf("Tracking table table name must be a simple identifier, not '%s'", tableName) + } + return "", tableName, nil + default: + return "", "", errors.Errorf("Tracking table '%s' is not valid", tableName) + } +} + +// trackingTable returns the schema+table reference for the migration tracking table. +// The name is already quoted properly for use as a save postgres identifier. +// TODO: DRY +func trackingTable(d *libschema.Database) string { + _, table, _ := trackingSchemaTable(d) + return table +} + +func (p *MySQL) saveStatus(log libschema.MyLogger, tx *sql.Tx, d *libschema.Database, m libschema.Migration, done bool, migrationError error) error { + var estr string + if migrationError != nil { + estr = migrationError.Error() + } + log.Info("Saving migration status", map[string]interface{}{ + "migration": m.Base().Name, + "done": done, + "error": migrationError, + }) + q := fmt.Sprintf(` + REPLACE INTO %s (library, migration, done, error, updated_at) + VALUES (?, ?, ?, ?, now())`, trackingTable(d)) + _, err := tx.Exec(q, m.Base().Name.Library, m.Base().Name.Name, done, estr) + if err != nil { + return errors.Wrapf(err, "Save status for %s", m.Base().Name) + } + return nil +} + +// LockMigrationsTable locks the migration tracking table for exclusive use by the +// migrations running now. +// It is expected to be called by libschema. +// In MySQL, locks are _not_ tied to transactions so closing the transaction +// does not release the lock. We'll use a transaction just to make sure that +// we're using the same connection. If LockMigrationsTable succeeds, be sure to +// call UnlockMigrationsTable. +func (p *MySQL) LockMigrationsTable(ctx context.Context, _ libschema.MyLogger, d *libschema.Database) error { + _, tableName, err := trackingSchemaTable(d) + if err != nil { + return err + } + if p.lockTx != nil { + return errors.Errorf("libschema migrations table, '%s' already locked", tableName) + } + tx, err := d.DB().BeginTx(ctx, &sql.TxOptions{}) + if err != nil { + return errors.Wrap(err, "Could not start transaction: %s") + } + p.lockStr = "libschema_" + tableName + var gotLock int + err = tx.QueryRow(`SELECT GET_LOCK(?, -1)`, p.lockStr).Scan(&gotLock) + if err != nil { + return errors.Wrapf(err, "Could not get lock for libschema migrations") + } + p.lockTx = tx + return nil +} + +// UnlockMigrationsTable unlocks the migration tracking table. +// It is expected to be called by libschema. +func (p *MySQL) UnlockMigrationsTable(_ libschema.MyLogger) error { + if p.lockTx == nil { + return errors.Errorf("libschema migrations table, not locked") + } + defer func() { + _ = p.lockTx.Rollback() + p.lockTx = nil + }() + _, err := p.lockTx.Exec(`SELECT RELEASE_LOCK(?)`, p.lockStr) + if err != nil { + return errors.Wrap(err, "Could not release explicit lock for schema migrations") + } + return nil +} + +// LoadStatus loads the current status of all migrations from the migration tracking table. +// It is expected to be called by libschema. +func (p *MySQL) LoadStatus(ctx context.Context, _ libschema.MyLogger, d *libschema.Database) ([]libschema.MigrationName, error) { + // TODO: DRY + tableName := trackingTable(d) + rows, err := d.DB().QueryContext(ctx, fmt.Sprintf(` + SELECT library, migration, done + FROM %s`, tableName)) + if err != nil { + return nil, errors.Wrap(err, "Cannot query migration status") + } + defer rows.Close() + var unknowns []libschema.MigrationName + for rows.Next() { + var ( + name libschema.MigrationName + status libschema.MigrationStatus + ) + err := rows.Scan(&name.Library, &name.Name, &status.Done) + if err != nil { + return nil, errors.Wrap(err, "Cannot scan migration status") + } + if m, ok := d.Lookup(name); ok { + m.Base().SetStatus(status) + } else if status.Done { + unknowns = append(unknowns, name) + } + } + return unknowns, nil +} + +// IsMigrationSupported checks to see if a migration is well-formed. Absent a code change, this +// should always return nil. +// It is expected to be called by libschema. +func (p *MySQL) IsMigrationSupported(d *libschema.Database, _ libschema.MyLogger, migration libschema.Migration) error { + m, ok := migration.(*mmigration) + if !ok { + return fmt.Errorf("Non-postgres migration %s registered with postgres migrations", migration.Base().Name) + } + if m.script != nil { + return nil + } + if m.computed != nil { + return nil + } + return errors.Errorf("Migration %s is not supported", m.Name) +} diff --git a/lsmysql/mysql_test.go b/lsmysql/mysql_test.go new file mode 100644 index 0000000..2ba9ded --- /dev/null +++ b/lsmysql/mysql_test.go @@ -0,0 +1,258 @@ +package lsmysql_test + +import ( + "context" + "database/sql" + "fmt" + "os" + "testing" + + _ "github.com/go-sql-driver/mysql" + "github.com/muir/libschema" + "github.com/muir/libschema/lsmysql" + "github.com/muir/libschema/lstesting" + "github.com/muir/testinglogur" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Since MySQL does not support schemas (it treats them like databases), +// LIBSCHEMA_MYSQL_TEST_DSN has to give access to a user that can create +// and destroy databases. + +func TestMysqlHappyPath(t *testing.T) { + dsn := os.Getenv("LIBSCHEMA_MYSQL_TEST_DSN") + if dsn == "" { + t.Skip("Set $LIBSCHEMA_MYSQL_TEST_DSN to test libschema/lsmysql") + } + + var actions []string + + options, cleanup := lstesting.FakeSchema(t, "") + + options.ErrorOnUnknownMigrations = true + options.OnMigrationFailure = func(name libschema.MigrationName, err error) { + actions = append(actions, fmt.Sprintf("FAIL %s: %s", name, err)) + } + options.OnMigrationsStarted = func() { + actions = append(actions, "START") + } + options.OnMigrationsComplete = func(err error) { + if err != nil { + actions = append(actions, "COMPLETE: "+err.Error()) + } else { + actions = append(actions, "COMPLETE") + } + } + options.DebugLogging = true + s := libschema.New(context.Background(), options) + + db, err := sql.Open("mysql", dsn) + require.NoError(t, err, "open database") + defer db.Close() + defer cleanup(db) + + dbase, _, err := lsmysql.New(testinglogur.Get(t), "test", s, db) + require.NoError(t, err, "libschema NewDatabase") + + dbase.Migrations("L1", + lsmysql.Generate("T1", func(_ context.Context, _ libschema.MyLogger, _ *sql.Tx) string { + actions = append(actions, "MIGRATE: L1.T1") + return `CREATE TABLE IF NOT EXISTS T1 (id text) ENGINE = InnoDB` + }), + lsmysql.Script("T2pre", ` + INSERT INTO T1 (id) VALUES ('T2');`), + lsmysql.Script("T2pre2", ` + INSERT INTO T3 (id) VALUES ('T2');`, + libschema.After("L2", "T3")), + lsmysql.Computed("T2", func(_ context.Context, _ libschema.MyLogger, tx *sql.Tx) error { + actions = append(actions, "MIGRATE: L1.T2") + _, err := tx.Exec(` + CREATE TABLE IF NOT EXISTS T2 (id text) ENGINE = InnoDB`) + return err + }), + lsmysql.Generate("PT1", func(_ context.Context, _ libschema.MyLogger, _ *sql.Tx) string { + actions = append(actions, "MIGRATE: L1.PT1") + return ` + INSERT INTO T1 (id) VALUES ('PT1'); + ` + }), + lsmysql.Script("PT1p1", ` + INSERT INTO T2 (id) VALUES ('PT1');`), + lsmysql.Script("PT1p2", ` + INSERT INTO T3 (id) VALUES ('PT1');`), + ) + + dbase.Migrations("L2", + lsmysql.Script("T3pre", ` + INSERT INTO T1 (id) VALUES ('T3');`), + lsmysql.Generate("T3", func(_ context.Context, _ libschema.MyLogger, _ *sql.Tx) string { + actions = append(actions, "MIGRATE: L2.T3") + return ` + CREATE TABLE IF NOT EXISTS T3 (id text) ENGINE = InnoDB` + }), + lsmysql.Script("T4pre1", ` + INSERT INTO T1 (id) VALUES ('T4');`), + lsmysql.Script("T4pre2", ` + INSERT INTO T2 (id) VALUES ('T4');`), + lsmysql.Script("T4pre3", ` + INSERT INTO T3 (id) VALUES ('T4');`), + lsmysql.Generate("T4", func(_ context.Context, _ libschema.MyLogger, _ *sql.Tx) string { + actions = append(actions, "MIGRATE: L2.T4") + return ` + CREATE TABLE IF NOT EXISTS T4 (id text) ENGINE = InnoDB` + }), + ) + + err = s.Migrate(context.Background()) + assert.NoError(t, err) + + assert.Equal(t, []string{ + "START", + "MIGRATE: L1.T1", + "MIGRATE: L2.T3", + "MIGRATE: L1.T2", + "MIGRATE: L1.PT1", + "MIGRATE: L2.T4", + "COMPLETE", + }, actions) + + rows, err := db.Query(` + SELECT table_name + FROM information_schema.tables + WHERE table_schema = ? + ORDER BY table_name`, options.SchemaOverride) + require.NoError(t, err, "query for list of tables") + defer rows.Close() + var names []string + for rows.Next() { + var name string + assert.NoError(t, rows.Scan(&name)) + names = append(names, name) + } + assert.Equal(t, []string{"T1", "T2", "T3", "T4", "tracking_table"}, names, "table names") +} + +func TestMysqlNotAllowed(t *testing.T) { + dsn := os.Getenv("LIBSCHEMA_MYSQL_TEST_DSN") + if dsn == "" { + t.Skip("Set $LIBSCHEMA_MYSQL_TEST_DSN to test libschema/lsmysql") + } + + cases := []struct { + name string + migration string + errorText string + }{ + { + name: "combines", + migration: `CREATE TABLE T1 (id text) ENGINE=InnoDB; INSERT INTO T1 (id) VALUES ('x');`, + errorText: "Migration combines DDL", + }, + { + name: "unconditional", + migration: `CREATE TABLE T1 (id text) ENGINE=InnoDB`, + errorText: "Unconditional migration has non-idempotent", + }, + } + + for _, tc := range cases { + options, cleanup := lstesting.FakeSchema(t, "") + options.DebugLogging = true + s := libschema.New(context.Background(), options) + + db, err := sql.Open("mysql", dsn) + require.NoError(t, err, "open database") + defer db.Close() + defer cleanup(db) + + dbase, _, err := lsmysql.New(testinglogur.Get(t), "test", s, db) + require.NoError(t, err, "libschema NewDatabase") + + dbase.Migrations("L1", lsmysql.Script("x", tc.migration)) + + err = s.Migrate(context.Background()) + if assert.Error(t, err, tc.name) { + assert.Contains(t, err.Error(), tc.errorText, tc.name) + } + } +} + +func TestSkipFunctions(t *testing.T) { + dsn := os.Getenv("LIBSCHEMA_MYSQL_TEST_DSN") + if dsn == "" { + t.Skip("Set $LIBSCHEMA_MYSQL_TEST_DSN to test libschema/lsmysql") + } + + options, cleanup := lstesting.FakeSchema(t, "") + options.DebugLogging = true + s := libschema.New(context.Background(), options) + + db, err := sql.Open("mysql", dsn) + require.NoError(t, err, "open database") + defer db.Close() + defer cleanup(db) + + dbase, m, err := lsmysql.New(testinglogur.Get(t), "test", s, db) + require.NoError(t, err, "libschema NewDatabase") + + dbase.Migrations("T", + lsmysql.Script("setup1", ` + CREATE TABLE IF NOT EXISTS users ( + id varchar(255), + level integer DEFAULT 37, + PRIMARY KEY (id) + ) ENGINE=InnoDB`), + lsmysql.Script("setup2", ` + CREATE TABLE IF NOT EXISTS accounts ( + id varchar(255) + ) ENGINE=InnoDB`), + lsmysql.Script("setup3", ` + ALTER TABLE users + ADD CONSTRAINT hi_level + CHECK (level > 10) ENFORCED`, + libschema.SkipIf(func() (bool, error) { + t, _, err := m.GetTableConstraint("users", "hi_level") + return t != "", err + })), + ) + + err = s.Migrate(context.Background()) + assert.NoError(t, err) + + hasPK, err := m.HasPrimaryKey("users") + if assert.NoError(t, err, "users has pk") { + assert.True(t, hasPK, "users has pk") + } + hasPK, err = m.HasPrimaryKey("accounts") + if assert.NoError(t, err, "accounts has pk") { + assert.False(t, hasPK, "accounts has pk") + } + dflt, err := m.ColumnDefault("users", "id") + if assert.NoError(t, err, "user id default") { + assert.Nil(t, dflt, "user id default") + } + dflt, err = m.ColumnDefault("users", "level") + if assert.NoError(t, err, "user level default") { + if assert.NotNil(t, dflt, "user level default") { + assert.Equal(t, "37", *dflt, "user id default") + } + } + exists, err := m.DoesColumnExist("users", "foo") + if assert.NoError(t, err, "users has foo") { + assert.False(t, exists, "users has foo") + } + exists, err = m.DoesColumnExist("users", "level") + if assert.NoError(t, err, "users has level") { + assert.True(t, exists, "users has level") + } + typ, enf, err := m.GetTableConstraint("users", "hi_level") + if assert.NoError(t, err, "users hi_level constraint") { + assert.Equal(t, "CHECK", typ, "users hi_level constraint") + assert.True(t, enf, "users hi_level constraint") + } +} + +func pointerToString(s string) *string { + return &s +} diff --git a/lsmysql/skip.go b/lsmysql/skip.go new file mode 100644 index 0000000..d74fe00 --- /dev/null +++ b/lsmysql/skip.go @@ -0,0 +1,101 @@ +package lsmysql + +import ( + "database/sql" + + "github.com/pkg/errors" +) + +// ColumnDefault returns the default value for a column. If there +// is no default value, then nil is returned. +func (m *MySQL) ColumnDefault(table, column string) (*string, error) { + database, err := m.DatabaseName() + if err != nil { + return nil, err + } + var dflt *string + err = m.db.QueryRow(` + SELECT column_default + FROM information_schema.columns + WHERE table_schema = ? + AND table_name = ? + AND column_name = ?`, + database, table, column).Scan(&dflt) + return dflt, errors.Wrapf(err, "get default for %s.%s", table, column) +} + +// HasPrimaryKey returns true if the table has a primary key +func (m *MySQL) HasPrimaryKey(table string) (bool, error) { + return m.TableHasIndex(table, "PRIMARY") +} + +// TableHasIndex returns true if there is an index matching the +// name given. +func (m *MySQL) TableHasIndex(table, indexName string) (bool, error) { + database, err := m.DatabaseName() + if err != nil { + return false, err + } + var count int + err = m.db.QueryRow(` + SELECT COUNT(*) + FROM information_schema.statistics + WHERE table_schema = ? + AND table_name = ? + AND index_name = ?`, + database, table, indexName).Scan(&count) + return count != 0, errors.Wrapf(err, "has table index %s.%s", table, indexName) + +} + +// DoesColumnExist returns true if the column exists +func (m *MySQL) DoesColumnExist(table, column string) (bool, error) { + database, err := m.DatabaseName() + if err != nil { + return false, err + } + var count int + err = m.db.QueryRow(` + SELECT COUNT(*) + FROM information_schema.columns + WHERE table_schema = ? + AND table_name = ? + AND column_name = ?`, + database, table, column).Scan(&count) + return count != 0, errors.Wrapf(err, "get column exist %s.%s", table, column) +} + +// GetTableConstraints returns the type of constraint and if it is enforced. +func (m *MySQL) GetTableConstraint(table, constraintName string) (string, bool, error) { + database, err := m.DatabaseName() + if err != nil { + return "", false, err + } + var typ *string + var enforced *string + err = m.db.QueryRow(` + SELECT constraint_type, enforced + FROM information_schema.table_constraints + WHERE constraint_schema = ? + AND table_name = ? + AND constraint_name = ?`, + database, table, constraintName).Scan(&typ, &enforced) + if err == sql.ErrNoRows { + return "", false, nil + } + return asString(typ), asString(enforced) == "YES", errors.Wrapf(err, "get table constraint %s.%s", table, constraintName) +} + +// DatabaseName returns the name of the current database (aka schema for MySQL) +func (m *MySQL) DatabaseName() (string, error) { + var database string + err := m.db.QueryRow(`SELECT DATABASE()`).Scan(&database) + return database, errors.Wrap(err, "select database()") +} + +func asString(s *string) string { + if s == nil { + return "" + } + return *s +} diff --git a/lspostgres/postgres.go b/lspostgres/postgres.go index cfeafba..bd64321 100644 --- a/lspostgres/postgres.go +++ b/lspostgres/postgres.go @@ -65,7 +65,7 @@ func Generate( }.applyOpts(opts) } -// Computed creates a libschema.Migration from a Go function to run to do +// Computed creates a libschema.Migration from a Go function to run // the migration directly. func Computed( name string, diff --git a/lspostgres/postgres_test.go b/lspostgres/postgres_test.go index f6df6d6..42285bd 100644 --- a/lspostgres/postgres_test.go +++ b/lspostgres/postgres_test.go @@ -23,7 +23,7 @@ func TestPostgresMigrations(t *testing.T) { var actions []string - options, cleanup := lstesting.FakeSchema(t) + options, cleanup := lstesting.FakeSchema(t, "CASCADE") options.ErrorOnUnknownMigrations = true options.OnMigrationFailure = func(name libschema.MigrationName, err error) { diff --git a/lstesting/testing.go b/lstesting/testing.go index ffd2ca7..9e9aa2a 100644 --- a/lstesting/testing.go +++ b/lstesting/testing.go @@ -28,13 +28,13 @@ type T interface { // FakeSchema generates an Options config with a fake random schema name // that begins with "lstest_". It also returns a function to // remove that schema -- the function should work with Postgres and Mysql. -func FakeSchema(t T) (libschema.Options, func(db *sql.DB)) { +func FakeSchema(t T, cascade string) (libschema.Options, func(db *sql.DB)) { schemaName := "lstest_" + RandomString(15) return libschema.Options{ TrackingTable: schemaName + ".tracking_table", SchemaOverride: schemaName, }, func(db *sql.DB) { - _, err := db.Exec(`DROP SCHEMA IF EXISTS ` + schemaName + ` CASCADE`) + _, err := db.Exec(`DROP SCHEMA IF EXISTS ` + schemaName + ` ` + cascade) t.Logf("DROPPED %s", schemaName) assert.NoError(t, err, "drop schema") }