Skip to content

Commit

Permalink
Miscellaneous changes in export schema transformations (#2378)
Browse files Browse the repository at this point in the history
* Added more unit test for testing pg parser Deparsing API

* Miscellaneous changes
- fixed the console output whether recommendations are applied or not.
- regression: added --skip-recommendations flag back
  • Loading branch information
sanyamsinghal authored Mar 6, 2025
1 parent 941eec4 commit 7d21b3e
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 30 deletions.
100 changes: 77 additions & 23 deletions yb-voyager/cmd/exportSchema.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ import (

var skipRecommendations utils.BoolStr
var assessmentReportPath string
var assessmentRecommendationsApplied bool
var assessmentRecommendationsApplied = false
var errorApplyingAssessmentRecommendations = false

var exportSchemaCmd = &cobra.Command{
Use: "schema",
Expand Down Expand Up @@ -311,10 +312,11 @@ func applySchemaTransformations() {
// 1. Transform table.sql
{
tableFilePath := utils.GetObjectFilePath(schemaDir, TABLE)
transformations := []func([]*pg_query.RawStmt) ([]*pg_query.RawStmt, error){
applyShardedTableTransformation, // transform #1
applyMergeConstraintsTransformation, // transform #2
var transformations []func([]*pg_query.RawStmt, string) ([]*pg_query.RawStmt, error)
if !skipRecommendations {
transformations = append(transformations, applyShardedTableTransformation) // transform #1
}
transformations = append(transformations, applyMergeConstraintsTransformation) // transform #2

err := transformSchemaFile(tableFilePath, transformations, "table")
if err != nil {
Expand All @@ -325,49 +327,77 @@ func applySchemaTransformations() {
// 2. Transform mview.sql
{
mviewFilePath := utils.GetObjectFilePath(schemaDir, MVIEW)
transformations := []func([]*pg_query.RawStmt) ([]*pg_query.RawStmt, error){
applyShardedTableTransformation, // only transformation for mview
var transformations []func([]*pg_query.RawStmt, string) ([]*pg_query.RawStmt, error)
if !skipRecommendations {
transformations = append(transformations, applyShardedTableTransformation) // only transformation for mview
}

err := transformSchemaFile(mviewFilePath, transformations, "mview")
if err != nil {
log.Warnf("Error transforming %q: %v", mviewFilePath, err)
}
}

// Check the flag to message the user about the recommendations applied and ask to apply manually
if errorApplyingAssessmentRecommendations {
utils.PrintAndLog("\nUnable to apply assessment recommendations(sharded/colocated tables) to the exported schema. Please check the logs for more details.")
utils.PrintAndLog("You can apply the recommendations manually by referring to the assessment report.")
} else if assessmentRecommendationsApplied {
SetAssessmentRecommendationsApplied()
}
// else case will be whether neither applied nor errored, but rather schema file was not present.

// There is corner case: when recommmendations applied on table.sql but not on mview.sql or vice versa
// In this case, there is no definite answer whether assessmentRecommendationsApplied should be true or false; Assuming false.
}

// transformSchemaFile applies a sequence of transformations to the given schema file
// and writes the transformed result back. If the file doesn't exist, logs a message and returns nil.
func transformSchemaFile(filePath string, transformations []func(raw []*pg_query.RawStmt) ([]*pg_query.RawStmt, error), objectType string) error {
if !utils.FileOrFolderExists(filePath) {
log.Infof("%q file doesn't exist, skipping transformations for %s object type", filePath, objectType)
func transformSchemaFile(filePath string, transformations []func(raw []*pg_query.RawStmt, filePath string) ([]*pg_query.RawStmt, error), objectType string) error {
if !utils.FileOrFolderExists(filePath) || len(transformations) == 0 {
log.Infof("schema file %q for object type %s doesn't exist or no transformations to apply", filePath, objectType)
return nil
}

rawStmts, err := queryparser.ParseSqlFile(filePath)
var rawStmts []*pg_query.RawStmt
var err error
defer func() {
if err != nil {
errorApplyingAssessmentRecommendations = true
utils.PrintAndLog("Failed to apply any transformation to the exported schema file %q: %v\n", filePath, err)
}
}()

log.Infof("applying transformations to the schema file %q for object type %s", filePath, objectType)
rawStmts, err = queryparser.ParseSqlFile(filePath)
if err != nil {
return fmt.Errorf("failed to parse sql statements from %s object type in schema file %q: %w", objectType, filePath, err)
}

beforeSqlStmts, err := queryparser.DeparseRawStmts(rawStmts)
var beforeSqlStmts []string
beforeSqlStmts, err = queryparser.DeparseRawStmts(rawStmts)
if err != nil {
return fmt.Errorf("failed to deparse raw stmts for %s object type in schema file %q: %w", objectType, filePath, err)
}

transformedStmts := rawStmts
// Apply transformations in order
for _, transformFn := range transformations {
newStmts, err := transformFn(transformedStmts)
if err != nil {
transformFuncName := utils.GetFuncName(transformFn)
log.Infof("applying transformation: %s on %s", filepath.Base(transformFuncName), filePath)

newStmts, err2 := transformFn(transformedStmts, filePath)
if err2 != nil {
// Log and continue using the unmodified statements slice for subsequent transformations in case of error
log.Warnf("failed to apply transformation function %T in schema file %q: %v", transformFn, filePath, err)
log.Warnf("failed to apply transformation %s on the exported schema file %q: %v",
filepath.Base(transformFuncName), filePath, err)
continue
}
transformedStmts = newStmts
}

// Deparse
sqlStmts, err := queryparser.DeparseRawStmts(transformedStmts)
var sqlStmts []string
sqlStmts, err = queryparser.DeparseRawStmts(transformedStmts)
if err != nil {
return fmt.Errorf("failed to deparse transformed raw stmts for %s object type in schema file %q: %w", objectType, filePath, err)
}
Expand All @@ -389,6 +419,7 @@ func transformSchemaFile(filePath string, transformations []func(raw []*pg_query
if err != nil {
return fmt.Errorf("failed to rename %s file to %s: %w", filePath, backupFile, err)
}
utils.PrintAndLog("The original DDLs(without transformation) for %q object type are backed up at %s\n", strings.ToUpper(objectType), backupFile)

// Write updated file
fileContent := strings.Join(sqlStmts, "\n\n")
Expand All @@ -400,16 +431,38 @@ func transformSchemaFile(filePath string, transformations []func(raw []*pg_query
return nil
}

func applyShardedTableTransformation(stmts []*pg_query.RawStmt) ([]*pg_query.RawStmt, error) {
log.Info("applying sharded tables transformation to the exported schema")
func applyShardedTableTransformation(stmts []*pg_query.RawStmt, filePath string) ([]*pg_query.RawStmt, error) {
log.Infof("applying sharded tables transformation to the exported schema file %q", filePath)
if bool(skipRecommendations) || !slices.Contains(assessMigrationSupportedDBTypes, source.DBType) {
log.Info("skipping applying sharded tables transformation due to --skip-recommendations flag or assessment unsupported source db type")
return stmts, nil
}

var transformedRawStmts []*pg_query.RawStmt
var err error
// defer func to inspect err and set global flag for recommendations application
defer func() {
if err != nil {
errorApplyingAssessmentRecommendations = true
assessmentRecommendationsApplied = false
utils.PrintAndLog("Failed to apply assessment recommendations to the exported schema file %q: %v\n", filepath.Base(filePath), err)
} else {
utils.PrintAndLog("Applied assessment recommendations to %s schema\n", filepath.Base(filePath))
assessmentRecommendationsApplied = true
}
}()

assessmentReportPath = lo.Ternary(assessmentReportPath != "", assessmentReportPath,
filepath.Join(exportDir, "assessment", "reports", fmt.Sprintf("%s.json", ASSESSMENT_FILE_NAME)))
assessmentReport, err := ParseJSONToAssessmentReport(assessmentReportPath)

var assessmentReport *AssessmentReport
assessmentReport, err = ParseJSONToAssessmentReport(assessmentReportPath)
if err != nil {
return stmts, fmt.Errorf("failed to parse json report file %q: %w", assessmentReportPath, err)
}

shardedObjects, err := assessmentReport.GetShardedTablesRecommendation()
var shardedObjects []string
shardedObjects, err = assessmentReport.GetShardedTablesRecommendation()
if err != nil {
return stmts, fmt.Errorf("failed to fetch sharded tables recommendation: %w", err)
}
Expand All @@ -433,21 +486,22 @@ func applyShardedTableTransformation(stmts []*pg_query.RawStmt) ([]*pg_query.Raw
}

transformer := sqltransformer.NewTransformer()
transformedRawStmts, err := transformer.ConvertToShardedTables(stmts, isObjectSharded)
transformedRawStmts, err = transformer.ConvertToShardedTables(stmts, isObjectSharded)
if err != nil {
return stmts, fmt.Errorf("failed to convert to sharded tables: %w", err)
}

assessmentRecommendationsApplied = true
return transformedRawStmts, nil
}

func applyMergeConstraintsTransformation(rawStmts []*pg_query.RawStmt) ([]*pg_query.RawStmt, error) {
func applyMergeConstraintsTransformation(rawStmts []*pg_query.RawStmt, filePath string) ([]*pg_query.RawStmt, error) {
if utils.GetEnvAsBool("YB_VOYAGER_SKIP_MERGE_CONSTRAINTS_TRANSFORMATIONS", false) {
log.Infof("skipping applying merge constraints transformation due to env var YB_VOYAGER_SKIP_MERGE_CONSTRAINTS_TRANSFORMATIONS=true")
return rawStmts, nil
}

log.Info("applying merge constraints transformation to the exported schema")
log.Infof("applying merge constraints transformation to the exported schema file %q", filePath)
transformer := sqltransformer.NewTransformer()
transformedRawStmts, err := transformer.MergeConstraints(rawStmts)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion yb-voyager/src/query/queryparser/query_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func ParsePLPGSQLToJson(query string) (string, error) {
}

func ParseSqlFile(filePath string) ([]*pg_query.RawStmt, error) {
log.Debugf("parsing the file [%s]", filePath)
log.Infof("parsing the file %q", filePath)
bytes, err := os.ReadFile(filePath)
if err != nil {
return nil, fmt.Errorf("reading file failed: %v", err)
Expand Down
2 changes: 2 additions & 0 deletions yb-voyager/src/query/sqltransformer/transformer.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ func (t *Transformer) MergeConstraints(stmts []*pg_query.RawStmt) ([]*pg_query.R
Otherwise, add it to the result slice
*/
alterTableCmdType := alterTableCmd.GetSubtype()
log.Infof("alterTableCmdType: %v", *alterTableCmdType.Enum())
if *alterTableCmdType.Enum() != pg_query.AlterTableType_AT_AddConstraint {
// If the ALTER TABLE stmt is not an ADD CONSTRAINT stmt, then need to append it to the result slice
result = append(result, stmt)
Expand All @@ -146,6 +147,7 @@ func (t *Transformer) MergeConstraints(stmts []*pg_query.RawStmt) ([]*pg_query.R
if !ok {
return nil, fmt.Errorf("CREATE TABLE stmt not found for table %v", objectName)
}
log.Infof("merging constraint %v into CREATE TABLE for object %v", constrType, objectName)
createStmt.Stmt.GetCreateStmt().TableElts = append(createStmt.Stmt.GetCreateStmt().TableElts, alterTableCmd.GetDef())
}

Expand Down
46 changes: 42 additions & 4 deletions yb-voyager/src/query/sqltransformer/transformer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ limitations under the License.
package sqltransformer

import (
"os"
"testing"

log "github.com/sirupsen/logrus"
"github.com/yugabyte/yb-voyager/yb-voyager/src/query/queryparser"
testutils "github.com/yugabyte/yb-voyager/yb-voyager/test/utils"
)
Expand All @@ -37,6 +39,14 @@ import (
9. [Extra] Exclude constraint (omission of USING btree by parser)
*/

func TestMain(m *testing.M) {
// set log level to warn
log.SetLevel(log.WarnLevel)

exitCode := m.Run()
os.Exit(exitCode)
}

func TestMergeConstraints_Basic(t *testing.T) {
sqlFileContent := `
CREATE TABLE test_table1 (
Expand All @@ -51,14 +61,18 @@ func TestMergeConstraints_Basic(t *testing.T) {
);
ALTER TABLE test_table1 ADD CONSTRAINT test_table_pk PRIMARY KEY (id);
ALTER TABLE test_table2 ADD CONSTRAINT test_table_fk FOREIGN KEY (id) REFERENCES test_table1 (id);
ALTER TABLE test_table2 ADD CONSTRAINT test_table_uk UNIQUE (email);
-- Skip NOT VALID merging constraint
ALTER TABLE test_table1 ADD CONSTRAINT check_name CHECK (name <> '') NOT VALID;
ALTER TABLE test_table2 ADD CONSTRAINT test_table2_fk FOREIGN KEY (id) REFERENCES test_table1 (id);
ALTER TABLE test_table2 ADD CONSTRAINT test_table2_uk UNIQUE (email);
`

expectedSqls := []string{
`CREATE TABLE test_table1 (id int, name varchar(255), CONSTRAINT test_table_pk PRIMARY KEY (id));`,
`CREATE TABLE test_table2 (id int, name varchar(255), email varchar(255), CONSTRAINT test_table_uk UNIQUE (email));`,
`ALTER TABLE test_table2 ADD CONSTRAINT test_table_fk FOREIGN KEY (id) REFERENCES test_table1 (id);`,
`ALTER TABLE test_table1 ADD CONSTRAINT check_name CHECK (name <> '') NOT VALID;`,
`CREATE TABLE test_table2 (id int, name varchar(255), email varchar(255), CONSTRAINT test_table2_uk UNIQUE (email));`,
`ALTER TABLE test_table2 ADD CONSTRAINT test_table2_fk FOREIGN KEY (id) REFERENCES test_table1 (id);`,
}

tempFilePath, err := testutils.CreateTempFile("/tmp", sqlFileContent, "sql")
Expand Down Expand Up @@ -460,3 +474,27 @@ func Test_RemovalOfDefaultValuesByParser(t *testing.T) {

testutils.AssertEqualStringSlices(t, expectedSqls, finalSqlStmts)
}

// Tests cases where deparse() API deviates from expected SQL or a corner cases which is good to test.
func Test_DeparsingAPI(t *testing.T) {
sqlFileContent := `
CREATE TABLE my_table (created_at TIMESTAMPTZ NOT NULL DEFAULT (CURRENT_TIMESTAMP AT TIME ZONE 'UTC'));
`

expectedSqls := []string{
// expected: CREATE TABLE my_table (created_at timestamptz NOT NULL DEFAULT current_timestamp AT TIME ZONE 'UTC');
// but below is what parser actual returns due to Parser bug: https://github.com/pganalyze/pg_query_go/issues/126
`CREATE TABLE my_table (created_at timestamptz NOT NULL DEFAULT current_timestamp AT TIME ZONE 'UTC');`,
}

tempFilePath, err := testutils.CreateTempFile("/tmp", sqlFileContent, "sql")
testutils.FatalIfError(t, err)

stmts, err := queryparser.ParseSqlFile(tempFilePath)
testutils.FatalIfError(t, err)

finalSqlStmts, err := queryparser.DeparseRawStmts(stmts)
testutils.FatalIfError(t, err)

testutils.AssertEqualStringSlices(t, expectedSqls, finalSqlStmts)
}
1 change: 1 addition & 0 deletions yb-voyager/src/srcdb/ora2pg_extract_schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ func ora2pgExtractSchema(source *Source, exportDir string, schemaDir string) {
}
}
}
fmt.Println()
if source.DBType == "oracle" {
if err := ora2pgAssessmentReport(source, configFilePath, schemaDir); err != nil {
utils.ErrExit("failed to save ora2pg oracle assessment report during export schema: %v", err.Error())
Expand Down
4 changes: 2 additions & 2 deletions yb-voyager/src/srcdb/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,7 @@ func (pg *PostgreSQL) Connect() error {
err := pg.db.Ping()
if err == nil {
log.Infof("Already connected to the source database")
log.Infof("Already connected to the source database")
return nil
return nil
} else {
log.Infof("Failed to ping the source database: %s", err)
pg.Disconnect()
Expand Down Expand Up @@ -366,6 +365,7 @@ func (pg *PostgreSQL) ExportSchema(exportDir string, schemaDir string) {
log.Info("Export of schema completed.")
utils.WaitChannel <- returnCode
<-utils.WaitChannel
fmt.Println()
}
}

Expand Down
7 changes: 7 additions & 0 deletions yb-voyager/src/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ import (
"os"
"os/exec"
"path/filepath"
"reflect"
"regexp"
"runtime"
"sort"
"strconv"
"strings"
Expand Down Expand Up @@ -864,3 +866,8 @@ func ObfuscateFormatDetails(format, final, obfuscateWith string) (string, error)
sb.WriteString(final[lastIndex:]) // Append the text after the last group.
return sb.String(), nil
}

// given a function return the function name
func GetFuncName(fn interface{}) string {
return runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()
}

0 comments on commit 7d21b3e

Please sign in to comment.