diff --git a/inline_verifier.go b/inline_verifier.go index 80d56bab..d3bfeb41 100644 --- a/inline_verifier.go +++ b/inline_verifier.go @@ -3,8 +3,10 @@ package ghostferry import ( "bytes" "context" + "encoding/json" "errors" "fmt" + "reflect" "strconv" "strings" "sync" @@ -12,6 +14,7 @@ import ( sql "github.com/Shopify/ghostferry/sqlwrapper" + "github.com/go-mysql-org/go-mysql/schema" "github.com/golang/snappy" "github.com/sirupsen/logrus" ) @@ -696,7 +699,7 @@ func (v *InlineVerifier) verifyAllEventsInStore() (bool, map[string]map[string][ v.logger.WithField("batches", len(allBatches)).Debug("verifyAllEventsInStore") for _, batch := range allBatches { - batchMismatches, err := v.verifyBinlogBatch(batch) + batchMismatches, err := v.verifyBinlogBatch(batch, false) if err != nil { return false, nil, err } @@ -724,7 +727,7 @@ func (v *InlineVerifier) verifyAllEventsInStore() (bool, map[string]map[string][ // Since the mismatches gets re-added to the reverify store, this must return // a union of mismatches of fingerprints and mismatches due to decompressed // data. -func (v *InlineVerifier) verifyBinlogBatch(batch BinlogVerifyBatch) ([]InlineVerifierMismatches, error) { +func (v *InlineVerifier) verifyBinlogBatch(batch BinlogVerifyBatch, skipJsonColumnCheck bool) ([]InlineVerifierMismatches, error) { targetSchema := batch.SchemaName if targetSchemaName, exists := v.DatabaseRewrites[targetSchema]; exists { targetSchema = targetSchemaName @@ -783,5 +786,191 @@ func (v *InlineVerifier) verifyBinlogBatch(batch BinlogVerifyBatch) ([]InlineVer return nil, targetErr } - return v.compareHashesAndData(sourceFingerprints, targetFingerprints, sourceDecompressedData, targetDecompressedData), nil + mismatches := v.compareHashesAndData(sourceFingerprints, targetFingerprints, sourceDecompressedData, targetDecompressedData) + + if len(mismatches) > 0 && !skipJsonColumnCheck { + return v.compareJsonColumnValues(batch, mismatches, sourceTableSchema, targetSchema, targetTable) + } + + return mismatches, nil +} + +func (v *InlineVerifier) compareJsonColumnValues(batch BinlogVerifyBatch, mismatches []InlineVerifierMismatches, sourceTableSchema *TableSchema, targetSchema string, targetTable string) ([]InlineVerifierMismatches, error) { + jsonColumnNames := []string{} + for _, column := range sourceTableSchema.Columns { + if column.Type == schema.TYPE_JSON { + jsonColumnNames = append(jsonColumnNames, column.Name) + } + } + + if len(jsonColumnNames) == 0 { + return mismatches, nil + } + + addJsonColumnNamesToIgnoredColumnsForVerification(sourceTableSchema, jsonColumnNames) + + mismatches, err := v.verifyBinlogBatch(batch, true) + if err != nil { + return nil, err + } + + if len(mismatches) > 0 { + return mismatches, nil + } + + args := make([]interface{}, len(batch.PaginationKeys)) + for i, paginationKey := range batch.PaginationKeys { + args[i] = paginationKey + } + + sourceQuery := jsonColumnValueQuery(sourceTableSchema, batch.SchemaName, batch.TableName, jsonColumnNames, len(batch.PaginationKeys)) + targetQuery := jsonColumnValueQuery(sourceTableSchema, targetSchema, targetTable, jsonColumnNames, len(batch.PaginationKeys)) + + sourceStatement, _ := v.sourceStmtCache.StmtFor(v.SourceDB, sourceQuery) + targetStatement, _ := v.targetStmtCache.StmtFor(v.TargetDB, targetQuery) + + sourceRows, _ := sourceStatement.Query(args...) + targetRows, _ := targetStatement.Query(args...) + + defer sourceRows.Close() + defer targetRows.Close() + + mismatchedJsonColumns := []string{} + paginationKeysWithMismatchedJson := []uint64{} + + for { + hasSourceRows := sourceRows.Next() + hasTargetRows := targetRows.Next() + + if !hasSourceRows && !hasTargetRows { + break + } + + if (hasSourceRows && !hasTargetRows) || (!hasSourceRows && hasTargetRows) { + return nil, fmt.Errorf("Number of source and target rows are different") + } + + sourceRowData, err := ScanByteRow(sourceRows, len(jsonColumnNames) + 1) + if err != nil { + return nil, err + } + + targetRowData, err := ScanByteRow(targetRows, len(jsonColumnNames) + 1) + if err != nil { + return nil, err + } + + var sourceJsonColumnValue map[string]interface{} + var sourcePaginationKey uint64 + + var targetJsonColumnValue map[string]interface{} + var targetPaginationKey uint64 + + sourcePaginationKey, _ = strconv.ParseUint(string(sourceRowData[0]), 10, 64) + targetPaginationKey, _ = strconv.ParseUint(string(targetRowData[0]), 10, 64) + + for j, jsonColumn := range jsonColumnNames { + err := json.Unmarshal([]byte(sourceRowData[j+1]), &sourceJsonColumnValue) + if err != nil { + return nil, fmt.Errorf("unmarshalling target rowdata: %w") + } + + err = json.Unmarshal([]byte(targetRowData[j+1]), &targetJsonColumnValue) + if err != nil { + return nil, fmt.Errorf("unmarshalling target rowdata: %w") + } + + if sourcePaginationKey == targetPaginationKey && reflect.DeepEqual(sourceJsonColumnValue, targetJsonColumnValue) { + continue + } + + if !uint64SliceContains(paginationKeysWithMismatchedJson, sourcePaginationKey) { + paginationKeysWithMismatchedJson = append(paginationKeysWithMismatchedJson, sourcePaginationKey) + } + + if !stringSliceContains(mismatchedJsonColumns, jsonColumn) { + mismatchedJsonColumns = append(mismatchedJsonColumns, jsonColumn) + } + } + } + + if len(mismatchedJsonColumns) == 0 { + return mismatches, nil + } + + removeJsonColumnsFromIgnoredColumnsForVerification(sourceTableSchema, mismatchedJsonColumns) + + mismatched, err := v.verifyBinlogBatch(batch, true) + + if err != nil { + return nil, err + } + + filteredMismatches := []InlineVerifierMismatches{} + + // filtering out the mismatches that have successful json value comparison + for _, mismatch := range mismatched { + for _, mismatchedJsonPK := range paginationKeysWithMismatchedJson { + if mismatch.Pk == mismatchedJsonPK { + filteredMismatches = append(filteredMismatches, mismatch) + } + } + } + + return filteredMismatches, nil +} + +func jsonColumnValueQuery(sourceTableSchema *TableSchema, schemaName string, tableName string, jsonColumnNames []string, paginationKeysCount int) string { + paginationColumn := QuoteField(sourceTableSchema.GetPaginationColumn().Name) + + return fmt.Sprintf( + "SELECT %s FROM %s WHERE %s IN (%s) ORDER BY %s", + fmt.Sprintf("%s, %s", paginationColumn, strings.Join(jsonColumnNames, ",")), + QuotedTableNameFromString(schemaName, tableName), + paginationColumn, + strings.Repeat("?,", paginationKeysCount-1)+"?", + paginationColumn, + ) +} + +func addJsonColumnNamesToIgnoredColumnsForVerification(sourceTableSchema *TableSchema, jsonColumnNames []string) { + for _, jsonColumn := range jsonColumnNames { + if sourceTableSchema.IgnoredColumnsForVerification == nil { + sourceTableSchema.IgnoredColumnsForVerification = make(map[string]struct{}) + } + + sourceTableSchema.IgnoredColumnsForVerification[jsonColumn] = struct{}{} + } + + sourceTableSchema.rowMd5Query = "" + sourceTableSchema.RowMd5Query() +} + +func removeJsonColumnsFromIgnoredColumnsForVerification(sourceTableSchema *TableSchema, jsonColumnNames []string){ + for _, jsonColumn := range jsonColumnNames { + delete(sourceTableSchema.IgnoredColumnsForVerification, jsonColumn) + } + + sourceTableSchema.rowMd5Query = "" + sourceTableSchema.RowMd5Query() +} + +func uint64SliceContains(s []uint64, item uint64) bool { + for _, v := range s { + if item == v { + return true + } + } + + return false +} + +func stringSliceContains(s []string, item string) bool { + for _, v := range s { + if item == v { + return true + } + } + + return false } diff --git a/test/integration/inline_verifier_test.rb b/test/integration/inline_verifier_test.rb index 26c3a34d..bb7549bd 100644 --- a/test/integration/inline_verifier_test.rb +++ b/test/integration/inline_verifier_test.rb @@ -141,6 +141,60 @@ def test_different_data_in_ignored_column_passes_inline_verification end end + def test_json_data_with_float_numbers_verification_fail + [source_db, target_db].each do |db| + db.query("CREATE DATABASE IF NOT EXISTS #{DEFAULT_DB}") + db.query("CREATE TABLE IF NOT EXISTS #{DEFAULT_FULL_TABLE_NAME} (id bigint(20) not null auto_increment, data JSON, data2 JSON, primary key(id))") + end + + enable_corrupting_insert_trigger(2, '{\"data\": {\"float\": 100}}') + + ghostferry = new_ghostferry(MINIMAL_GHOSTFERRY, config: { verifier_type: "Inline" }) + + ghostferry.on_status(Ghostferry::Status::BINLOG_STREAMING_STARTED) do + source_db.query("INSERT INTO #{DEFAULT_FULL_TABLE_NAME} (data, data2) VALUES ('{\"data\": {\"float\": 32.0}}', '{\"data\": {\"float\": 42.0}}')") + source_db.query("INSERT INTO #{DEFAULT_FULL_TABLE_NAME} (data, data2) VALUES ('{\"data\": {\"float\": 25.0}}', '{\"data\": {\"float\": 35.0}}')") + end + + verification_ran = false + ghostferry.on_status(Ghostferry::Status::VERIFIED) do |*incorrect_tables| + verification_ran = true + assert_equal ["gftest.test_table_1"], incorrect_tables + end + + ghostferry.run + assert verification_ran + + assert ghostferry.error_lines.last["msg"].start_with?("cutover verification failed for: gftest.test_table_1 [paginationKeys: 2") + end + + def test_json_data_with_float_numbers_verification + [source_db, target_db].each do |db| + db.query("CREATE DATABASE IF NOT EXISTS #{DEFAULT_DB}") + db.query("CREATE TABLE IF NOT EXISTS #{DEFAULT_FULL_TABLE_NAME} (id bigint(20) not null auto_increment, data JSON, data2 JSON, primary key(id))") + end + + ghostferry = new_ghostferry(MINIMAL_GHOSTFERRY, config: { verifier_type: "Inline" }) + + ghostferry.on_status(Ghostferry::Status::BINLOG_STREAMING_STARTED) do + source_db.query("INSERT INTO #{DEFAULT_FULL_TABLE_NAME} (data, data2) VALUES ('{\"data\": {\"float\": 32.0}}', '{\"data\": {\"float\": 42.0}}')") + source_db.query("INSERT INTO #{DEFAULT_FULL_TABLE_NAME} (data, data2) VALUES ('{\"data\": {\"float\": 25.0}}', '{\"data\": {\"float\": 35.0}}')") + end + + verification_ran = false + incorrect_tables = [] + ghostferry.on_status(Ghostferry::Status::VERIFIED) do |*tables| + verification_ran = true + incorrect_tables = tables + end + + ghostferry.run + + assert_nil ghostferry.error + assert verification_ran + assert_equal [], incorrect_tables + end + def test_catches_binlog_streamer_corruption seed_random_data(source_db, number_of_rows: 1) seed_random_data(target_db, number_of_rows: 0) @@ -613,12 +667,12 @@ def set_data_column_collation(db, charset) db.query("ALTER TABLE #{DEFAULT_FULL_TABLE_NAME} MODIFY data VARCHAR(255) CHARACTER SET #{charset} COLLATE #{CHARSET_TO_COLLATION[charset]}") end - def enable_corrupting_insert_trigger(corrupting_id) + def enable_corrupting_insert_trigger(corrupting_id, new_data = "corrupted") query = [ "CREATE TRIGGER #{INSERT_TRIGGER_NAME} BEFORE INSERT ON #{DEFAULT_TABLE}", "FOR EACH ROW BEGIN", "IF NEW.id = #{corrupting_id} THEN", - "SET NEW.data = 'corrupted';", + "SET NEW.data = '#{new_data}';", "END IF;", "END", ].join("\n")