diff --git a/.github/workflows/ci-bats-unix.yaml b/.github/workflows/ci-bats-unix.yaml index 0f71d0e9f8a..b2cdc65db2e 100644 --- a/.github/workflows/ci-bats-unix.yaml +++ b/.github/workflows/ci-bats-unix.yaml @@ -106,7 +106,7 @@ jobs: run: | curl -OL https://github.com/apache/parquet-mr/archive/refs/tags/apache-parquet-1.12.3.tar.gz tar zxvf apache-parquet-1.12.3.tar.gz - cd parquet-mr-apache-parquet-1.12.3/parquet-cli + cd parquet-java-apache-parquet-1.12.3/parquet-cli mvn clean install -DskipTests runtime_jar="$(pwd)"/target/parquet-cli-1.12.3-runtime.jar echo "runtime_jar=$runtime_jar" >> $GITHUB_OUTPUT diff --git a/go/libraries/doltcore/diff/async_differ_test.go b/go/libraries/doltcore/diff/async_differ_test.go index e582120b298..87a7d243031 100644 --- a/go/libraries/doltcore/diff/async_differ_test.go +++ b/go/libraries/doltcore/diff/async_differ_test.go @@ -22,7 +22,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - dtu "github.com/dolthub/dolt/go/libraries/doltcore/dtestutils" "github.com/dolthub/dolt/go/libraries/doltcore/schema" "github.com/dolthub/dolt/go/store/chunks" "github.com/dolthub/dolt/go/store/constants" @@ -332,7 +331,15 @@ func getKeylessRow(ctx context.Context, vals []types.Value) ([]types.Value, erro vals = append(prefix, vals...) return []types.Value{ - dtu.MustTuple(rowIdTag, id1), - dtu.MustTuple(vals...), + mustTuple(rowIdTag, id1), + mustTuple(vals...), }, nil } + +func mustTuple(vals ...types.Value) types.Tuple { + tup, err := types.NewTuple(types.Format_Default, vals...) + if err != nil { + panic(err) + } + return tup +} diff --git a/go/libraries/doltcore/diff/table_deltas.go b/go/libraries/doltcore/diff/table_deltas.go index 20e6dcaef15..5590c7efe24 100644 --- a/go/libraries/doltcore/diff/table_deltas.go +++ b/go/libraries/doltcore/diff/table_deltas.go @@ -24,7 +24,6 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/doltcore/doltdb/durable" "github.com/dolthub/dolt/go/libraries/doltcore/schema" - "github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlfmt" "github.com/dolthub/dolt/go/libraries/utils/set" "github.com/dolthub/dolt/go/store/prolly/tree" "github.com/dolthub/dolt/go/store/types" @@ -614,36 +613,6 @@ func (td TableDelta) GetRowData(ctx context.Context) (from, to durable.Index, er return from, to, nil } -// GetDataDiffStatement returns any data diff in SQL statements for given table including INSERT, UPDATE and DELETE row statements. -func GetDataDiffStatement(tableName string, sch schema.Schema, row sql.Row, rowDiffType ChangeType, colDiffTypes []ChangeType) (string, error) { - if len(row) != len(colDiffTypes) { - return "", fmt.Errorf("expected the same size for columns and diff types, got %d and %d", len(row), len(colDiffTypes)) - } - - switch rowDiffType { - case Added: - return sqlfmt.SqlRowAsInsertStmt(row, tableName, sch) - case Removed: - return sqlfmt.SqlRowAsDeleteStmt(row, tableName, sch, 0) - case ModifiedNew: - updatedCols := set.NewEmptyStrSet() - for i, diffType := range colDiffTypes { - if diffType != None { - updatedCols.Add(sch.GetAllCols().GetByIndex(i).Name) - } - } - if updatedCols.Size() == 0 { - return "", nil - } - return sqlfmt.SqlRowAsUpdateStmt(row, tableName, sch, updatedCols) - case ModifiedOld: - // do nothing, we only issue UPDATE for ModifiedNew - return "", nil - default: - return "", fmt.Errorf("unexpected row diff type: %v", rowDiffType) - } -} - // WorkingSetContainsOnlyIgnoredTables returns true if all changes in working set are ignored tables. // Otherwise, if there are any non-ignored changes, returns false. // Note that only unstaged tables are subject to dolt_ignore (this is consistent with what git does.) diff --git a/go/libraries/doltcore/merge/merge_prolly_rows.go b/go/libraries/doltcore/merge/merge_prolly_rows.go index e25db0abf41..125a5d2441a 100644 --- a/go/libraries/doltcore/merge/merge_prolly_rows.go +++ b/go/libraries/doltcore/merge/merge_prolly_rows.go @@ -33,6 +33,7 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/doltdb/durable" "github.com/dolthub/dolt/go/libraries/doltcore/schema" "github.com/dolthub/dolt/go/libraries/doltcore/schema/typeinfo" + "github.com/dolthub/dolt/go/libraries/doltcore/sqle/expranalysis" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/index" "github.com/dolthub/dolt/go/store/hash" "github.com/dolthub/dolt/go/store/pool" @@ -374,7 +375,7 @@ func newCheckValidator(ctx *sql.Context, tm *TableMerger, vm *valueMerger, sch s continue } - expr, err := index.ResolveCheckExpression(ctx, tm.name, sch, check.Expression()) + expr, err := expranalysis.ResolveCheckExpression(ctx, tm.name, sch, check.Expression()) if err != nil { return checkValidator{}, err } @@ -1192,7 +1193,7 @@ func resolveDefaults(ctx *sql.Context, tableName string, mergedSchema schema.Sch } if col.Default != "" || col.Generated != "" || col.OnUpdate != "" { - expr, err := index.ResolveDefaultExpression(ctx, tableName, mergedSchema, col) + expr, err := expranalysis.ResolveDefaultExpression(ctx, tableName, mergedSchema, col) if err != nil { return true, err } diff --git a/go/libraries/doltcore/sqle/dolt_patch_table_function.go b/go/libraries/doltcore/sqle/dolt_patch_table_function.go index e4949eec947..8f56f6d2725 100644 --- a/go/libraries/doltcore/sqle/dolt_patch_table_function.go +++ b/go/libraries/doltcore/sqle/dolt_patch_table_function.go @@ -28,9 +28,7 @@ import ( "github.com/dolthub/vitess/go/mysql" "golang.org/x/exp/slices" - "github.com/dolthub/dolt/go/cmd/dolt/errhand" "github.com/dolthub/dolt/go/libraries/doltcore/diff" - "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/doltcore/env" "github.com/dolthub/dolt/go/libraries/doltcore/schema" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" @@ -514,7 +512,7 @@ func getPatchNodes(ctx *sql.Context, dbData env.DbData, tableDeltas []diff.Table // Get SCHEMA DIFF var schemaStmts []string if includeSchemaDiff { - schemaStmts, err = GenerateSqlPatchSchemaStatements(ctx, toRefDetails.root, td) + schemaStmts, err = sqlfmt.GenerateSqlPatchSchemaStatements(ctx, toRefDetails.root, td) if err != nil { return nil, err } @@ -535,41 +533,6 @@ func getPatchNodes(ctx *sql.Context, dbData env.DbData, tableDeltas []diff.Table return patches, nil } -// GenerateSqlPatchSchemaStatements examines the table schema changes in the specified TableDelta |td| and returns -// a slice of SQL path statements that represent the equivalent SQL DDL statements for those schema changes. The -// specified RootValue, |toRoot|, must be the RootValue that was used as the "To" root when computing the specified -// TableDelta. -func GenerateSqlPatchSchemaStatements(ctx *sql.Context, toRoot doltdb.RootValue, td diff.TableDelta) ([]string, error) { - toSchemas, err := doltdb.GetAllSchemas(ctx, toRoot) - if err != nil { - return nil, fmt.Errorf("could not read schemas from toRoot, cause: %s", err.Error()) - } - - fromSch, toSch, err := td.GetSchemas(ctx) - if err != nil { - return nil, fmt.Errorf("cannot retrieve schema for table %s, cause: %s", td.ToName, err.Error()) - } - - var ddlStatements []string - if td.IsDrop() { - ddlStatements = append(ddlStatements, sqlfmt.DropTableStmt(td.FromName)) - } else if td.IsAdd() { - stmt, err := sqlfmt.GenerateCreateTableStatement(td.ToName, td.ToSch, td.ToFks, td.ToFksParentSch) - if err != nil { - return nil, errhand.VerboseErrorFromError(err) - } - ddlStatements = append(ddlStatements, stmt) - } else { - stmts, err := GetNonCreateNonDropTableSqlSchemaDiff(td, toSchemas, fromSch, toSch) - if err != nil { - return nil, err - } - ddlStatements = append(ddlStatements, stmts...) - } - - return ddlStatements, nil -} - func canGetDataDiff(ctx *sql.Context, td diff.TableDelta) bool { if td.IsDrop() { return false // don't output DELETE FROM statements after DROP TABLE @@ -630,14 +593,14 @@ func getDataSqlPatchResults(ctx *sql.Context, diffQuerySch, targetSch sql.Schema var stmt string if oldRow.Row != nil { - stmt, err = diff.GetDataDiffStatement(tn, tsch, oldRow.Row, oldRow.RowDiff, oldRow.ColDiffs) + stmt, err = sqlfmt.GenerateDataDiffStatement(tn, tsch, oldRow.Row, oldRow.RowDiff, oldRow.ColDiffs) if err != nil { return nil, err } } if newRow.Row != nil { - stmt, err = diff.GetDataDiffStatement(tn, tsch, newRow.Row, newRow.RowDiff, newRow.ColDiffs) + stmt, err = sqlfmt.GenerateDataDiffStatement(tn, tsch, newRow.Row, newRow.RowDiff, newRow.ColDiffs) if err != nil { return nil, err } @@ -649,96 +612,6 @@ func getDataSqlPatchResults(ctx *sql.Context, diffQuerySch, targetSch sql.Schema } } -// GetNonCreateNonDropTableSqlSchemaDiff returns any schema diff in SQL statements that is NEITHER 'CREATE TABLE' NOR 'DROP TABLE' statements. -func GetNonCreateNonDropTableSqlSchemaDiff(td diff.TableDelta, toSchemas map[string]schema.Schema, fromSch, toSch schema.Schema) ([]string, error) { - if td.IsAdd() || td.IsDrop() { - // use add and drop specific methods - return nil, nil - } - - var ddlStatements []string - if td.FromName != td.ToName { - ddlStatements = append(ddlStatements, sqlfmt.RenameTableStmt(td.FromName, td.ToName)) - } - - eq := schema.SchemasAreEqual(fromSch, toSch) - if eq && !td.HasFKChanges() { - return ddlStatements, nil - } - - colDiffs, unionTags := diff.DiffSchColumns(fromSch, toSch) - for _, tag := range unionTags { - cd := colDiffs[tag] - switch cd.DiffType { - case diff.SchDiffNone: - case diff.SchDiffAdded: - ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddColStmt(td.ToName, sqlfmt.GenerateCreateTableColumnDefinition(*cd.New, sql.CollationID(td.ToSch.GetCollation())))) - case diff.SchDiffRemoved: - ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropColStmt(td.ToName, cd.Old.Name)) - case diff.SchDiffModified: - // Ignore any primary key set changes here - if cd.Old.IsPartOfPK != cd.New.IsPartOfPK { - continue - } - if cd.Old.Name != cd.New.Name { - ddlStatements = append(ddlStatements, sqlfmt.AlterTableRenameColStmt(td.ToName, cd.Old.Name, cd.New.Name)) - } - if !cd.Old.TypeInfo.Equals(cd.New.TypeInfo) { - ddlStatements = append(ddlStatements, sqlfmt.AlterTableModifyColStmt(td.ToName, - sqlfmt.GenerateCreateTableColumnDefinition(*cd.New, sql.CollationID(td.ToSch.GetCollation())))) - } - } - } - - // Print changes between a primary key set change. It contains an ALTER TABLE DROP and an ALTER TABLE ADD - if !schema.ColCollsAreEqual(fromSch.GetPKCols(), toSch.GetPKCols()) { - ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropPks(td.ToName)) - if toSch.GetPKCols().Size() > 0 { - ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddPrimaryKeys(td.ToName, toSch.GetPKCols().GetColumnNames())) - } - } - - for _, idxDiff := range diff.DiffSchIndexes(fromSch, toSch) { - switch idxDiff.DiffType { - case diff.SchDiffNone: - case diff.SchDiffAdded: - ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddIndexStmt(td.ToName, idxDiff.To)) - case diff.SchDiffRemoved: - ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropIndexStmt(td.FromName, idxDiff.From)) - case diff.SchDiffModified: - ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropIndexStmt(td.FromName, idxDiff.From)) - ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddIndexStmt(td.ToName, idxDiff.To)) - } - } - - for _, fkDiff := range diff.DiffForeignKeys(td.FromFks, td.ToFks) { - switch fkDiff.DiffType { - case diff.SchDiffNone: - case diff.SchDiffAdded: - parentSch := toSchemas[fkDiff.To.ReferencedTableName] - ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddForeignKeyStmt(fkDiff.To, toSch, parentSch)) - case diff.SchDiffRemoved: - from := fkDiff.From - ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropForeignKeyStmt(from.TableName, from.Name)) - case diff.SchDiffModified: - from := fkDiff.From - ddlStatements = append(ddlStatements, sqlfmt.AlterTableDropForeignKeyStmt(from.TableName, from.Name)) - - parentSch := toSchemas[fkDiff.To.ReferencedTableName] - ddlStatements = append(ddlStatements, sqlfmt.AlterTableAddForeignKeyStmt(fkDiff.To, toSch, parentSch)) - } - } - - // Handle charset/collation changes - toCollation := toSch.GetCollation() - fromCollation := fromSch.GetCollation() - if toCollation != fromCollation { - ddlStatements = append(ddlStatements, sqlfmt.AlterTableCollateStmt(td.ToName, fromCollation, toCollation)) - } - - return ddlStatements, nil -} - // getDiffQuery returns diff schema for specified columns and array of sql.Expression as projection to be used // on diff table function row iter. This function attempts to imitate running a query // fmt.Sprintf("select %s, %s from dolt_diff('%s', '%s', '%s')", columnsWithDiff, "diff_type", fromRef, toRef, tableName) diff --git a/go/libraries/doltcore/sqle/expranalysis/expranalysis.go b/go/libraries/doltcore/sqle/expranalysis/expranalysis.go new file mode 100644 index 00000000000..05ff6883e39 --- /dev/null +++ b/go/libraries/doltcore/sqle/expranalysis/expranalysis.go @@ -0,0 +1,109 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expranalysis + +import ( + "fmt" + + "github.com/dolthub/go-mysql-server/memory" + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/analyzer" + "github.com/dolthub/go-mysql-server/sql/expression" + "github.com/dolthub/go-mysql-server/sql/plan" + "github.com/dolthub/go-mysql-server/sql/planbuilder" + "github.com/dolthub/go-mysql-server/sql/transform" + + "github.com/dolthub/dolt/go/libraries/doltcore/schema" + "github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlfmt" +) + +// ResolveDefaultExpression returns a sql.Expression for the column default or generated expression for the +// column provided +func ResolveDefaultExpression(ctx *sql.Context, tableName string, sch schema.Schema, col schema.Column) (sql.Expression, error) { + ct, err := parseCreateTable(ctx, tableName, sch) + if err != nil { + return nil, err + } + + colIdx := ct.PkSchema().Schema.IndexOfColName(col.Name) + if colIdx < 0 { + return nil, fmt.Errorf("unable to find column %s in analyzed query", col.Name) + } + + sqlCol := ct.PkSchema().Schema[colIdx] + expr := sqlCol.Default + if expr == nil || expr.Expr == nil { + expr = sqlCol.Generated + } + + if expr == nil || expr.Expr == nil { + return nil, fmt.Errorf("unable to find default or generated expression") + } + + return expr.Expr, nil +} + +// ResolveCheckExpression returns a sql.Expression for the check provided +func ResolveCheckExpression(ctx *sql.Context, tableName string, sch schema.Schema, checkExpr string) (sql.Expression, error) { + ct, err := parseCreateTable(ctx, tableName, sch) + if err != nil { + return nil, err + } + + for _, check := range ct.Checks() { + if stripTableNamesFromExpression(check.Expr).String() == checkExpr { + return check.Expr, nil + } + } + + return nil, fmt.Errorf("unable to find check expression") +} + +func stripTableNamesFromExpression(expr sql.Expression) sql.Expression { + e, _, _ := transform.Expr(expr, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) { + if col, ok := e.(*expression.GetField); ok { + return col.WithTable(""), transform.NewTree, nil + } + return e, transform.SameTree, nil + }) + return e +} + +func parseCreateTable(ctx *sql.Context, tableName string, sch schema.Schema) (*plan.CreateTable, error) { + createTable, err := sqlfmt.GenerateCreateTableStatement(tableName, sch, nil, nil) + if err != nil { + return nil, err + } + + query := createTable + + mockDatabase := memory.NewDatabase("mydb") + mockProvider := memory.NewDBProvider(mockDatabase) + catalog := analyzer.NewCatalog(mockProvider) + parseCtx := sql.NewEmptyContext() + parseCtx.SetCurrentDatabase("mydb") + + b := planbuilder.New(parseCtx, catalog, sql.NewMysqlParser()) + pseudoAnalyzedQuery, _, _, err := b.Parse(query, false) + if err != nil { + return nil, err + } + + ct, ok := pseudoAnalyzedQuery.(*plan.CreateTable) + if !ok { + return nil, fmt.Errorf("expected a *plan.CreateTable node, but got %T", pseudoAnalyzedQuery) + } + return ct, nil +} diff --git a/go/libraries/doltcore/sqle/index/key_builder.go b/go/libraries/doltcore/sqle/index/key_builder.go index 70b9f149ae3..e42045b7c8c 100644 --- a/go/libraries/doltcore/sqle/index/key_builder.go +++ b/go/libraries/doltcore/sqle/index/key_builder.go @@ -16,18 +16,11 @@ package index import ( "context" - "fmt" - "github.com/dolthub/go-mysql-server/memory" "github.com/dolthub/go-mysql-server/sql" - "github.com/dolthub/go-mysql-server/sql/analyzer" - "github.com/dolthub/go-mysql-server/sql/expression" - "github.com/dolthub/go-mysql-server/sql/plan" - "github.com/dolthub/go-mysql-server/sql/planbuilder" - "github.com/dolthub/go-mysql-server/sql/transform" "github.com/dolthub/dolt/go/libraries/doltcore/schema" - "github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlfmt" + "github.com/dolthub/dolt/go/libraries/doltcore/sqle/expranalysis" "github.com/dolthub/dolt/go/store/pool" "github.com/dolthub/dolt/go/store/prolly" "github.com/dolthub/dolt/go/store/prolly/tree" @@ -65,7 +58,7 @@ func NewSecondaryKeyBuilder(ctx *sql.Context, tableName string, sch schema.Schem virtualExpressions = make([]sql.Expression, len(def.AllTags())) } - expr, err := ResolveDefaultExpression(ctx, tableName, sch, col) + expr, err := expranalysis.ResolveDefaultExpression(ctx, tableName, sch, col) if err != nil { return SecondaryKeyBuilder{}, err } @@ -91,85 +84,6 @@ func NewSecondaryKeyBuilder(ctx *sql.Context, tableName string, sch schema.Schem return b, nil } -// ResolveDefaultExpression returns a sql.Expression for the column default or generated expression for the -// column provided -func ResolveDefaultExpression(ctx *sql.Context, tableName string, sch schema.Schema, col schema.Column) (sql.Expression, error) { - ct, err := parseCreateTable(ctx, tableName, sch) - if err != nil { - return nil, err - } - - colIdx := ct.PkSchema().Schema.IndexOfColName(col.Name) - if colIdx < 0 { - return nil, fmt.Errorf("unable to find column %s in analyzed query", col.Name) - } - - sqlCol := ct.PkSchema().Schema[colIdx] - expr := sqlCol.Default - if expr == nil || expr.Expr == nil { - expr = sqlCol.Generated - } - - if expr == nil || expr.Expr == nil { - return nil, fmt.Errorf("unable to find default or generated expression") - } - - return expr.Expr, nil -} - -// ResolveCheckExpression returns a sql.Expression for the check provided -func ResolveCheckExpression(ctx *sql.Context, tableName string, sch schema.Schema, checkExpr string) (sql.Expression, error) { - ct, err := parseCreateTable(ctx, tableName, sch) - if err != nil { - return nil, err - } - - for _, check := range ct.Checks() { - if stripTableNamesFromExpression(check.Expr).String() == checkExpr { - return check.Expr, nil - } - } - - return nil, fmt.Errorf("unable to find check expression") -} - -func stripTableNamesFromExpression(expr sql.Expression) sql.Expression { - e, _, _ := transform.Expr(expr, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) { - if col, ok := e.(*expression.GetField); ok { - return col.WithTable(""), transform.NewTree, nil - } - return e, transform.SameTree, nil - }) - return e -} - -func parseCreateTable(ctx *sql.Context, tableName string, sch schema.Schema) (*plan.CreateTable, error) { - createTable, err := sqlfmt.GenerateCreateTableStatement(tableName, sch, nil, nil) - if err != nil { - return nil, err - } - - query := createTable - - mockDatabase := memory.NewDatabase("mydb") - mockProvider := memory.NewDBProvider(mockDatabase) - catalog := analyzer.NewCatalog(mockProvider) - parseCtx := sql.NewEmptyContext() - parseCtx.SetCurrentDatabase("mydb") - - b := planbuilder.New(parseCtx, catalog, sql.NewMysqlParser()) - pseudoAnalyzedQuery, _, _, err := b.Parse(query, false) - if err != nil { - return nil, err - } - - ct, ok := pseudoAnalyzedQuery.(*plan.CreateTable) - if !ok { - return nil, fmt.Errorf("expected a *plan.CreateTable node, but got %T", pseudoAnalyzedQuery) - } - return ct, nil -} - type SecondaryKeyBuilder struct { // sch holds the schema of the table on which the secondary index is created sch schema.Schema diff --git a/go/libraries/doltcore/sqle/sqlfmt/schema_fmt.go b/go/libraries/doltcore/sqle/sqlfmt/schema_fmt.go index 42cac4624f4..d2c599ecb87 100644 --- a/go/libraries/doltcore/sqle/sqlfmt/schema_fmt.go +++ b/go/libraries/doltcore/sqle/sqlfmt/schema_fmt.go @@ -20,10 +20,168 @@ import ( "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/dolt/go/cmd/dolt/errhand" + "github.com/dolthub/dolt/go/libraries/doltcore/diff" "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/doltcore/schema" + "github.com/dolthub/dolt/go/libraries/utils/set" ) +// GenerateDataDiffStatement returns any data diff in SQL statements for given table including INSERT, UPDATE and DELETE row statements. +func GenerateDataDiffStatement(tableName string, sch schema.Schema, row sql.Row, rowDiffType diff.ChangeType, colDiffTypes []diff.ChangeType) (string, error) { + if len(row) != len(colDiffTypes) { + return "", fmt.Errorf("expected the same size for columns and diff types, got %d and %d", len(row), len(colDiffTypes)) + } + + switch rowDiffType { + case diff.Added: + return SqlRowAsInsertStmt(row, tableName, sch) + case diff.Removed: + return SqlRowAsDeleteStmt(row, tableName, sch, 0) + case diff.ModifiedNew: + updatedCols := set.NewEmptyStrSet() + for i, diffType := range colDiffTypes { + if diffType != diff.None { + updatedCols.Add(sch.GetAllCols().GetByIndex(i).Name) + } + } + if updatedCols.Size() == 0 { + return "", nil + } + return SqlRowAsUpdateStmt(row, tableName, sch, updatedCols) + case diff.ModifiedOld: + // do nothing, we only issue UPDATE for ModifiedNew + return "", nil + default: + return "", fmt.Errorf("unexpected row diff type: %v", rowDiffType) + } +} + +// GenerateSqlPatchSchemaStatements examines the table schema changes in the specified TableDelta |td| and returns +// a slice of SQL path statements that represent the equivalent SQL DDL statements for those schema changes. The +// specified RootValue, |toRoot|, must be the RootValue that was used as the "To" root when computing the specified +// TableDelta. +func GenerateSqlPatchSchemaStatements(ctx *sql.Context, toRoot doltdb.RootValue, td diff.TableDelta) ([]string, error) { + toSchemas, err := doltdb.GetAllSchemas(ctx, toRoot) + if err != nil { + return nil, fmt.Errorf("could not read schemas from toRoot, cause: %s", err.Error()) + } + + fromSch, toSch, err := td.GetSchemas(ctx) + if err != nil { + return nil, fmt.Errorf("cannot retrieve schema for table %s, cause: %s", td.ToName, err.Error()) + } + + var ddlStatements []string + if td.IsDrop() { + ddlStatements = append(ddlStatements, DropTableStmt(td.FromName)) + } else if td.IsAdd() { + stmt, err := GenerateCreateTableStatement(td.ToName, td.ToSch, td.ToFks, td.ToFksParentSch) + if err != nil { + return nil, errhand.VerboseErrorFromError(err) + } + ddlStatements = append(ddlStatements, stmt) + } else { + stmts, err := generateNonCreateNonDropTableSqlSchemaDiff(td, toSchemas, fromSch, toSch) + if err != nil { + return nil, err + } + ddlStatements = append(ddlStatements, stmts...) + } + + return ddlStatements, nil +} + +// generateNonCreateNonDropTableSqlSchemaDiff returns any schema diff in SQL statements that is NEITHER 'CREATE TABLE' NOR 'DROP TABLE' statements. +func generateNonCreateNonDropTableSqlSchemaDiff(td diff.TableDelta, toSchemas map[string]schema.Schema, fromSch, toSch schema.Schema) ([]string, error) { + if td.IsAdd() || td.IsDrop() { + // use add and drop specific methods + return nil, nil + } + + var ddlStatements []string + if td.FromName != td.ToName { + ddlStatements = append(ddlStatements, RenameTableStmt(td.FromName, td.ToName)) + } + + eq := schema.SchemasAreEqual(fromSch, toSch) + if eq && !td.HasFKChanges() { + return ddlStatements, nil + } + + colDiffs, unionTags := diff.DiffSchColumns(fromSch, toSch) + for _, tag := range unionTags { + cd := colDiffs[tag] + switch cd.DiffType { + case diff.SchDiffNone: + case diff.SchDiffAdded: + ddlStatements = append(ddlStatements, AlterTableAddColStmt(td.ToName, GenerateCreateTableColumnDefinition(*cd.New, sql.CollationID(td.ToSch.GetCollation())))) + case diff.SchDiffRemoved: + ddlStatements = append(ddlStatements, AlterTableDropColStmt(td.ToName, cd.Old.Name)) + case diff.SchDiffModified: + // Ignore any primary key set changes here + if cd.Old.IsPartOfPK != cd.New.IsPartOfPK { + continue + } + if cd.Old.Name != cd.New.Name { + ddlStatements = append(ddlStatements, AlterTableRenameColStmt(td.ToName, cd.Old.Name, cd.New.Name)) + } + if !cd.Old.TypeInfo.Equals(cd.New.TypeInfo) { + ddlStatements = append(ddlStatements, AlterTableModifyColStmt(td.ToName, + GenerateCreateTableColumnDefinition(*cd.New, sql.CollationID(td.ToSch.GetCollation())))) + } + } + } + + // Print changes between a primary key set change. It contains an ALTER TABLE DROP and an ALTER TABLE ADD + if !schema.ColCollsAreEqual(fromSch.GetPKCols(), toSch.GetPKCols()) { + ddlStatements = append(ddlStatements, AlterTableDropPks(td.ToName)) + if toSch.GetPKCols().Size() > 0 { + ddlStatements = append(ddlStatements, AlterTableAddPrimaryKeys(td.ToName, toSch.GetPKCols().GetColumnNames())) + } + } + + for _, idxDiff := range diff.DiffSchIndexes(fromSch, toSch) { + switch idxDiff.DiffType { + case diff.SchDiffNone: + case diff.SchDiffAdded: + ddlStatements = append(ddlStatements, AlterTableAddIndexStmt(td.ToName, idxDiff.To)) + case diff.SchDiffRemoved: + ddlStatements = append(ddlStatements, AlterTableDropIndexStmt(td.FromName, idxDiff.From)) + case diff.SchDiffModified: + ddlStatements = append(ddlStatements, AlterTableDropIndexStmt(td.FromName, idxDiff.From)) + ddlStatements = append(ddlStatements, AlterTableAddIndexStmt(td.ToName, idxDiff.To)) + } + } + + for _, fkDiff := range diff.DiffForeignKeys(td.FromFks, td.ToFks) { + switch fkDiff.DiffType { + case diff.SchDiffNone: + case diff.SchDiffAdded: + parentSch := toSchemas[fkDiff.To.ReferencedTableName] + ddlStatements = append(ddlStatements, AlterTableAddForeignKeyStmt(fkDiff.To, toSch, parentSch)) + case diff.SchDiffRemoved: + from := fkDiff.From + ddlStatements = append(ddlStatements, AlterTableDropForeignKeyStmt(from.TableName, from.Name)) + case diff.SchDiffModified: + from := fkDiff.From + ddlStatements = append(ddlStatements, AlterTableDropForeignKeyStmt(from.TableName, from.Name)) + + parentSch := toSchemas[fkDiff.To.ReferencedTableName] + ddlStatements = append(ddlStatements, AlterTableAddForeignKeyStmt(fkDiff.To, toSch, parentSch)) + } + } + + // Handle charset/collation changes + toCollation := toSch.GetCollation() + fromCollation := fromSch.GetCollation() + if toCollation != fromCollation { + ddlStatements = append(ddlStatements, AlterTableCollateStmt(td.ToName, fromCollation, toCollation)) + } + + return ddlStatements, nil +} + // GenerateCreateTableColumnDefinition returns column definition for CREATE TABLE statement with no indentation func GenerateCreateTableColumnDefinition(col schema.Column, tableCollation sql.CollationID) string { colStr := GenerateCreateTableIndentedColumnDefinition(col, tableCollation) diff --git a/go/libraries/doltcore/table/untyped/sqlexport/sql_diff_writer.go b/go/libraries/doltcore/table/untyped/sqlexport/sql_diff_writer.go index ebd33af7b10..760072670f2 100755 --- a/go/libraries/doltcore/table/untyped/sqlexport/sql_diff_writer.go +++ b/go/libraries/doltcore/table/untyped/sqlexport/sql_diff_writer.go @@ -23,6 +23,7 @@ import ( "github.com/dolthub/dolt/go/libraries/doltcore/diff" "github.com/dolthub/dolt/go/libraries/doltcore/schema" + "github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlfmt" "github.com/dolthub/dolt/go/libraries/doltcore/table/editor" "github.com/dolthub/dolt/go/libraries/utils/iohelp" ) @@ -49,7 +50,7 @@ func NewSqlDiffWriter(tableName string, schema schema.Schema, wr io.WriteCloser) } func (w SqlDiffWriter) WriteRow(ctx context.Context, row sql.Row, rowDiffType diff.ChangeType, colDiffTypes []diff.ChangeType) error { - stmt, err := diff.GetDataDiffStatement(w.tableName, w.sch, row, rowDiffType, colDiffTypes) + stmt, err := sqlfmt.GenerateDataDiffStatement(w.tableName, w.sch, row, rowDiffType, colDiffTypes) if err != nil { return err }