From 3137c761b9cdabef24f3df06d7ed04a1f8c03d6c Mon Sep 17 00:00:00 2001 From: Ariel Mashraki Date: Wed, 25 Dec 2024 21:55:00 +0200 Subject: [PATCH] sql/mysql: suffix oss driver with oss --- sql/mysql/{diff.go => diff_oss.go} | 6 + sql/mysql/{diff_test.go => diff_oss_test.go} | 2 + sql/mysql/driver.go | 482 ----------------- sql/mysql/driver_oss.go | 501 +++++++++++++++--- .../{driver_test.go => driver_oss_test.go} | 2 + sql/mysql/{inspect.go => inspect_oss.go} | 34 +- .../{inspect_test.go => inspect_oss_test.go} | 2 + sql/mysql/{migrate.go => migrate_oss.go} | 31 +- .../{migrate_test.go => migrate_oss_test.go} | 2 + sql/mysql/{sqlspec.go => sqlspec_oss.go} | 21 +- .../{sqlspec_test.go => sqlspec_oss_test.go} | 2 + 11 files changed, 467 insertions(+), 618 deletions(-) rename sql/mysql/{diff.go => diff_oss.go} (99%) rename sql/mysql/{diff_test.go => diff_oss_test.go} (99%) delete mode 100644 sql/mysql/driver.go rename sql/mysql/{driver_test.go => driver_oss_test.go} (99%) rename sql/mysql/{inspect.go => inspect_oss.go} (97%) rename sql/mysql/{inspect_test.go => inspect_oss_test.go} (99%) rename sql/mysql/{migrate.go => migrate_oss.go} (97%) rename sql/mysql/{migrate_test.go => migrate_oss_test.go} (99%) rename sql/mysql/{sqlspec.go => sqlspec_oss.go} (97%) rename sql/mysql/{sqlspec_test.go => sqlspec_oss_test.go} (99%) diff --git a/sql/mysql/diff.go b/sql/mysql/diff_oss.go similarity index 99% rename from sql/mysql/diff.go rename to sql/mysql/diff_oss.go index 0d662a537de..69a8b179f07 100644 --- a/sql/mysql/diff.go +++ b/sql/mysql/diff_oss.go @@ -2,6 +2,8 @@ // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. +//go:build !ent + package mysql import ( @@ -723,3 +725,7 @@ func (d *diff) defaultCharset(attrs *[]schema.Attr) error { } return nil } + +func (*diff) ViewAttrChanges(_, _ *schema.View) []schema.Change { + return nil // Not implemented. +} diff --git a/sql/mysql/diff_test.go b/sql/mysql/diff_oss_test.go similarity index 99% rename from sql/mysql/diff_test.go rename to sql/mysql/diff_oss_test.go index 2780cb61cd6..ce03c52001a 100644 --- a/sql/mysql/diff_test.go +++ b/sql/mysql/diff_oss_test.go @@ -2,6 +2,8 @@ // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. +//go:build !ent + package mysql import ( diff --git a/sql/mysql/driver.go b/sql/mysql/driver.go deleted file mode 100644 index d9e89167631..00000000000 --- a/sql/mysql/driver.go +++ /dev/null @@ -1,482 +0,0 @@ -// Copyright 2021-present The Atlas Authors. All rights reserved. -// This source code is licensed under the Apache 2.0 license found -// in the LICENSE file in the root directory of this source tree. - -package mysql - -import ( - "context" - "database/sql" - "fmt" - "net/url" - "strings" - "time" - - "ariga.io/atlas/sql/internal/sqlx" - "ariga.io/atlas/sql/migrate" - "ariga.io/atlas/sql/mysql/internal/mysqlversion" - "ariga.io/atlas/sql/schema" - "ariga.io/atlas/sql/sqlclient" -) - -type ( - // Driver represents a MySQL driver for introspecting database schemas, - // generating diff between schema elements and apply migrations changes. - Driver struct { - *conn - schema.Differ - schema.Inspector - migrate.PlanApplier - } - - // database connection and its information. - conn struct { - schema.ExecQuerier - // The schema was set in the path (schema connection). - schema string - // System variables that are set on `Open`. - mysqlversion.V - collate string - charset string - lcnames int - } -) - -var _ interface { - migrate.Snapshoter - migrate.StmtScanner - migrate.CleanChecker - schema.TypeParseFormatter -} = (*Driver)(nil) - -// DriverName and DriverMaria holds the names used for registration. -const ( - DriverName = "mysql" - DriverMaria = "mariadb" -) - -func init() { - sqlclient.Register( - DriverName, - opener(DriverName), - sqlclient.RegisterDriverOpener(Open), - sqlclient.RegisterCodec(codec, codec), - sqlclient.RegisterFlavours("mysql+unix"), - sqlclient.RegisterURLParser(parser{}), - ) - sqlclient.Register( - DriverMaria, - opener(DriverMaria), - sqlclient.RegisterDriverOpener(Open), - sqlclient.RegisterCodec(mariaCodec, mariaCodec), - sqlclient.RegisterFlavours("mariadb+unix", "maria", "maria+unix"), - sqlclient.RegisterURLParser(parser{}), - ) -} - -// Open opens a new MySQL driver. -func Open(db schema.ExecQuerier) (migrate.Driver, error) { - c := &conn{ExecQuerier: db} - rows, err := db.QueryContext(context.Background(), variablesQuery) - if err != nil { - return nil, fmt.Errorf("mysql: query system variables: %w", err) - } - if err := sqlx.ScanOne(rows, &c.V, &c.collate, &c.charset, &c.lcnames); err != nil { - return nil, fmt.Errorf("mysql: scan system variables: %w", err) - } - if c.TiDB() { - return &Driver{ - conn: c, - Differ: &sqlx.Diff{DiffDriver: &tdiff{diff{conn: c}}}, - Inspector: &tinspect{inspect{c}}, - PlanApplier: &tplanApply{planApply{c}}, - }, nil - } - return &Driver{ - conn: c, - Differ: &sqlx.Diff{DiffDriver: &diff{conn: c}}, - Inspector: &inspect{c}, - PlanApplier: &planApply{c}, - }, nil -} - -// opener for the given driver name. -func opener(name string) sqlclient.OpenerFunc { - return func(_ context.Context, u *url.URL) (*sqlclient.Client, error) { - ur := parser{}.ParseURL(u) - db, err := sql.Open(DriverName, ur.DSN) - if err != nil { - return nil, err - } - drv, err := Open(db) - if err != nil { - if cerr := db.Close(); cerr != nil { - err = fmt.Errorf("%w: %v", err, cerr) - } - return nil, err - } - drv.(*Driver).schema = ur.Schema - return &sqlclient.Client{ - Name: name, - DB: db, - URL: ur, - Driver: drv, - }, nil - } -} - -// NormalizeRealm returns the normal representation of the given database. -func (d *Driver) NormalizeRealm(ctx context.Context, r *schema.Realm) (*schema.Realm, error) { - return (&sqlx.DevDriver{Driver: d}).NormalizeRealm(ctx, r) -} - -// NormalizeSchema returns the normal representation of the given database. -func (d *Driver) NormalizeSchema(ctx context.Context, s *schema.Schema) (*schema.Schema, error) { - return (&sqlx.DevDriver{Driver: d}).NormalizeSchema(ctx, s) -} - -// Lock implements the schema.Locker interface. -func (d *Driver) Lock(ctx context.Context, name string, timeout time.Duration) (schema.UnlockFunc, error) { - conn, err := sqlx.SingleConn(ctx, d.ExecQuerier) - if err != nil { - return nil, err - } - if err := acquire(ctx, conn, name, timeout); err != nil { - conn.Close() - return nil, err - } - return func() error { - defer conn.Close() - rows, err := conn.QueryContext(ctx, "SELECT RELEASE_LOCK(?)", name) - if err != nil { - return err - } - switch released, err := sqlx.ScanNullBool(rows); { - case err != nil: - return err - case !released.Valid || !released.Bool: - return fmt.Errorf("sql/mysql: failed releasing a named lock %q", name) - } - return nil - }, nil -} - -// Snapshot implements migrate.Snapshoter. -func (d *Driver) Snapshot(ctx context.Context) (migrate.RestoreFunc, error) { - // If the connection is bound to a schema, we can restore the state if the schema has no tables. - s, err := d.InspectSchema(ctx, "", nil) - if err != nil && !schema.IsNotExistError(err) { - return nil, err - } - // If a schema was found, it has to have no tables attached to be considered clean. - if s != nil { - if len(s.Tables) > 0 { - return nil, &migrate.NotCleanError{ - State: schema.NewRealm(s), - Reason: fmt.Sprintf("found table %q in schema %q", s.Tables[0].Name, s.Name), - } - } - return d.SchemaRestoreFunc(s), nil - } - // Otherwise, the database can not have any schema. - realm, err := d.InspectRealm(ctx, nil) - if err != nil { - return nil, err - } - if len(realm.Schemas) > 0 { - return nil, &migrate.NotCleanError{State: realm, Reason: fmt.Sprintf("found schema %q", realm.Schemas[0].Name)} - } - return d.RealmRestoreFunc(realm), nil -} - -// SchemaRestoreFunc returns a function that restores the given schema to its desired state. -func (d *Driver) SchemaRestoreFunc(desired *schema.Schema) migrate.RestoreFunc { - return func(ctx context.Context) error { - current, err := d.InspectSchema(ctx, desired.Name, nil) - if err != nil { - return err - } - changes, err := d.SchemaDiff(current, desired) - if err != nil { - return err - } - return d.ApplyChanges(ctx, changes) - } -} - -// RealmRestoreFunc returns a function that restores the given realm to its desired state. -func (d *Driver) RealmRestoreFunc(desired *schema.Realm) migrate.RestoreFunc { - return func(ctx context.Context) error { - current, err := d.InspectRealm(ctx, nil) - if err != nil { - return err - } - changes, err := d.RealmDiff(current, desired) - if err != nil { - return err - } - return d.ApplyChanges(ctx, changes) - } -} - -// CheckClean implements migrate.CleanChecker. -func (d *Driver) CheckClean(ctx context.Context, revT *migrate.TableIdent) error { - if revT == nil { // accept nil values - revT = &migrate.TableIdent{} - } - s, err := d.InspectSchema(ctx, "", nil) - if err != nil && !schema.IsNotExistError(err) { - return err - } - if s != nil { - if len(s.Tables) == 0 || (revT.Schema == "" || s.Name == revT.Schema) && len(s.Tables) == 1 && s.Tables[0].Name == revT.Name { - return nil - } - return &migrate.NotCleanError{ - State: schema.NewRealm(s), - Reason: fmt.Sprintf("found table %q in schema %q", s.Tables[0].Name, s.Name), - } - } - r, err := d.InspectRealm(ctx, nil) - if err != nil { - return err - } - switch n := len(r.Schemas); { - case n > 1: - return &migrate.NotCleanError{State: r, Reason: fmt.Sprintf("found multiple schemas: %d", len(r.Schemas))} - case n == 1 && r.Schemas[0].Name != revT.Schema: - return &migrate.NotCleanError{State: r, Reason: fmt.Sprintf("found schema %q", r.Schemas[0].Name)} - case n == 1 && len(r.Schemas[0].Tables) > 1: - return &migrate.NotCleanError{State: r, Reason: fmt.Sprintf("found multiple tables: %d", len(r.Schemas[0].Tables))} - case n == 1 && len(r.Schemas[0].Tables) == 1 && r.Schemas[0].Tables[0].Name != revT.Name: - return &migrate.NotCleanError{State: r, Reason: fmt.Sprintf("found table %q", r.Schemas[0].Tables[0].Name)} - } - return nil -} - -// Version returns the version of the connected database. -func (d *Driver) Version() string { - return string(d.conn.V) -} - -// FormatType converts schema type to its column form in the database. -func (*Driver) FormatType(t schema.Type) (string, error) { - return FormatType(t) -} - -// ParseType returns the schema.Type value represented by the given string. -func (*Driver) ParseType(s string) (schema.Type, error) { - return ParseType(s) -} - -// StmtBuilder is a helper method used to build statements with MySQL formatting. -func (*Driver) StmtBuilder(opts migrate.PlanOptions) *sqlx.Builder { - return &sqlx.Builder{ - QuoteOpening: '`', - QuoteClosing: '`', - Schema: opts.SchemaQualifier, - Indent: opts.Indent, - } -} - -// ScanStmts implements migrate.StmtScanner. -func (*Driver) ScanStmts(input string) ([]*migrate.Stmt, error) { - return (&migrate.Scanner{ - ScannerOptions: migrate.ScannerOptions{ - MatchBegin: true, - BackslashEscapes: true, - HashComments: true, - // The following are not support by MySQL/MariaDB. - MatchBeginAtomic: false, - MatchDollarQuote: false, - }, - }).Scan(input) -} - -func acquire(ctx context.Context, conn schema.ExecQuerier, name string, timeout time.Duration) error { - rows, err := conn.QueryContext(ctx, "SELECT GET_LOCK(?, ?)", name, int(timeout.Seconds())) - if err != nil { - return err - } - switch acquired, err := sqlx.ScanNullBool(rows); { - case err != nil: - return err - case !acquired.Valid: - // NULL is returned in case of an unexpected internal error. - return fmt.Errorf("sql/mysql: unexpected internal error on Lock(%q, %s)", name, timeout) - case !acquired.Bool: - return schema.ErrLocked - } - return nil -} - -// unescape strings with backslashes returned -// for SQL expressions from information schema. -func unescape(s string) string { - var b strings.Builder - for i := 0; i < len(s); i++ { - switch c := s[i]; { - case c != '\\' || i == len(s)-1: - b.WriteByte(c) - case s[i+1] == '\'', s[i+1] == '\\': - b.WriteByte(s[i+1]) - i++ - } - } - return b.String() -} - -type parser struct{} - -// ParseURL implements the sqlclient.URLParser interface. -func (parser) ParseURL(u *url.URL) *sqlclient.URL { - v := u.Query() - v.Set("parseTime", "true") - u.RawQuery = v.Encode() - cu := &sqlclient.URL{URL: u, DSN: dsn(u), Schema: strings.TrimPrefix(u.Path, "/")} - if strings.HasSuffix(u.Scheme, "+unix") { - cu.Schema = v.Get("database") - } - return cu -} - -// ChangeSchema implements the sqlclient.SchemaChanger interface. -func (parser) ChangeSchema(u *url.URL, s string) *url.URL { - nu := *u - nu.Path = "/" + s - return &nu -} - -// dsn returns the MySQL standard DSN for opening -// the sql.DB from the user provided URL. -func dsn(u *url.URL) string { - var ( - b strings.Builder - values = u.Query() - ) - b.WriteString(u.User.Username()) - if p, ok := u.User.Password(); ok { - b.WriteByte(':') - b.WriteString(p) - } - if b.Len() > 0 { - b.WriteByte('@') - } - switch { - case strings.HasSuffix(u.Scheme, "+unix"): - b.WriteString("unix(") - // The path is always absolute, and - // therefore the host should be empty. - b.WriteString(u.Path) - b.WriteString(")/") - if name := values.Get("database"); name != "" { - b.WriteString(name) - values.Del("database") - } - default: - if u.Host != "" { - b.WriteString("tcp(") - b.WriteString(u.Host) - b.WriteByte(')') - } - if u.Path != "" { - b.WriteString(u.Path) - } else { - b.WriteByte('/') - } - } - if p := values.Encode(); p != "" { - b.WriteByte('?') - b.WriteString(p) - } - return b.String() -} - -// MySQL standard column types as defined in its codebase. Name and order -// is organized differently than MySQL. -// -// https://github.com/mysql/mysql-server/blob/8.0/include/field_types.h -// https://github.com/mysql/mysql-server/blob/8.0/sql/dd/types/column.h -// https://github.com/mysql/mysql-server/blob/8.0/sql/sql_show.cc -// https://github.com/mysql/mysql-server/blob/8.0/sql/gis/geometries.cc -// https://dev.mysql.com/doc/refman/8.0/en/other-vendor-data-types.html -const ( - TypeBool = "bool" - TypeBoolean = "boolean" - - TypeBit = "bit" // MYSQL_TYPE_BIT - TypeInt = "int" // MYSQL_TYPE_LONG - TypeTinyInt = "tinyint" // MYSQL_TYPE_TINY - TypeSmallInt = "smallint" // MYSQL_TYPE_SHORT - TypeMediumInt = "mediumint" // MYSQL_TYPE_INT24 - TypeBigInt = "bigint" // MYSQL_TYPE_LONGLONG - - TypeDecimal = "decimal" // MYSQL_TYPE_DECIMAL - TypeNumeric = "numeric" // MYSQL_TYPE_DECIMAL (numeric_type rule in sql_yacc.yy) - TypeFloat = "float" // MYSQL_TYPE_FLOAT - TypeDouble = "double" // MYSQL_TYPE_DOUBLE - TypeReal = "real" // MYSQL_TYPE_FLOAT or MYSQL_TYPE_DOUBLE (real_type in sql_yacc.yy) - - TypeTimestamp = "timestamp" // MYSQL_TYPE_TIMESTAMP - TypeDate = "date" // MYSQL_TYPE_DATE - TypeTime = "time" // MYSQL_TYPE_TIME - TypeDateTime = "datetime" // MYSQL_TYPE_DATETIME - TypeYear = "year" // MYSQL_TYPE_YEAR - - TypeVarchar = "varchar" // MYSQL_TYPE_VAR_STRING, MYSQL_TYPE_VARCHAR - TypeChar = "char" // MYSQL_TYPE_STRING - TypeVarBinary = "varbinary" // MYSQL_TYPE_VAR_STRING + NULL CHARACTER_SET. - TypeBinary = "binary" // MYSQL_TYPE_STRING + NULL CHARACTER_SET. - TypeBlob = "blob" // MYSQL_TYPE_BLOB - TypeTinyBlob = "tinyblob" // MYSQL_TYPE_TINYBLOB - TypeMediumBlob = "mediumblob" // MYSQL_TYPE_MEDIUM_BLOB - TypeLongBlob = "longblob" // MYSQL_TYPE_LONG_BLOB - TypeText = "text" // MYSQL_TYPE_BLOB + CHARACTER_SET utf8mb4 - TypeTinyText = "tinytext" // MYSQL_TYPE_TINYBLOB + CHARACTER_SET utf8mb4 - TypeMediumText = "mediumtext" // MYSQL_TYPE_MEDIUM_BLOB + CHARACTER_SET utf8mb4 - TypeLongText = "longtext" // MYSQL_TYPE_LONG_BLOB with + CHARACTER_SET utf8mb4 - - TypeEnum = "enum" // MYSQL_TYPE_ENUM - TypeSet = "set" // MYSQL_TYPE_SET - TypeJSON = "json" // MYSQL_TYPE_JSON - - TypeGeometry = "geometry" // MYSQL_TYPE_GEOMETRY - TypePoint = "point" // Geometry_type::kPoint - TypeMultiPoint = "multipoint" // Geometry_type::kMultipoint - TypeLineString = "linestring" // Geometry_type::kLinestring - TypeMultiLineString = "multilinestring" // Geometry_type::kMultilinestring - TypePolygon = "polygon" // Geometry_type::kPolygon - TypeMultiPolygon = "multipolygon" // Geometry_type::kMultipolygon - TypeGeoCollection = "geomcollection" // Geometry_type::kGeometrycollection - TypeGeometryCollection = "geometrycollection" // Geometry_type::kGeometrycollection - - TypeUUID = "uuid" // MariaDB supported uuid type from 10.7.0+ - - TypeInet4 = "inet4" // MariaDB type for storage of IPv4 addresses, from 10.10.0+. - TypeInet6 = "inet6" // MariaDB type for storage of IPv6 addresses, from 10.10.0+. -) - -// Additional common constants in MySQL. -const ( - IndexTypeBTree = "BTREE" - IndexTypeHash = "HASH" - IndexTypeFullText = "FULLTEXT" - IndexTypeSpatial = "SPATIAL" - - IndexParserNGram = "ngram" - IndexParserMeCab = "mecab" - - EngineInnoDB = "InnoDB" - EngineMyISAM = "MyISAM" - EngineMemory = "Memory" - EngineCSV = "CSV" - EngineNDB = "NDB" // NDBCLUSTER - - currentTS = "current_timestamp" - defaultGen = "default_generated" - autoIncrement = "auto_increment" - - virtual = "VIRTUAL" - stored = "STORED" - persistent = "PERSISTENT" -) diff --git a/sql/mysql/driver_oss.go b/sql/mysql/driver_oss.go index f41df078afd..92f4bfcf097 100644 --- a/sql/mysql/driver_oss.go +++ b/sql/mysql/driver_oss.go @@ -8,137 +8,496 @@ package mysql import ( "context" + "database/sql" + "fmt" + "net/url" + "strings" + "time" - "ariga.io/atlas/schemahcl" - "ariga.io/atlas/sql/internal/specutil" "ariga.io/atlas/sql/internal/sqlx" + "ariga.io/atlas/sql/migrate" + "ariga.io/atlas/sql/mysql/internal/mysqlversion" "ariga.io/atlas/sql/schema" + "ariga.io/atlas/sql/sqlclient" "ariga.io/atlas/sql/sqlspec" ) -var ( - specOptions, mariaSpecOptions []schemahcl.Option - specFuncs = &specutil.SchemaFuncs{ - Table: tableSpec, - View: viewSpec, +type ( + // Driver represents a MySQL driver for introspecting database schemas, + // generating diff between schema elements and apply migrations changes. + Driver struct { + *conn + schema.Differ + schema.Inspector + migrate.PlanApplier } - scanFuncs = &specutil.ScanFuncs{ - Table: convertTable, - View: convertView, + + // database connection and its information. + conn struct { + schema.ExecQuerier + // The schema was set in the path (schema connection). + schema string + // System variables that are set on `Open`. + mysqlversion.V + collate string + charset string + lcnames int } ) -func triggersSpec([]*schema.Trigger, *specutil.Doc) ([]*sqlspec.Trigger, error) { - return nil, nil // unimplemented. -} +var _ interface { + migrate.Snapshoter + migrate.StmtScanner + migrate.CleanChecker + schema.TypeParseFormatter +} = (*Driver)(nil) + +// DriverName and DriverMaria holds the names used for registration. +const ( + DriverName = "mysql" + DriverMaria = "mariadb" +) -func (*inspect) tablesQuery(context.Context) string { - return tablesQuery +func init() { + sqlclient.Register( + DriverName, + opener(DriverName), + sqlclient.RegisterDriverOpener(Open), + sqlclient.RegisterCodec(codec, codec), + sqlclient.RegisterFlavours("mysql+unix"), + sqlclient.RegisterURLParser(parser{}), + ) + sqlclient.Register( + DriverMaria, + opener(DriverMaria), + sqlclient.RegisterDriverOpener(Open), + sqlclient.RegisterCodec(mariaCodec, mariaCodec), + sqlclient.RegisterFlavours("mariadb+unix", "maria", "maria+unix"), + sqlclient.RegisterURLParser(parser{}), + ) } -func (*inspect) tablesQueryArgs(context.Context) string { - return tablesQueryArgs +// Open opens a new MySQL driver. +func Open(db schema.ExecQuerier) (migrate.Driver, error) { + c := &conn{ExecQuerier: db} + rows, err := db.QueryContext(context.Background(), variablesQuery) + if err != nil { + return nil, fmt.Errorf("mysql: query system variables: %w", err) + } + if err := sqlx.ScanOne(rows, &c.V, &c.collate, &c.charset, &c.lcnames); err != nil { + return nil, fmt.Errorf("mysql: scan system variables: %w", err) + } + if c.TiDB() { + return &Driver{ + conn: c, + Differ: &sqlx.Diff{DiffDriver: &tdiff{diff{conn: c}}}, + Inspector: &tinspect{inspect{c}}, + PlanApplier: &tplanApply{planApply{c}}, + }, nil + } + return &Driver{ + conn: c, + Differ: &sqlx.Diff{DiffDriver: &diff{conn: c}}, + Inspector: &inspect{c}, + PlanApplier: &planApply{c}, + }, nil } -// newTable creates a new table with the given name and type. -func (*inspect) newTable(name, _ string) *schema.Table { - return schema.NewTable(name) +// opener for the given driver name. +func opener(name string) sqlclient.OpenerFunc { + return func(_ context.Context, u *url.URL) (*sqlclient.Client, error) { + ur := parser{}.ParseURL(u) + db, err := sql.Open(DriverName, ur.DSN) + if err != nil { + return nil, err + } + drv, err := Open(db) + if err != nil { + if cerr := db.Close(); cerr != nil { + err = fmt.Errorf("%w: %v", err, cerr) + } + return nil, err + } + drv.(*Driver).schema = ur.Schema + return &sqlclient.Client{ + Name: name, + DB: db, + URL: ur, + Driver: drv, + }, nil + } } -func (s *state) tableAttr(*sqlx.Builder, schema.Change, schema.Attr) { - // unimplemented. +// NormalizeRealm returns the normal representation of the given database. +func (d *Driver) NormalizeRealm(ctx context.Context, r *schema.Realm) (*schema.Realm, error) { + return (&sqlx.DevDriver{Driver: d}).NormalizeRealm(ctx, r) } -func convertTableAttrs(*sqlspec.Table, *schema.Table) error { - return nil // unimplemented. +// NormalizeSchema returns the normal representation of the given database. +func (d *Driver) NormalizeSchema(ctx context.Context, s *schema.Schema) (*schema.Schema, error) { + return (&sqlx.DevDriver{Driver: d}).NormalizeSchema(ctx, s) } -func tableAttrsSpec(*schema.Table, *sqlspec.Table) { - // unimplemented. +// Lock implements the schema.Locker interface. +func (d *Driver) Lock(ctx context.Context, name string, timeout time.Duration) (schema.UnlockFunc, error) { + conn, err := sqlx.SingleConn(ctx, d.ExecQuerier) + if err != nil { + return nil, err + } + if err := acquire(ctx, conn, name, timeout); err != nil { + conn.Close() + return nil, err + } + return func() error { + defer conn.Close() + rows, err := conn.QueryContext(ctx, "SELECT RELEASE_LOCK(?)", name) + if err != nil { + return err + } + switch released, err := sqlx.ScanNullBool(rows); { + case err != nil: + return err + case !released.Valid || !released.Bool: + return fmt.Errorf("sql/mysql: failed releasing a named lock %q", name) + } + return nil + }, nil } -func viewSpec(*schema.View) (*sqlspec.View, error) { - return nil, nil // unimplemented. +// Snapshot implements migrate.Snapshoter. +func (d *Driver) Snapshot(ctx context.Context) (migrate.RestoreFunc, error) { + // If the connection is bound to a schema, we can restore the state if the schema has no tables. + s, err := d.InspectSchema(ctx, "", nil) + if err != nil && !schema.IsNotExistError(err) { + return nil, err + } + // If a schema was found, it has to have no tables attached to be considered clean. + if s != nil { + if len(s.Tables) > 0 { + return nil, &migrate.NotCleanError{ + State: schema.NewRealm(s), + Reason: fmt.Sprintf("found table %q in schema %q", s.Tables[0].Name, s.Name), + } + } + return d.SchemaRestoreFunc(s), nil + } + // Otherwise, the database can not have any schema. + realm, err := d.InspectRealm(ctx, nil) + if err != nil { + return nil, err + } + if len(realm.Schemas) > 0 { + return nil, &migrate.NotCleanError{State: realm, Reason: fmt.Sprintf("found schema %q", realm.Schemas[0].Name)} + } + return d.RealmRestoreFunc(realm), nil } -func convertView(*sqlspec.View, *schema.Schema) (*schema.View, error) { - return nil, nil // unimplemented. +// SchemaRestoreFunc returns a function that restores the given schema to its desired state. +func (d *Driver) SchemaRestoreFunc(desired *schema.Schema) migrate.RestoreFunc { + return func(ctx context.Context) error { + current, err := d.InspectSchema(ctx, desired.Name, nil) + if err != nil { + return err + } + changes, err := d.SchemaDiff(current, desired) + if err != nil { + return err + } + return d.ApplyChanges(ctx, changes) + } } -func (*inspect) inspectViews(context.Context, *schema.Realm, *schema.InspectOptions) error { - return nil // unimplemented. +// RealmRestoreFunc returns a function that restores the given realm to its desired state. +func (d *Driver) RealmRestoreFunc(desired *schema.Realm) migrate.RestoreFunc { + return func(ctx context.Context) error { + current, err := d.InspectRealm(ctx, nil) + if err != nil { + return err + } + changes, err := d.RealmDiff(current, desired) + if err != nil { + return err + } + return d.ApplyChanges(ctx, changes) + } } -func (*inspect) inspectFuncs(context.Context, *schema.Realm, *schema.InspectOptions) error { - return nil // unimplemented. +// CheckClean implements migrate.CleanChecker. +func (d *Driver) CheckClean(ctx context.Context, revT *migrate.TableIdent) error { + if revT == nil { // accept nil values + revT = &migrate.TableIdent{} + } + s, err := d.InspectSchema(ctx, "", nil) + if err != nil && !schema.IsNotExistError(err) { + return err + } + if s != nil { + if len(s.Tables) == 0 || (revT.Schema == "" || s.Name == revT.Schema) && len(s.Tables) == 1 && s.Tables[0].Name == revT.Name { + return nil + } + return &migrate.NotCleanError{ + State: schema.NewRealm(s), + Reason: fmt.Sprintf("found table %q in schema %q", s.Tables[0].Name, s.Name), + } + } + r, err := d.InspectRealm(ctx, nil) + if err != nil { + return err + } + switch n := len(r.Schemas); { + case n > 1: + return &migrate.NotCleanError{State: r, Reason: fmt.Sprintf("found multiple schemas: %d", len(r.Schemas))} + case n == 1 && r.Schemas[0].Name != revT.Schema: + return &migrate.NotCleanError{State: r, Reason: fmt.Sprintf("found schema %q", r.Schemas[0].Name)} + case n == 1 && len(r.Schemas[0].Tables) > 1: + return &migrate.NotCleanError{State: r, Reason: fmt.Sprintf("found multiple tables: %d", len(r.Schemas[0].Tables))} + case n == 1 && len(r.Schemas[0].Tables) == 1 && r.Schemas[0].Tables[0].Name != revT.Name: + return &migrate.NotCleanError{State: r, Reason: fmt.Sprintf("found table %q", r.Schemas[0].Tables[0].Name)} + } + return nil } -func (*inspect) inspectTriggers(context.Context, *schema.Realm, *schema.InspectOptions) error { - return nil // unimplemented. +// Version returns the version of the connected database. +func (d *Driver) Version() string { + return string(d.conn.V) } -func (*state) addView(*schema.AddView) error { - return nil // unimplemented. +// FormatType converts schema type to its column form in the database. +func (*Driver) FormatType(t schema.Type) (string, error) { + return FormatType(t) } -func (*state) dropView(*schema.DropView) error { - return nil // unimplemented. +// ParseType returns the schema.Type value represented by the given string. +func (*Driver) ParseType(s string) (schema.Type, error) { + return ParseType(s) } -func (*state) modifyView(*schema.ModifyView) error { - return nil // unimplemented. +// StmtBuilder is a helper method used to build statements with MySQL formatting. +func (*Driver) StmtBuilder(opts migrate.PlanOptions) *sqlx.Builder { + return &sqlx.Builder{ + QuoteOpening: '`', + QuoteClosing: '`', + Schema: opts.SchemaQualifier, + Indent: opts.Indent, + } } -func (*state) renameView(*schema.RenameView) { - // unimplemented. +// ScanStmts implements migrate.StmtScanner. +func (*Driver) ScanStmts(input string) ([]*migrate.Stmt, error) { + return (&migrate.Scanner{ + ScannerOptions: migrate.ScannerOptions{ + MatchBegin: true, + BackslashEscapes: true, + HashComments: true, + // The following are not support by MySQL/MariaDB. + MatchBeginAtomic: false, + MatchDollarQuote: false, + }, + }).Scan(input) } -func (*diff) ViewAttrChanges(_, _ *schema.View) []schema.Change { - return nil // Not implemented. +func acquire(ctx context.Context, conn schema.ExecQuerier, name string, timeout time.Duration) error { + rows, err := conn.QueryContext(ctx, "SELECT GET_LOCK(?, ?)", name, int(timeout.Seconds())) + if err != nil { + return err + } + switch acquired, err := sqlx.ScanNullBool(rows); { + case err != nil: + return err + case !acquired.Valid: + // NULL is returned in case of an unexpected internal error. + return fmt.Errorf("sql/mysql: unexpected internal error on Lock(%q, %s)", name, timeout) + case !acquired.Bool: + return schema.ErrLocked + } + return nil } -func (s *state) addFunc(*schema.AddFunc) error { - return nil // unimplemented. +// unescape strings with backslashes returned +// for SQL expressions from information schema. +func unescape(s string) string { + var b strings.Builder + for i := 0; i < len(s); i++ { + switch c := s[i]; { + case c != '\\' || i == len(s)-1: + b.WriteByte(c) + case s[i+1] == '\'', s[i+1] == '\\': + b.WriteByte(s[i+1]) + i++ + } + } + return b.String() } -func (s *state) dropFunc(*schema.DropFunc) error { - return nil // unimplemented. -} +type parser struct{} -func (s *state) modifyFunc(*schema.ModifyFunc) error { - return nil // unimplemented. +// ParseURL implements the sqlclient.URLParser interface. +func (parser) ParseURL(u *url.URL) *sqlclient.URL { + v := u.Query() + v.Set("parseTime", "true") + u.RawQuery = v.Encode() + cu := &sqlclient.URL{URL: u, DSN: dsn(u), Schema: strings.TrimPrefix(u.Path, "/")} + if strings.HasSuffix(u.Scheme, "+unix") { + cu.Schema = v.Get("database") + } + return cu } -func (s *state) renameFunc(*schema.RenameFunc) error { - return nil // unimplemented. +// ChangeSchema implements the sqlclient.SchemaChanger interface. +func (parser) ChangeSchema(u *url.URL, s string) *url.URL { + nu := *u + nu.Path = "/" + s + return &nu } -func (s *state) addProc(*schema.AddProc) error { - return nil // unimplemented. +// dsn returns the MySQL standard DSN for opening +// the sql.DB from the user provided URL. +func dsn(u *url.URL) string { + var ( + b strings.Builder + values = u.Query() + ) + b.WriteString(u.User.Username()) + if p, ok := u.User.Password(); ok { + b.WriteByte(':') + b.WriteString(p) + } + if b.Len() > 0 { + b.WriteByte('@') + } + switch { + case strings.HasSuffix(u.Scheme, "+unix"): + b.WriteString("unix(") + // The path is always absolute, and + // therefore the host should be empty. + b.WriteString(u.Path) + b.WriteString(")/") + if name := values.Get("database"); name != "" { + b.WriteString(name) + values.Del("database") + } + default: + if u.Host != "" { + b.WriteString("tcp(") + b.WriteString(u.Host) + b.WriteByte(')') + } + if u.Path != "" { + b.WriteString(u.Path) + } else { + b.WriteByte('/') + } + } + if p := values.Encode(); p != "" { + b.WriteByte('?') + b.WriteString(p) + } + return b.String() } -func (s *state) dropProc(*schema.DropProc) error { - return nil // unimplemented. -} +// MySQL standard column types as defined in its codebase. Name and order +// is organized differently than MySQL. +// +// https://github.com/mysql/mysql-server/blob/8.0/include/field_types.h +// https://github.com/mysql/mysql-server/blob/8.0/sql/dd/types/column.h +// https://github.com/mysql/mysql-server/blob/8.0/sql/sql_show.cc +// https://github.com/mysql/mysql-server/blob/8.0/sql/gis/geometries.cc +// https://dev.mysql.com/doc/refman/8.0/en/other-vendor-data-types.html +const ( + TypeBool = "bool" + TypeBoolean = "boolean" + + TypeBit = "bit" // MYSQL_TYPE_BIT + TypeInt = "int" // MYSQL_TYPE_LONG + TypeTinyInt = "tinyint" // MYSQL_TYPE_TINY + TypeSmallInt = "smallint" // MYSQL_TYPE_SHORT + TypeMediumInt = "mediumint" // MYSQL_TYPE_INT24 + TypeBigInt = "bigint" // MYSQL_TYPE_LONGLONG + + TypeDecimal = "decimal" // MYSQL_TYPE_DECIMAL + TypeNumeric = "numeric" // MYSQL_TYPE_DECIMAL (numeric_type rule in sql_yacc.yy) + TypeFloat = "float" // MYSQL_TYPE_FLOAT + TypeDouble = "double" // MYSQL_TYPE_DOUBLE + TypeReal = "real" // MYSQL_TYPE_FLOAT or MYSQL_TYPE_DOUBLE (real_type in sql_yacc.yy) + + TypeTimestamp = "timestamp" // MYSQL_TYPE_TIMESTAMP + TypeDate = "date" // MYSQL_TYPE_DATE + TypeTime = "time" // MYSQL_TYPE_TIME + TypeDateTime = "datetime" // MYSQL_TYPE_DATETIME + TypeYear = "year" // MYSQL_TYPE_YEAR + + TypeVarchar = "varchar" // MYSQL_TYPE_VAR_STRING, MYSQL_TYPE_VARCHAR + TypeChar = "char" // MYSQL_TYPE_STRING + TypeVarBinary = "varbinary" // MYSQL_TYPE_VAR_STRING + NULL CHARACTER_SET. + TypeBinary = "binary" // MYSQL_TYPE_STRING + NULL CHARACTER_SET. + TypeBlob = "blob" // MYSQL_TYPE_BLOB + TypeTinyBlob = "tinyblob" // MYSQL_TYPE_TINYBLOB + TypeMediumBlob = "mediumblob" // MYSQL_TYPE_MEDIUM_BLOB + TypeLongBlob = "longblob" // MYSQL_TYPE_LONG_BLOB + TypeText = "text" // MYSQL_TYPE_BLOB + CHARACTER_SET utf8mb4 + TypeTinyText = "tinytext" // MYSQL_TYPE_TINYBLOB + CHARACTER_SET utf8mb4 + TypeMediumText = "mediumtext" // MYSQL_TYPE_MEDIUM_BLOB + CHARACTER_SET utf8mb4 + TypeLongText = "longtext" // MYSQL_TYPE_LONG_BLOB with + CHARACTER_SET utf8mb4 + + TypeEnum = "enum" // MYSQL_TYPE_ENUM + TypeSet = "set" // MYSQL_TYPE_SET + TypeJSON = "json" // MYSQL_TYPE_JSON + + TypeGeometry = "geometry" // MYSQL_TYPE_GEOMETRY + TypePoint = "point" // Geometry_type::kPoint + TypeMultiPoint = "multipoint" // Geometry_type::kMultipoint + TypeLineString = "linestring" // Geometry_type::kLinestring + TypeMultiLineString = "multilinestring" // Geometry_type::kMultilinestring + TypePolygon = "polygon" // Geometry_type::kPolygon + TypeMultiPolygon = "multipolygon" // Geometry_type::kMultipolygon + TypeGeoCollection = "geomcollection" // Geometry_type::kGeometrycollection + TypeGeometryCollection = "geometrycollection" // Geometry_type::kGeometrycollection + + TypeUUID = "uuid" // MariaDB supported uuid type from 10.7.0+ + + TypeInet4 = "inet4" // MariaDB type for storage of IPv4 addresses, from 10.10.0+. + TypeInet6 = "inet6" // MariaDB type for storage of IPv6 addresses, from 10.10.0+. +) -func (s *state) modifyProc(*schema.ModifyProc) error { - return nil // unimplemented. -} +// Additional common constants in MySQL. +const ( + IndexTypeBTree = "BTREE" + IndexTypeHash = "HASH" + IndexTypeFullText = "FULLTEXT" + IndexTypeSpatial = "SPATIAL" + + IndexParserNGram = "ngram" + IndexParserMeCab = "mecab" + + EngineInnoDB = "InnoDB" + EngineMyISAM = "MyISAM" + EngineMemory = "Memory" + EngineCSV = "CSV" + EngineNDB = "NDB" // NDBCLUSTER + + currentTS = "current_timestamp" + defaultGen = "default_generated" + autoIncrement = "auto_increment" + + virtual = "VIRTUAL" + stored = "STORED" + persistent = "PERSISTENT" +) -func (s *state) renameProc(*schema.RenameProc) error { - return nil // unimplemented. +func (*inspect) tablesQuery(context.Context) string { + return tablesQuery } -func (*state) addTrigger(*schema.AddTrigger) error { - return nil // unimplemented. +func (*inspect) tablesQueryArgs(context.Context) string { + return tablesQueryArgs } -func (*state) dropTrigger(*schema.DropTrigger) error { - return nil // unimplemented. +func viewSpec(*schema.View) (*sqlspec.View, error) { + return nil, nil // unimplemented. } -func (*state) modifyTrigger(*schema.ModifyTrigger) error { - return nil // unimplemented. +func convertView(*sqlspec.View, *schema.Schema) (*schema.View, error) { + return nil, nil // unimplemented. } func verifyChanges(context.Context, []schema.Change) error { diff --git a/sql/mysql/driver_test.go b/sql/mysql/driver_oss_test.go similarity index 99% rename from sql/mysql/driver_test.go rename to sql/mysql/driver_oss_test.go index cc49d1de9d0..746ffb2f724 100644 --- a/sql/mysql/driver_test.go +++ b/sql/mysql/driver_oss_test.go @@ -2,6 +2,8 @@ // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. +//go:build !ent + package mysql import ( diff --git a/sql/mysql/inspect.go b/sql/mysql/inspect_oss.go similarity index 97% rename from sql/mysql/inspect.go rename to sql/mysql/inspect_oss.go index 62cfdcdc0d3..f9322f60c56 100644 --- a/sql/mysql/inspect.go +++ b/sql/mysql/inspect_oss.go @@ -2,6 +2,8 @@ // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. +//go:build !ent + package mysql import ( @@ -41,21 +43,6 @@ func (i *inspect) InspectRealm(ctx context.Context, opts *schema.InspectRealmOpt } sqlx.LinkSchemaTables(schemas) } - if mode.Is(schema.InspectViews) { - if err := i.inspectViews(ctx, r, nil); err != nil { - return nil, err - } - } - if mode.Is(schema.InspectFuncs) { - if err := i.inspectFuncs(ctx, r, nil); err != nil { - return nil, err - } - } - if mode.Is(schema.InspectTriggers) { - if err := i.inspectTriggers(ctx, r, nil); err != nil { - return nil, err - } - } } return schema.ExcludeRealm(r, opts.Exclude) } @@ -86,21 +73,6 @@ func (i *inspect) InspectSchema(ctx context.Context, name string, opts *schema.I } sqlx.LinkSchemaTables(schemas) } - if mode.Is(schema.InspectViews) { - if err := i.inspectViews(ctx, r, opts); err != nil { - return nil, err - } - } - if mode.Is(schema.InspectFuncs) { - if err := i.inspectFuncs(ctx, r, opts); err != nil { - return nil, err - } - } - if mode.Is(schema.InspectTriggers) { - if err := i.inspectTriggers(ctx, r, opts); err != nil { - return nil, err - } - } return schema.ExcludeSchema(r.Schemas[0], opts.Exclude) } @@ -212,7 +184,7 @@ func (i *inspect) tables(ctx context.Context, realm *schema.Realm, opts *schema. if !ok { return fmt.Errorf("schema %q was not found in realm", tSchema.String) } - t := i.newTable(name.String, ttyp.String) + t := schema.NewTable(name.String) s.AddTables(t) if sqlx.ValidString(charset) { t.Attrs = append(t.Attrs, &schema.Charset{ diff --git a/sql/mysql/inspect_test.go b/sql/mysql/inspect_oss_test.go similarity index 99% rename from sql/mysql/inspect_test.go rename to sql/mysql/inspect_oss_test.go index b202a82dec8..dad935de1c8 100644 --- a/sql/mysql/inspect_test.go +++ b/sql/mysql/inspect_oss_test.go @@ -2,6 +2,8 @@ // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. +//go:build !ent + package mysql import ( diff --git a/sql/mysql/migrate.go b/sql/mysql/migrate_oss.go similarity index 97% rename from sql/mysql/migrate.go rename to sql/mysql/migrate_oss.go index 96da373483d..f7d6180911b 100644 --- a/sql/mysql/migrate.go +++ b/sql/mysql/migrate_oss.go @@ -2,6 +2,8 @@ // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. +//go:build !ent + package mysql import ( @@ -98,32 +100,6 @@ func (s *state) plan(changes []schema.Change) error { err = s.modifyTable(c) case *schema.RenameTable: s.renameTable(c) - case *schema.AddFunc: - err = s.addFunc(c) - case *schema.AddProc: - err = s.addProc(c) - case *schema.ModifyFunc: - err = s.modifyFunc(c) - case *schema.ModifyProc: - err = s.modifyProc(c) - case *schema.DropFunc: - err = s.dropFunc(c) - case *schema.DropProc: - err = s.dropProc(c) - case *schema.AddView: - err = s.addView(c) - case *schema.DropView: - err = s.dropView(c) - case *schema.ModifyView: - err = s.modifyView(c) - case *schema.RenameView: - s.renameView(c) - case *schema.AddTrigger: - err = s.addTrigger(c) - case *schema.ModifyTrigger: - err = s.modifyTrigger(c) - case *schema.DropTrigger: - err = s.dropTrigger(c) default: err = fmt.Errorf("unsupported change %T", c) } @@ -711,9 +687,6 @@ func (s *state) tableAttrs(b *sqlx.Builder, c schema.Change, attrs ...schema.Att b.P("COLLATE", a.V) case *schema.Comment: b.P("COMMENT", quote(a.Text)) - default: - // Driver/build specific handling. - s.tableAttr(b, c, a) } } } diff --git a/sql/mysql/migrate_test.go b/sql/mysql/migrate_oss_test.go similarity index 99% rename from sql/mysql/migrate_test.go rename to sql/mysql/migrate_oss_test.go index 39130459da7..bf19d26c22c 100644 --- a/sql/mysql/migrate_test.go +++ b/sql/mysql/migrate_oss_test.go @@ -2,6 +2,8 @@ // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. +//go:build !ent + package mysql import ( diff --git a/sql/mysql/sqlspec.go b/sql/mysql/sqlspec_oss.go similarity index 97% rename from sql/mysql/sqlspec.go rename to sql/mysql/sqlspec_oss.go index 1de95e0fb67..98babadc503 100644 --- a/sql/mysql/sqlspec.go +++ b/sql/mysql/sqlspec_oss.go @@ -2,6 +2,8 @@ // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. +//go:build !ent + package mysql import ( @@ -88,6 +90,10 @@ func (c *Codec) MarshalSpec(v any) ([]byte, error) { }) } +func triggersSpec([]*schema.Trigger, *specutil.Doc) ([]*sqlspec.Trigger, error) { + return nil, nil // unimplemented. +} + var ( registrySpecs = TypeRegistry.Specs() sharedSpecOptions = []schemahcl.Option{ @@ -127,7 +133,16 @@ var ( // EvalMariaHCL implements the schemahcl.Evaluator interface for MariaDB flavor. EvalMariaHCL = schemahcl.EvalFunc(mariaCodec.Eval) // EvalMariaHCLBytes is a helper that evaluates a MariaDB HCL document from a byte slice. - EvalMariaHCLBytes = specutil.HCLBytesFunc(EvalMariaHCL) + EvalMariaHCLBytes = specutil.HCLBytesFunc(EvalMariaHCL) + specOptions, mariaSpecOptions []schemahcl.Option + specFuncs = &specutil.SchemaFuncs{ + Table: tableSpec, + View: viewSpec, + } + scanFuncs = &specutil.ScanFuncs{ + Table: convertTable, + View: convertView, + } ) // convertTable converts a sqlspec.Table to a schema.Table. Table conversion is done without converting @@ -157,9 +172,6 @@ func convertTable(spec *sqlspec.Table, parent *schema.Schema) (*schema.Table, er } t.AddAttrs(&Engine{V: v}) } - if err := convertTableAttrs(spec, t); err != nil { - return nil, err - } return t, nil } @@ -320,7 +332,6 @@ func tableSpec(t *schema.Table) (*sqlspec.Table, error) { } ts.Extra.Attrs = append(ts.Extra.Attrs, attr) } - tableAttrsSpec(t, ts) return ts, nil } diff --git a/sql/mysql/sqlspec_test.go b/sql/mysql/sqlspec_oss_test.go similarity index 99% rename from sql/mysql/sqlspec_test.go rename to sql/mysql/sqlspec_oss_test.go index 3084b62c3cd..1f939d4e499 100644 --- a/sql/mysql/sqlspec_test.go +++ b/sql/mysql/sqlspec_oss_test.go @@ -2,6 +2,8 @@ // This source code is licensed under the Apache 2.0 license found // in the LICENSE file in the root directory of this source tree. +//go:build !ent + package mysql import (