diff --git a/.github/workflows/flow.yml b/.github/workflows/flow.yml index 4bc3ef44d8..2adf8c619d 100644 --- a/.github/workflows/flow.yml +++ b/.github/workflows/flow.yml @@ -57,7 +57,7 @@ jobs: - name: run tests run: | - gotestsum --format testname -- -p 1 ./... + gotestsum --format testname -- -p 1 ./... -timeout 1200s working-directory: ./flow env: AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} diff --git a/README.md b/README.md index 9a00b594c7..135a4f8507 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,12 @@
-img-verification -

PeerDB

-

Modern ETL in minutes, with SQL

- - - + +PeerDB Banner + +#### Modern ETL in minutes, with SQL + +[![Workflow Status](https://github.com/PEerDB-io/peerdb/actions/workflows/ci.yml/badge.svg)](https://github.com/Peerdb-io/peerdb/actions/workflows/ci.yml) [![ElV2 License](https://badgen.net/badge/License/Elv2/green?icon=github)](https://github.com/PeerDB-io/peerdb/blob/main/LICENSE.md) [![Slack Community](https://img.shields.io/badge/slack-peerdb-brightgreen.svg?logo=slack)](https://join.slack.com/t/peerdb-public/shared_invite/zt-1wo9jydev-EXInbMtCtpAKFFWdi7QvLQ) +

## PeerDB diff --git a/flow/connectors/bigquery/bigquery.go b/flow/connectors/bigquery/bigquery.go index 798d4a6e3a..d4c9f9df1c 100644 --- a/flow/connectors/bigquery/bigquery.go +++ b/flow/connectors/bigquery/bigquery.go @@ -15,6 +15,7 @@ import ( "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" + "github.com/PeerDB-io/peer-flow/model/qvalue" "github.com/google/uuid" log "github.com/sirupsen/logrus" "google.golang.org/api/iterator" @@ -442,8 +443,8 @@ func (c *BigQueryConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S // 1. _peerdb_uid - uuid // 2. _peerdb_timestamp - current timestamp // 2. _peerdb_timestamp_nanos - current timestamp in nano seconds - // 3. _peerdb_data - json of `r.Items` - json, err := json.Marshal(r.Items) + // 3. _peerdb_data - itemsJSON of `r.Items` + itemsJSON, err := r.Items.ToJSON() if err != nil { return nil, fmt.Errorf("failed to create items to json: %v", err) } @@ -454,7 +455,7 @@ func (c *BigQueryConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S timestamp: time.Now(), timestampNanos: time.Now().UnixNano(), destinationTableName: r.DestinationTableName, - data: string(json), + data: itemsJSON, recordType: 0, matchData: "", batchID: syncBatchID, @@ -469,12 +470,12 @@ func (c *BigQueryConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S // 4. _peerdb_record_type - 1 // 5. _peerdb_match_data - json of `r.OldItems` - newItemsJSON, err := json.Marshal(r.NewItems) + newItemsJSON, err := r.NewItems.ToJSON() if err != nil { return nil, fmt.Errorf("failed to create new items to json: %v", err) } - oldItemsJSON, err := json.Marshal(r.OldItems) + oldItemsJSON, err := r.OldItems.ToJSON() if err != nil { return nil, fmt.Errorf("failed to create old items to json: %v", err) } @@ -485,9 +486,9 @@ func (c *BigQueryConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S timestamp: time.Now(), timestampNanos: time.Now().UnixNano(), destinationTableName: r.DestinationTableName, - data: string(newItemsJSON), + data: newItemsJSON, recordType: 1, - matchData: string(oldItemsJSON), + matchData: oldItemsJSON, batchID: syncBatchID, stagingBatchID: stagingBatchID, unchangedToastColumns: utils.KeysToString(r.UnchangedToastColumns), @@ -500,7 +501,7 @@ func (c *BigQueryConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S // 4. _peerdb_match_data - json of `r.Items` // json.Marshal converts bytes in Hex automatically to BASE64 string. - itemsJSON, err := json.Marshal(r.Items) + itemsJSON, err := r.Items.ToJSON() if err != nil { return nil, fmt.Errorf("failed to create items to json: %v", err) } @@ -511,9 +512,9 @@ func (c *BigQueryConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S timestamp: time.Now(), timestampNanos: time.Now().UnixNano(), destinationTableName: r.DestinationTableName, - data: string(itemsJSON), + data: itemsJSON, recordType: 2, - matchData: string(itemsJSON), + matchData: itemsJSON, batchID: syncBatchID, stagingBatchID: stagingBatchID, unchangedToastColumns: utils.KeysToString(r.UnchangedToastColumns), @@ -823,7 +824,7 @@ func (c *BigQueryConnector) SetupNormalizedTable( for colName, genericColType := range sourceSchema.Columns { columns[idx] = &bigquery.FieldSchema{ Name: colName, - Type: getBigQueryColumnTypeForGenericColType(genericColType), + Type: qValueKindToBigQueryType(genericColType), } idx++ } @@ -923,45 +924,6 @@ func (c *BigQueryConnector) truncateTable(tableIdentifier string) error { return nil } -func getBigQueryColumnTypeForGenericColType(colType string) bigquery.FieldType { - switch colType { - // boolean - case model.ColumnTypeBoolean: - return bigquery.BooleanFieldType - // integer types - case model.ColumnTypeInt16, model.ColumnTypeInt32, model.ColumnTypeInt64: - return bigquery.IntegerFieldType - // decimal types - case model.ColumnTypeFloat16, model.ColumnTypeFloat32, model.ColumnTypeFloat64: - return bigquery.FloatFieldType - case model.ColumnTypeNumeric: - return bigquery.NumericFieldType - // string related - case model.ColumnTypeString: - return bigquery.StringFieldType - // json also is stored as string for now - case model.ColumnTypeJSON: - return bigquery.StringFieldType - // time related - case model.ColumnTypeTimestamp, model.ColumnTypeTimeStampWithTimeZone: - return bigquery.TimestampFieldType - case model.ColumnTypeTime: - return bigquery.TimeFieldType - case model.ColumnTypeTimeWithTimeZone: - return bigquery.StringFieldType - case model.ColumnTypeDate: - return bigquery.TimestampFieldType - case model.ColumnTypeInterval: - return bigquery.IntervalFieldType - // bytes - case model.ColumnHexBytes, model.ColumnHexBit: - return bigquery.BytesFieldType - // rest will be strings - default: - return bigquery.StringFieldType - } -} - type MergeStmtGenerator struct { // dataset of all the tables Dataset string @@ -1003,40 +965,42 @@ func (m *MergeStmtGenerator) generateFlattenedCTE() string { // statement. flattenedProjs := make([]string, 0) for colName, colType := range m.NormalizedTableSchema.Columns { - bqType := getBigQueryColumnTypeForGenericColType(colType) + bqType := qValueKindToBigQueryType(colType) // CAST doesn't work for FLOAT, so rewrite it to FLOAT64. if bqType == bigquery.FloatFieldType { bqType = "FLOAT64" } var castStmt string - switch colType { - case model.ColumnTypeJSON: + switch qvalue.QValueKind(colType) { + case qvalue.QValueKindJSON: //if the type is JSON, then just extract JSON castStmt = fmt.Sprintf("CAST(JSON_EXTRACT(_peerdb_data, '$.%s') AS %s) AS %s", colName, bqType, colName) // expecting data in BASE64 format - case model.ColumnHexBytes: + case qvalue.QValueKindBytes: castStmt = fmt.Sprintf("FROM_BASE64(JSON_EXTRACT_SCALAR(_peerdb_data, '$.%s')) AS %s", colName, colName) // MAKE_INTERVAL(years INT64, months INT64, days INT64, hours INT64, minutes INT64, seconds INT64) // Expecting interval to be in the format of {"Microseconds":2000000,"Days":0,"Months":0,"Valid":true} // json.Marshal in SyncRecords for Postgres already does this - once new data-stores are added, // this needs to be handled again - case model.ColumnTypeInterval: - castStmt = fmt.Sprintf("MAKE_INTERVAL(0,CAST(JSON_EXTRACT_SCALAR(_peerdb_data, '$.%s.Months') AS INT64),"+ - "CAST(JSON_EXTRACT_SCALAR(_peerdb_data, '$.%s.Days') AS INT64),0,0,"+ - "CAST(CAST(JSON_EXTRACT_SCALAR(_peerdb_data, '$.%s.Microseconds') AS INT64)/1000000 AS INT64)) AS %s", - colName, colName, colName, colName) - case model.ColumnHexBit: + // TODO add interval types again + // case model.ColumnTypeInterval: + // castStmt = fmt.Sprintf("MAKE_INTERVAL(0,CAST(JSON_EXTRACT_SCALAR(_peerdb_data, '$.%s.Months') AS INT64),"+ + // "CAST(JSON_EXTRACT_SCALAR(_peerdb_data, '$.%s.Days') AS INT64),0,0,"+ + // "CAST(CAST(JSON_EXTRACT_SCALAR(_peerdb_data, '$.%s.Microseconds') AS INT64)/1000000 AS INT64)) AS %s", + // colName, colName, colName, colName) + case qvalue.QValueKindBit: // sample raw data for BIT {"a":{"Bytes":"oA==","Len":3,"Valid":true},"id":1} // need to check correctness TODO castStmt = fmt.Sprintf("FROM_BASE64(JSON_EXTRACT_SCALAR(_peerdb_data, '$.%s.Bytes')) AS %s", colName, colName) - case model.ColumnTypeTime: - castStmt = fmt.Sprintf("time(timestamp_micros(CAST(JSON_EXTRACT(_peerdb_data, '$.%s.Microseconds')"+ - " AS int64))) AS %s", - colName, colName) + // TODO add proper granularity for time types, then restore this + // case model.ColumnTypeTime: + // castStmt = fmt.Sprintf("time(timestamp_micros(CAST(JSON_EXTRACT(_peerdb_data, '$.%s.Microseconds')"+ + // " AS int64))) AS %s", + // colName, colName) default: castStmt = fmt.Sprintf("CAST(JSON_EXTRACT_SCALAR(_peerdb_data, '$.%s') AS %s) AS %s", colName, bqType, colName) diff --git a/flow/connectors/bigquery/qrecord_value_saver.go b/flow/connectors/bigquery/qrecord_value_saver.go index 7bb82d8f94..672f7ee0b3 100644 --- a/flow/connectors/bigquery/qrecord_value_saver.go +++ b/flow/connectors/bigquery/qrecord_value_saver.go @@ -63,10 +63,17 @@ func (q QRecordValueSaver) Save() (map[string]bigquery.Value, string, error) { } bqValues[k] = val + case qvalue.QValueKindInt16: + val, ok := v.Value.(int16) + if !ok { + return nil, "", fmt.Errorf("failed to convert %v to int16", v.Value) + } + bqValues[k] = val + case qvalue.QValueKindInt32: val, ok := v.Value.(int32) if !ok { - return nil, "", fmt.Errorf("failed to convert %v to int64", v.Value) + return nil, "", fmt.Errorf("failed to convert %v to int32", v.Value) } bqValues[k] = val @@ -91,12 +98,12 @@ func (q QRecordValueSaver) Save() (map[string]bigquery.Value, string, error) { } bqValues[k] = val - case qvalue.QValueKindETime: - val, ok := v.Value.(*qvalue.ExtendedTime) - if !ok { - return nil, "", fmt.Errorf("failed to convert %v to ExtendedTime", v.Value) + case qvalue.QValueKindTimestamp, qvalue.QValueKindDate, qvalue.QValueKindTime: + var err error + bqValues[k], err = v.GoTimeConvert() + if err != nil { + return nil, "", fmt.Errorf("failed to convert parse %v into time.Time", v) } - bqValues[k] = val.Time case qvalue.QValueKindNumeric: val, ok := v.Value.(*big.Rat) diff --git a/flow/connectors/bigquery/qrep_avro_sync.go b/flow/connectors/bigquery/qrep_avro_sync.go index 09703625fc..58d19fbc8b 100644 --- a/flow/connectors/bigquery/qrep_avro_sync.go +++ b/flow/connectors/bigquery/qrep_avro_sync.go @@ -216,13 +216,13 @@ func GetAvroType(bqField *bigquery.FieldSchema) (interface{}, error) { }, nil case bigquery.DateFieldType: return map[string]string{ - "type": "int", - "logicalType": "date", + "type": "long", + "logicalType": "timestamp-micros", }, nil case bigquery.TimeFieldType: return map[string]string{ "type": "long", - "logicalType": "time-micros", + "logicalType": "timestamp-micros", }, nil case bigquery.DateTimeFieldType: return map[string]interface{}{ diff --git a/flow/connectors/bigquery/qrep_sync_method.go b/flow/connectors/bigquery/qrep_sync_method.go index 8dd998c973..92a9d4142d 100644 --- a/flow/connectors/bigquery/qrep_sync_method.go +++ b/flow/connectors/bigquery/qrep_sync_method.go @@ -94,7 +94,11 @@ func (s *QRepStagingTableSync) SyncQRepRecords( // col names for the destination table joined by comma colNames := []string{} for _, col := range dstTableMetadata.Schema { - colNames = append(colNames, col.Name) + if strings.ToLower(col.Name) == "from" { + colNames = append(colNames, "`from`") + } else { + colNames = append(colNames, col.Name) + } } colNamesStr := strings.Join(colNames, ", ") diff --git a/flow/connectors/bigquery/qvalue_convert.go b/flow/connectors/bigquery/qvalue_convert.go new file mode 100644 index 0000000000..fabdfb97f2 --- /dev/null +++ b/flow/connectors/bigquery/qvalue_convert.go @@ -0,0 +1,76 @@ +package connbigquery + +import ( + "fmt" + + "cloud.google.com/go/bigquery" + "github.com/PeerDB-io/peer-flow/model/qvalue" +) + +func qValueKindToBigQueryType(colType string) bigquery.FieldType { + switch qvalue.QValueKind(colType) { + // boolean + case qvalue.QValueKindBoolean: + return bigquery.BooleanFieldType + // integer types + case qvalue.QValueKindInt16, qvalue.QValueKindInt32, qvalue.QValueKindInt64: + return bigquery.IntegerFieldType + // decimal types + case qvalue.QValueKindFloat32, qvalue.QValueKindFloat64: + return bigquery.FloatFieldType + case qvalue.QValueKindNumeric: + return bigquery.NumericFieldType + // string related + case qvalue.QValueKindString: + return bigquery.StringFieldType + // json also is stored as string for now + case qvalue.QValueKindJSON: + return bigquery.StringFieldType + // time related + case qvalue.QValueKindTimestamp, qvalue.QValueKindTimestampTZ: + return bigquery.TimestampFieldType + // TODO: https://github.com/PeerDB-io/peerdb/issues/189 - DATE support is incomplete + case qvalue.QValueKindDate: + return bigquery.DateFieldType + // TODO: https://github.com/PeerDB-io/peerdb/issues/189 - TIME/TIMETZ support is incomplete + case qvalue.QValueKindTime, qvalue.QValueKindTimeTZ: + return bigquery.TimeFieldType + // TODO: https://github.com/PeerDB-io/peerdb/issues/189 - handle INTERVAL types again, + // bytes + case qvalue.QValueKindBit, qvalue.QValueKindBytes: + return bigquery.BytesFieldType + // rest will be strings + default: + return bigquery.StringFieldType + } +} + +// bigqueryTypeToQValueKind converts a bigquery FieldType to a QValueKind. +func BigQueryTypeToQValueKind(fieldType bigquery.FieldType) (qvalue.QValueKind, error) { + switch fieldType { + case bigquery.StringFieldType: + return qvalue.QValueKindString, nil + case bigquery.BytesFieldType: + return qvalue.QValueKindBytes, nil + case bigquery.IntegerFieldType: + return qvalue.QValueKindInt64, nil + case bigquery.FloatFieldType: + return qvalue.QValueKindFloat64, nil + case bigquery.BooleanFieldType: + return qvalue.QValueKindBoolean, nil + case bigquery.TimestampFieldType: + return qvalue.QValueKindTimestamp, nil + case bigquery.DateFieldType: + return qvalue.QValueKindDate, nil + case bigquery.TimeFieldType: + return qvalue.QValueKindTime, nil + case bigquery.RecordFieldType: + return qvalue.QValueKindStruct, nil + case bigquery.NumericFieldType: + return qvalue.QValueKindNumeric, nil + case bigquery.GeographyFieldType: + return qvalue.QValueKindString, nil + default: + return "", fmt.Errorf("unsupported bigquery field type: %v", fieldType) + } +} diff --git a/flow/connectors/postgres/cdc.go b/flow/connectors/postgres/cdc.go index 411e9ee41b..a71e229887 100644 --- a/flow/connectors/postgres/cdc.go +++ b/flow/connectors/postgres/cdc.go @@ -7,6 +7,7 @@ import ( "time" "github.com/PeerDB-io/peer-flow/model" + "github.com/PeerDB-io/peer-flow/model/qvalue" "github.com/jackc/pglogrepl" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgproto3" @@ -386,34 +387,34 @@ It takes a tuple and a relation message as input and returns func (p *PostgresCDCSource) convertTupleToMap( tuple *pglogrepl.TupleData, rel *pglogrepl.RelationMessage, -) (map[string]interface{}, map[string]bool, error) { +) (model.RecordItems, map[string]bool, error) { // if the tuple is nil, return an empty map if tuple == nil { - return make(map[string]interface{}), make(map[string]bool), nil + return make(model.RecordItems), make(map[string]bool), nil } // create empty map of string to interface{} - items := make(map[string]interface{}) + items := make(model.RecordItems) unchangedToastColumns := make(map[string]bool) for idx, col := range tuple.Columns { colName := rel.Columns[idx].Name switch col.DataType { case 'n': // null - items[colName] = nil + items[colName] = qvalue.QValue{Kind: qvalue.QValueKindInvalid, Value: nil} case 't': // text /* bytea also appears here as a hex */ - data, err := p.decodeTextColumnData(col.Data, rel.Columns[idx].DataType) + data, err := p.decodeColumnData(col.Data, rel.Columns[idx].DataType, pgtype.TextFormatCode) if err != nil { return nil, nil, fmt.Errorf("error decoding text column data: %w", err) } - items[colName] = data + items[colName] = *data case 'b': // binary - data, err := p.decodeBinaryColumnData(col.Data, rel.Columns[idx].DataType) + data, err := p.decodeColumnData(col.Data, rel.Columns[idx].DataType, pgtype.BinaryFormatCode) if err != nil { return nil, nil, fmt.Errorf("error decoding binary column data: %w", err) } - items[colName] = data + items[colName] = *data case 'u': // unchanged toast unchangedToastColumns[colName] = true default: @@ -423,21 +424,25 @@ func (p *PostgresCDCSource) convertTupleToMap( return items, unchangedToastColumns, nil } -func (p *PostgresCDCSource) decodeTextColumnData(data []byte, dataType uint32) (interface{}, error) { +func (p *PostgresCDCSource) decodeColumnData(data []byte, dataType uint32, formatCode int16) (*qvalue.QValue, error) { + var parsedData any + var err error if dt, ok := p.typeMap.TypeForOID(dataType); ok { if dt.Name == "uuid" { // below is required to decode uuid to string - return dt.Codec.DecodeDatabaseSQLValue(p.typeMap, dataType, pgtype.TextFormatCode, data) - } - return dt.Codec.DecodeValue(p.typeMap, dataType, pgtype.TextFormatCode, data) - } - return string(data), nil -} + parsedData, err = dt.Codec.DecodeDatabaseSQLValue(p.typeMap, dataType, pgtype.TextFormatCode, data) -// decodeBinaryColumnData decodes the binary data for a column -func (p *PostgresCDCSource) decodeBinaryColumnData(data []byte, dataType uint32) (interface{}, error) { - if dt, ok := p.typeMap.TypeForOID(dataType); ok { - return dt.Codec.DecodeValue(p.typeMap, dataType, pgtype.BinaryFormatCode, data) + } else { + parsedData, err = dt.Codec.DecodeValue(p.typeMap, dataType, formatCode, data) + } + if err != nil { + return nil, err + } + retVal, err := parseFieldFromPostgresOID(dataType, parsedData) + if err != nil { + return nil, err + } + return retVal, nil } - return string(data), nil + return &qvalue.QValue{Kind: qvalue.QValueKindString, Value: string(data)}, nil } diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index d9d228f8b8..909423a3ae 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -7,6 +7,7 @@ import ( "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" + "github.com/PeerDB-io/peer-flow/model/qvalue" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" log "github.com/sirupsen/logrus" @@ -271,17 +272,9 @@ func (c *PostgresConnector) GetTableSchema(req *protos.GetTableSchemaInput) (*pr return nil, err } - relID, err := c.getRelIDForTable(schemaTable) - if err != nil { - return nil, err - } - // Get the column names and types rows, err := c.pool.Query(c.ctx, - `SELECT a.attname, t.typname FROM pg_attribute a - JOIN pg_type t ON t.oid = a.atttypid - WHERE a.attnum > 0 AND NOT a.attisdropped AND a.attrelid = $1`, - relID) + fmt.Sprintf(`SELECT * FROM %s LIMIT 0`, req.TableIdentifier)) if err != nil { return nil, fmt.Errorf("error getting table schema for table %s: %w", schemaTable, err) } @@ -298,20 +291,13 @@ func (c *PostgresConnector) GetTableSchema(req *protos.GetTableSchemaInput) (*pr PrimaryKeyColumn: pkey, } - for rows.Next() { - var colName string - var colType string - err = rows.Scan(&colName, &colType) - if err != nil { - return nil, fmt.Errorf("error scanning table schema: %w", err) - } - - colType, err = convertPostgresColumnTypeToGeneric(colType) - if err != nil { - return nil, fmt.Errorf("error converting postgres column type: %w", err) + for _, fieldDescription := range rows.FieldDescriptions() { + genericColType := getQValueKindForPostgresOID(fieldDescription.DataTypeOID) + if genericColType == qvalue.QValueKindInvalid { + return nil, fmt.Errorf("error converting Postgres OID to QValueKind") } - res.Columns[colName] = colType + res.Columns[fieldDescription.Name] = string(genericColType) } if err = rows.Err(); err != nil { @@ -450,77 +436,6 @@ func (c *PostgresConnector) getPrimaryKeyColumn(schemaTable *SchemaTable) (strin return pkCol, nil } -func convertPostgresColumnTypeToGeneric(colType string) (string, error) { - switch colType { - case "int2": - return model.ColumnTypeInt16, nil - case "int4": - return model.ColumnTypeInt32, nil - case "int8": - return model.ColumnTypeInt64, nil - case "float4": - return model.ColumnTypeFloat32, nil - case "float8": - return model.ColumnTypeFloat64, nil - case "bool": - return model.ColumnTypeBoolean, nil - case "text": - return model.ColumnTypeString, nil - case "date": - return model.ColumnTypeDate, nil - case "timestamp": - return model.ColumnTypeTimestamp, nil - case "timestamptz": - return model.ColumnTypeTimeStampWithTimeZone, nil - case "varchar": - return model.ColumnTypeString, nil - case "char": - return model.ColumnTypeString, nil - case "bpchar": - return model.ColumnTypeString, nil - case "numeric": - return model.ColumnTypeNumeric, nil - case "uuid": - return model.ColumnTypeString, nil - case "json": - return model.ColumnTypeJSON, nil - case "jsonb": - return model.ColumnTypeJSON, nil - case "xml": - return model.ColumnTypeString, nil - case "tsvector": - return model.ColumnTypeString, nil - case "tsquery": - return model.ColumnTypeString, nil - case "bytea": - return model.ColumnHexBytes, nil - case "bit": - return model.ColumnHexBit, nil - case "varbit": - return model.ColumnHexBit, nil - case "cidr": - return model.ColumnTypeString, nil - case "inet": - return model.ColumnTypeString, nil - case "interval": - return model.ColumnTypeInterval, nil - case "macaddr": - return model.ColumnTypeString, nil - case "money": - return model.ColumnTypeString, nil - case "oid": - return model.ColumnTypeInt64, nil - case "time": - return model.ColumnTypeTime, nil - case "timetz": - return model.ColumnTypeTimeWithTimeZone, nil - case "txid_snapshot": - return model.ColumnTypeString, nil - default: - return "", fmt.Errorf("unsupported column type: %s", colType) - } -} - func (c *PostgresConnector) tableExists(schemaTable *SchemaTable) (bool, error) { var exists bool err := c.pool.QueryRow(c.ctx, diff --git a/flow/connectors/postgres/postgres_cdc_test.go b/flow/connectors/postgres/postgres_cdc_test.go index e45d0f117f..af32f112fc 100644 --- a/flow/connectors/postgres/postgres_cdc_test.go +++ b/flow/connectors/postgres/postgres_cdc_test.go @@ -9,6 +9,7 @@ import ( "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" + "github.com/PeerDB-io/peer-flow/model/qvalue" "github.com/stretchr/testify/suite" ) @@ -133,11 +134,11 @@ func (suite *PostgresCDCTestSuite) validateInsertedToastRecords(records []model. suite.Equal(dstTableName, insertRecord.DestinationTableName) suite.Equal(5, len(insertRecord.Items)) - suite.Equal(int32(idx+1), insertRecord.Items["id"].(int32)) - suite.Equal(32768, len(insertRecord.Items["n_t"].(string))) - suite.Equal(32768, len(insertRecord.Items["lz4_t"].(string))) - suite.Equal(32768, len(insertRecord.Items["n_b"].([]byte))) - suite.Equal(32768, len(insertRecord.Items["lz4_b"].([]byte))) + suite.Equal(int32(idx+1), insertRecord.Items["id"].Value.(int32)) + suite.Equal(32768, len(insertRecord.Items["n_t"].Value.(string))) + suite.Equal(32768, len(insertRecord.Items["lz4_t"].Value.(string))) + suite.Equal(32768, len(insertRecord.Items["n_b"].Value.([]byte))) + suite.Equal(32768, len(insertRecord.Items["lz4_b"].Value.([]byte))) } } @@ -175,8 +176,8 @@ func (suite *PostgresCDCTestSuite) validateMutatedToastRecords(records []model.R suite.Equal(srcTableName, updateRecord.SourceTableName) suite.Equal(dstTableName, updateRecord.DestinationTableName) suite.Equal(2, len(updateRecord.NewItems)) - suite.Equal(int32(1), updateRecord.NewItems["id"].(int32)) - suite.Equal(65536, len(updateRecord.NewItems["n_t"].(string))) + suite.Equal(int32(1), updateRecord.NewItems["id"].Value.(int32)) + suite.Equal(65536, len(updateRecord.NewItems["n_t"].Value.(string))) suite.Equal(3, len(updateRecord.UnchangedToastColumns)) suite.True(updateRecord.UnchangedToastColumns["lz4_t"]) suite.True(updateRecord.UnchangedToastColumns["n_b"]) @@ -187,8 +188,8 @@ func (suite *PostgresCDCTestSuite) validateMutatedToastRecords(records []model.R suite.Equal(srcTableName, updateRecord.SourceTableName) suite.Equal(dstTableName, updateRecord.DestinationTableName) suite.Equal(2, len(updateRecord.NewItems)) - suite.Equal(int32(2), updateRecord.NewItems["id"].(int32)) - suite.Equal(65536, len(updateRecord.NewItems["lz4_b"].([]byte))) + suite.Equal(int32(2), updateRecord.NewItems["id"].Value.(int32)) + suite.Equal(65536, len(updateRecord.NewItems["lz4_b"].Value.([]byte))) suite.Equal(3, len(updateRecord.UnchangedToastColumns)) suite.True(updateRecord.UnchangedToastColumns["lz4_t"]) suite.True(updateRecord.UnchangedToastColumns["n_b"]) @@ -199,8 +200,8 @@ func (suite *PostgresCDCTestSuite) validateMutatedToastRecords(records []model.R suite.Equal(srcTableName, updateRecord.SourceTableName) suite.Equal(dstTableName, updateRecord.DestinationTableName) suite.Equal(2, len(updateRecord.NewItems)) - suite.Equal(int32(3), updateRecord.NewItems["id"].(int32)) - suite.Equal(65536, len(updateRecord.NewItems["n_b"].([]byte))) + suite.Equal(int32(3), updateRecord.NewItems["id"].Value.(int32)) + suite.Equal(65536, len(updateRecord.NewItems["n_b"].Value.([]byte))) suite.Equal(3, len(updateRecord.UnchangedToastColumns)) suite.True(updateRecord.UnchangedToastColumns["lz4_t"]) suite.True(updateRecord.UnchangedToastColumns["n_t"]) @@ -211,7 +212,7 @@ func (suite *PostgresCDCTestSuite) validateMutatedToastRecords(records []model.R suite.Equal(srcTableName, deleteRecord.SourceTableName) suite.Equal(dstTableName, deleteRecord.DestinationTableName) suite.Equal(5, len(deleteRecord.Items)) - suite.Equal(int32(3), deleteRecord.Items["id"].(int32)) + suite.Equal(int32(3), deleteRecord.Items["id"].Value.(int32)) suite.Nil(deleteRecord.Items["n_t"]) suite.Nil(deleteRecord.Items["lz4_t"]) suite.Nil(deleteRecord.Items["n_b"]) @@ -397,8 +398,8 @@ func (suite *PostgresCDCTestSuite) TestErrorForTableNotExist() { tableNameSchemaMapping[nonExistentFlowDstTableName] = &protos.TableSchema{ TableIdentifier: nonExistentFlowSrcTableName, Columns: map[string]string{ - "id": model.ColumnTypeInt32, - "name": model.ColumnTypeString, + "id": string(qvalue.QValueKindInt32), + "name": string(qvalue.QValueKindString), }, PrimaryKeyColumn: "id", } @@ -494,8 +495,8 @@ func (suite *PostgresCDCTestSuite) TestSimpleHappyFlow() { suite.Equal(&protos.TableSchema{ TableIdentifier: simpleHappyFlowSrcTableName, Columns: map[string]string{ - "id": model.ColumnTypeInt32, - "name": model.ColumnTypeString, + "id": string(qvalue.QValueKindInt32), + "name": string(qvalue.QValueKindString), }, PrimaryKeyColumn: "id", }, tableNameSchema) @@ -603,36 +604,36 @@ func (suite *PostgresCDCTestSuite) TestAllTypesHappyFlow() { suite.Equal(&protos.TableSchema{ TableIdentifier: allTypesHappyFlowSrcTableName, Columns: map[string]string{ - "id": model.ColumnTypeInt64, - "c1": model.ColumnTypeInt64, + "id": string(qvalue.QValueKindInt64), + "c1": string(qvalue.QValueKindInt64), "c2": model.ColumnHexBit, "c3": model.ColumnHexBit, - "c4": model.ColumnTypeBoolean, - "c6": model.ColumnHexBytes, - "c7": model.ColumnTypeString, - "c8": model.ColumnTypeString, - "c9": model.ColumnTypeString, - "c11": model.ColumnTypeDate, - "c12": model.ColumnTypeFloat64, - "c13": model.ColumnTypeFloat64, + "c4": string(qvalue.QValueKindBoolean), + "c6": string(qvalue.QValueKindBytes), + "c7": string(qvalue.QValueKindString), + "c8": string(qvalue.QValueKindString), + "c9": string(qvalue.QValueKindString), + "c11": string(qvalue.QValueKindDate), + "c12": string(qvalue.QValueKindFloat64), + "c13": string(qvalue.QValueKindFloat64), "c14": model.ColumnTypeString, - "c15": model.ColumnTypeInt32, + "c15": string(qvalue.QValueKindInt32), "c16": model.ColumnTypeInterval, - "c17": model.ColumnTypeJSON, - "c18": model.ColumnTypeJSON, + "c17": string(qvalue.QValueKindJSON), + "c18": string(qvalue.QValueKindJSON), "c21": model.ColumnTypeString, "c22": model.ColumnTypeString, - "c23": model.ColumnTypeNumeric, - "c24": model.ColumnTypeInt64, - "c28": model.ColumnTypeFloat32, - "c29": model.ColumnTypeInt16, - "c30": model.ColumnTypeInt16, - "c31": model.ColumnTypeInt32, + "c23": string(qvalue.QValueKindNumeric), + "c24": string(qvalue.QValueKindInt64), + "c28": string(qvalue.QValueKindFloat32), + "c29": string(qvalue.QValueKindInt16), + "c30": string(qvalue.QValueKindInt16), + "c31": string(qvalue.QValueKindInt32), "c32": model.ColumnTypeString, - "c33": model.ColumnTypeTimestamp, - "c34": model.ColumnTypeTimeStampWithTimeZone, - "c35": model.ColumnTypeTime, - "c36": model.ColumnTypeTimeWithTimeZone, + "c33": string(qvalue.QValueKindTimestamp), + "c34": string(qvalue.QValueKindTimestampTZ), + "c35": string(qvalue.QValueKindTime), + "c36": string(qvalue.QValueKindTimeTZ), "c37": model.ColumnTypeString, "c38": model.ColumnTypeString, "c39": model.ColumnTypeString, @@ -714,11 +715,11 @@ func (suite *PostgresCDCTestSuite) TestToastHappyFlow() { suite.Equal(&protos.TableSchema{ TableIdentifier: toastHappyFlowSrcTableName, Columns: map[string]string{ - "id": model.ColumnTypeInt32, - "n_t": model.ColumnTypeString, - "lz4_t": model.ColumnTypeString, - "n_b": model.ColumnHexBytes, - "lz4_b": model.ColumnHexBytes, + "id": string(qvalue.QValueKindInt32), + "n_t": string(qvalue.QValueKindString), + "lz4_t": string(qvalue.QValueKindString), + "n_b": string(qvalue.QValueKindBytes), + "lz4_b": string(qvalue.QValueKindBytes), }, PrimaryKeyColumn: "id", }, tableNameSchema) diff --git a/flow/connectors/postgres/qrep.go b/flow/connectors/postgres/qrep.go index aa74bd9e21..10f1d4249b 100644 --- a/flow/connectors/postgres/qrep.go +++ b/flow/connectors/postgres/qrep.go @@ -28,6 +28,8 @@ func (c *PostgresConnector) GetQRepPartitions( switch v := minValue.(type) { case int32, int64: maxValue := maxValue.(int64) + 1 + fmt.Println("minValue", minValue) + fmt.Println("maxValue", maxValue) partitions, err = c.getIntPartitions(v.(int64), maxValue, config.BatchSizeInt) case time.Time: maxValue := maxValue.(time.Time).Add(time.Microsecond) @@ -286,7 +288,8 @@ func (c *PostgresConnector) getIntPartitions( for start <= end { partitionEnd := start + batchSize - if partitionEnd > end { + // safeguard against integer overflow + if partitionEnd > end || partitionEnd < start { partitionEnd = end } diff --git a/flow/connectors/postgres/qrep_query_executor.go b/flow/connectors/postgres/qrep_query_executor.go index 5dd4cc1ab4..8bf3c2373f 100644 --- a/flow/connectors/postgres/qrep_query_executor.go +++ b/flow/connectors/postgres/qrep_query_executor.go @@ -2,18 +2,11 @@ package connpostgres import ( "context" - "database/sql" - "errors" "fmt" - "math" - "math/big" - "time" "github.com/PeerDB-io/peer-flow/model" - "github.com/PeerDB-io/peer-flow/model/qvalue" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" - "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgxpool" log "github.com/sirupsen/logrus" ) @@ -38,50 +31,12 @@ func (qe *QRepQueryExecutor) ExecuteQuery(query string, args ...interface{}) (pg return rows, nil } -func fieldDescriptionToQValueKind(fd pgconn.FieldDescription) qvalue.QValueKind { - switch fd.DataTypeOID { - case pgtype.BoolOID: - return qvalue.QValueKindBoolean - case pgtype.Int2OID: - return qvalue.QValueKindInt16 - case pgtype.Int4OID: - return qvalue.QValueKindInt32 - case pgtype.Int8OID: - return qvalue.QValueKindInt64 - case pgtype.Float4OID: - return qvalue.QValueKindFloat32 - case pgtype.Float8OID: - return qvalue.QValueKindFloat64 - case pgtype.TextOID, pgtype.VarcharOID: - return qvalue.QValueKindString - case pgtype.ByteaOID: - return qvalue.QValueKindBytes - case pgtype.JSONOID, pgtype.JSONBOID: - return qvalue.QValueKindJSON - case pgtype.UUIDOID: - return qvalue.QValueKindUUID - case pgtype.TimestampOID, pgtype.TimestamptzOID, pgtype.DateOID, pgtype.TimeOID: - return qvalue.QValueKindETime - case pgtype.NumericOID: - return qvalue.QValueKindNumeric - default: - typeName, ok := pgtype.NewMap().TypeForOID(fd.DataTypeOID) - if !ok { - log.Warnf("failed to get type name for oid: %v", fd.DataTypeOID) - return qvalue.QValueKindInvalid - } else { - log.Warnf("unsupported field type: %v - type name - %s", fd.DataTypeOID, typeName.Name) - return qvalue.QValueKindInvalid - } - } -} - // FieldDescriptionsToSchema converts a slice of pgconn.FieldDescription to a QRecordSchema. func fieldDescriptionsToSchema(fds []pgconn.FieldDescription) *model.QRecordSchema { qfields := make([]*model.QField, len(fds)) for i, fd := range fds { cname := fd.Name - ctype := fieldDescriptionToQValueKind(fd) + ctype := getQValueKindForPostgresOID(fd.DataTypeOID) // there isn't a way to know if a column is nullable or not // TODO fix this. cnullable := true @@ -146,216 +101,22 @@ func (qe *QRepQueryExecutor) ExecuteAndProcessQuery( return batch, nil } -func mapRowToQRecord(row pgx.Row, fds []pgconn.FieldDescription) (*model.QRecord, error) { +func mapRowToQRecord(row pgx.Rows, fds []pgconn.FieldDescription) (*model.QRecord, error) { // make vals an empty array of QValue of size len(fds) record := model.NewQRecord(len(fds)) - scanArgs := make([]interface{}, len(fds)) - for i := range scanArgs { - switch fds[i].DataTypeOID { - case pgtype.BoolOID: - scanArgs[i] = new(pgtype.Bool) - case pgtype.TimestampOID: - scanArgs[i] = new(pgtype.Timestamp) - case pgtype.TimestamptzOID: - scanArgs[i] = new(pgtype.Timestamptz) - case pgtype.Int4OID: - scanArgs[i] = new(pgtype.Int4) - case pgtype.Int8OID: - scanArgs[i] = new(pgtype.Int8) - case pgtype.Float4OID: - scanArgs[i] = new(pgtype.Float4) - case pgtype.Float8OID: - scanArgs[i] = new(pgtype.Float8) - case pgtype.TextOID: - scanArgs[i] = new(pgtype.Text) - case pgtype.VarcharOID: - scanArgs[i] = new(pgtype.Text) - case pgtype.NumericOID: - scanArgs[i] = new(pgtype.Numeric) - case pgtype.UUIDOID: - scanArgs[i] = new(pgtype.UUID) - case pgtype.ByteaOID: - scanArgs[i] = new(sql.RawBytes) - case pgtype.DateOID: - scanArgs[i] = new(pgtype.Date) - default: - scanArgs[i] = new(pgtype.Text) - } - } - - err := row.Scan(scanArgs...) + values, err := row.Values() if err != nil { return nil, fmt.Errorf("failed to scan row: %w", err) } for i, fd := range fds { - tmp, err := parseField(fd.DataTypeOID, scanArgs[i]) + tmp, err := parseFieldFromPostgresOID(fd.DataTypeOID, values[i]) if err != nil { return nil, fmt.Errorf("failed to parse field: %w", err) } - record.Set(i, tmp) + record.Set(i, *tmp) } return record, nil } - -func parseField(oid uint32, value interface{}) (qvalue.QValue, error) { - var val qvalue.QValue - - switch oid { - case pgtype.TimestampOID: - timestamp := value.(*pgtype.Timestamp) - var et *qvalue.ExtendedTime - if timestamp.Valid { - var err error - et, err = qvalue.NewExtendedTime(timestamp.Time, qvalue.DateTimeKindType, "") - if err != nil { - return qvalue.QValue{}, fmt.Errorf("failed to create ExtendedTime: %w", err) - } - } - val = qvalue.QValue{Kind: qvalue.QValueKindETime, Value: et} - case pgtype.TimestamptzOID: - timestamp := value.(*pgtype.Timestamptz) - var et *qvalue.ExtendedTime - if timestamp.Valid { - var err error - et, err = qvalue.NewExtendedTime(timestamp.Time, qvalue.DateTimeKindType, "") - if err != nil { - return qvalue.QValue{}, fmt.Errorf("failed to create ExtendedTime: %w", err) - } - } - val = qvalue.QValue{Kind: qvalue.QValueKindETime, Value: et} - case pgtype.DateOID: - date := value.(*pgtype.Date) - var et *qvalue.ExtendedTime - if date.Valid { - var err error - et, err = qvalue.NewExtendedTime(date.Time, qvalue.DateKindType, "") - if err != nil { - return qvalue.QValue{}, fmt.Errorf("failed to create ExtendedTime: %w", err) - } - } - val = qvalue.QValue{Kind: qvalue.QValueKindETime, Value: et} - case pgtype.TimeOID: - timeVal := value.(*pgtype.Text) - var et *qvalue.ExtendedTime - if timeVal.Valid { - t, err := time.Parse("15:04:05.999999", timeVal.String) - if err != nil { - return qvalue.QValue{}, fmt.Errorf("failed to parse time: %w", err) - } - et, err = qvalue.NewExtendedTime(t, qvalue.TimeKindType, "") - if err != nil { - return qvalue.QValue{}, fmt.Errorf("failed to create ExtendedTime: %w", err) - } - } - val = qvalue.QValue{Kind: qvalue.QValueKindETime, Value: et} - case pgtype.BoolOID: - boolVal := value.(*pgtype.Bool) - if boolVal.Valid { - val = qvalue.QValue{Kind: qvalue.QValueKindBoolean, Value: boolVal.Bool} - } else { - val = qvalue.QValue{Kind: qvalue.QValueKindBoolean, Value: nil} - } - case pgtype.JSONOID, pgtype.JSONBOID: - // TODO: improve JSON support - strVal := value.(*pgtype.Text) - if strVal != nil { - val = qvalue.QValue{Kind: qvalue.QValueKindJSON, Value: strVal.String} - } else { - val = qvalue.QValue{Kind: qvalue.QValueKindJSON, Value: nil} - } - case pgtype.Int2OID: - intVal := value.(*pgtype.Int2) - if intVal.Valid { - val = qvalue.QValue{Kind: qvalue.QValueKindInt16, Value: intVal.Int16} - } else { - val = qvalue.QValue{Kind: qvalue.QValueKindInt16, Value: nil} - } - case pgtype.Int4OID: - intVal := value.(*pgtype.Int4) - if intVal.Valid { - val = qvalue.QValue{Kind: qvalue.QValueKindInt32, Value: intVal.Int32} - } else { - val = qvalue.QValue{Kind: qvalue.QValueKindInt32, Value: nil} - } - case pgtype.Int8OID: - intVal := value.(*pgtype.Int8) - if intVal.Valid { - val = qvalue.QValue{Kind: qvalue.QValueKindInt64, Value: intVal.Int64} - } else { - val = qvalue.QValue{Kind: qvalue.QValueKindInt64, Value: nil} - } - case pgtype.Float4OID: - floatVal := value.(*pgtype.Float4) - if floatVal.Valid { - val = qvalue.QValue{Kind: qvalue.QValueKindFloat32, Value: floatVal.Float32} - } else { - val = qvalue.QValue{Kind: qvalue.QValueKindFloat32, Value: nil} - } - case pgtype.Float8OID: - floatVal := value.(*pgtype.Float8) - if floatVal.Valid { - val = qvalue.QValue{Kind: qvalue.QValueKindFloat64, Value: floatVal.Float64} - } else { - val = qvalue.QValue{Kind: qvalue.QValueKindFloat64, Value: nil} - } - case pgtype.TextOID, pgtype.VarcharOID: - textVal := value.(*pgtype.Text) - if textVal.Valid { - val = qvalue.QValue{Kind: qvalue.QValueKindString, Value: textVal.String} - } else { - val = qvalue.QValue{Kind: qvalue.QValueKindString, Value: nil} - } - case pgtype.UUIDOID: - uuidVal := value.(*pgtype.UUID) - if uuidVal.Valid { - val = qvalue.QValue{Kind: qvalue.QValueKindUUID, Value: uuidVal.Bytes} - } else { - val = qvalue.QValue{Kind: qvalue.QValueKindUUID, Value: nil} - } - case pgtype.ByteaOID: - rawBytes := value.(*sql.RawBytes) - val = qvalue.QValue{Kind: qvalue.QValueKindBytes, Value: []byte(*rawBytes)} - case pgtype.NumericOID: - numVal := value.(*pgtype.Numeric) - rat, err := numericToRat(numVal) - if err != nil { - log.Warnf("failed to convert numeric [%v] to rat: %v", value, err) - val = qvalue.QValue{Kind: qvalue.QValueKindNumeric, Value: nil} - } else { - val = qvalue.QValue{Kind: qvalue.QValueKindNumeric, Value: rat} - } - default: - typ, _ := pgtype.NewMap().TypeForOID(oid) - fmt.Printf("QValueKindInvalid => oid: %v, typename: %v\n", oid, typ) - val = qvalue.QValue{Kind: qvalue.QValueKindInvalid, Value: nil} - } - - return val, nil -} - -func numericToRat(numVal *pgtype.Numeric) (*big.Rat, error) { - if numVal.Valid { - if numVal.NaN { - return nil, errors.New("numeric value is NaN") - } - - switch numVal.InfinityModifier { - case pgtype.NegativeInfinity, pgtype.Infinity: - return nil, errors.New("numeric value is infinity") - } - - rat := new(big.Rat) - - rat.SetInt(numVal.Int) - divisor := new(big.Rat).SetFloat64(math.Pow10(int(-numVal.Exp))) - rat.Quo(rat, divisor) - - return rat, nil - } - - // handle invalid numeric - return nil, errors.New("invalid numeric") -} diff --git a/flow/connectors/postgres/qvalue_convert.go b/flow/connectors/postgres/qvalue_convert.go new file mode 100644 index 0000000000..d9b821f2a2 --- /dev/null +++ b/flow/connectors/postgres/qvalue_convert.go @@ -0,0 +1,195 @@ +package connpostgres + +import ( + "encoding/json" + "errors" + "fmt" + "math" + "math/big" + "strings" + "time" + + "github.com/PeerDB-io/peer-flow/model/qvalue" + "github.com/jackc/pgx/v5/pgtype" + log "github.com/sirupsen/logrus" +) + +func getQValueKindForPostgresOID(oid uint32) qvalue.QValueKind { + switch oid { + case pgtype.BoolOID: + return qvalue.QValueKindBoolean + case pgtype.Int2OID: + return qvalue.QValueKindInt16 + case pgtype.Int4OID: + return qvalue.QValueKindInt32 + case pgtype.Int8OID: + return qvalue.QValueKindInt64 + case pgtype.Float4OID: + return qvalue.QValueKindFloat32 + case pgtype.Float8OID: + return qvalue.QValueKindFloat64 + case pgtype.TextOID, pgtype.VarcharOID, pgtype.BPCharOID: + return qvalue.QValueKindString + case pgtype.ByteaOID: + return qvalue.QValueKindBytes + case pgtype.JSONOID, pgtype.JSONBOID: + return qvalue.QValueKindJSON + case pgtype.UUIDOID: + return qvalue.QValueKindUUID + case pgtype.TimeOID: + return qvalue.QValueKindTime + case pgtype.DateOID: + return qvalue.QValueKindDate + case pgtype.TimestampOID: + return qvalue.QValueKindTimestamp + case pgtype.TimestamptzOID: + return qvalue.QValueKindTimestampTZ + case pgtype.NumericOID: + return qvalue.QValueKindNumeric + default: + typeName, ok := pgtype.NewMap().TypeForOID(oid) + if !ok { + // workaround for TIMETZ not being defined by this pgtype + if oid == 1266 { + return qvalue.QValueKindTimeTZ + } + log.Warnf("failed to get type name for oid: %v", oid) + return qvalue.QValueKindInvalid + } else { + log.Warnf("unsupported field type: %v - type name - %s", oid, typeName.Name) + return qvalue.QValueKindInvalid + } + } +} + +func parseFieldFromQValueKind(qvalueKind qvalue.QValueKind, value interface{}) (*qvalue.QValue, error) { + var val *qvalue.QValue = nil + + switch qvalueKind { + case qvalue.QValueKindTimestamp: + timestamp := value.(time.Time) + val = &qvalue.QValue{Kind: qvalue.QValueKindTimestamp, Value: timestamp} + case qvalue.QValueKindTimestampTZ: + timestamp := value.(time.Time) + val = &qvalue.QValue{Kind: qvalue.QValueKindTimestampTZ, Value: timestamp} + case qvalue.QValueKindDate: + date := value.(time.Time) + val = &qvalue.QValue{Kind: qvalue.QValueKindDate, Value: date} + case qvalue.QValueKindTime: + timeVal := value.(pgtype.Time) + if timeVal.Valid { + var timeValStr any + timeValStr, err := timeVal.Value() + if err != nil { + return nil, fmt.Errorf("failed to parse time: %w", err) + } + // edge case, only Postgres supports this extreme value for time + timeValStr = strings.Replace(timeValStr.(string), "24:00:00.000000", "23:59:59.999999", 1) + t, err := time.Parse("15:04:05.999999", timeValStr.(string)) + t = t.AddDate(1970, 0, 0) + if err != nil { + return nil, fmt.Errorf("failed to parse time: %w", err) + } + val = &qvalue.QValue{Kind: qvalue.QValueKindTime, Value: t} + } + case qvalue.QValueKindTimeTZ: + timeVal := value.(string) + // edge case, Postgres supports this extreme value for time + timeVal = strings.Replace(timeVal, "24:00:00.000000", "23:59:59.999999", 1) + t, err := time.Parse("15:04:05.999999-0700", timeVal) + if err != nil { + return nil, fmt.Errorf("failed to parse time: %w", err) + } + t = t.AddDate(1970, 0, 0) + val = &qvalue.QValue{Kind: qvalue.QValueKindTime, Value: t} + + case qvalue.QValueKindBoolean: + boolVal := value.(bool) + val = &qvalue.QValue{Kind: qvalue.QValueKindBoolean, Value: boolVal} + case qvalue.QValueKindJSON: + // TODO: improve JSON support + jsonVal := value.(map[string]interface{}) + jsonValStr, err := json.Marshal(jsonVal) + if err != nil { + return nil, fmt.Errorf("failed to parse json: %w", err) + } + val = &qvalue.QValue{Kind: qvalue.QValueKindJSON, Value: string(jsonValStr)} + case qvalue.QValueKindInt16: + intVal := value.(int16) + val = &qvalue.QValue{Kind: qvalue.QValueKindInt16, Value: intVal} + case qvalue.QValueKindInt32: + intVal := value.(int32) + val = &qvalue.QValue{Kind: qvalue.QValueKindInt32, Value: intVal} + case qvalue.QValueKindInt64: + intVal := value.(int64) + val = &qvalue.QValue{Kind: qvalue.QValueKindInt64, Value: intVal} + case qvalue.QValueKindFloat32: + floatVal := value.(float32) + val = &qvalue.QValue{Kind: qvalue.QValueKindFloat32, Value: floatVal} + case qvalue.QValueKindFloat64: + floatVal := value.(float64) + val = &qvalue.QValue{Kind: qvalue.QValueKindFloat64, Value: floatVal} + case qvalue.QValueKindString: + textVal := value.(string) + val = &qvalue.QValue{Kind: qvalue.QValueKindString, Value: textVal} + case qvalue.QValueKindUUID: + switch value.(type) { + case string: + val = &qvalue.QValue{Kind: qvalue.QValueKindUUID, Value: value} + case [16]byte: + val = &qvalue.QValue{Kind: qvalue.QValueKindUUID, Value: value} + default: + return nil, fmt.Errorf("failed to parse UUID: %v", value) + } + case qvalue.QValueKindBytes: + rawBytes := value.([]byte) + val = &qvalue.QValue{Kind: qvalue.QValueKindBytes, Value: rawBytes} + // TODO: check for handling of QValueKindBit + case qvalue.QValueKindNumeric: + numVal := value.(pgtype.Numeric) + if numVal.Valid { + rat, err := numericToRat(&numVal) + if err != nil { + return nil, fmt.Errorf("failed to convert numeric [%v] to rat: %w", value, err) + } + val = &qvalue.QValue{Kind: qvalue.QValueKindNumeric, Value: rat} + } + default: + log.Errorf("unhandled QValueKind => %v\n", qvalueKind) + return nil, fmt.Errorf("unhandled QValueKind => %v", qvalueKind) + } + + // parsing into pgtype failed. + if val == nil { + return nil, fmt.Errorf("failed to parse value %v into QValueKind %v", value, qvalueKind) + } + return val, nil +} + +func parseFieldFromPostgresOID(oid uint32, value interface{}) (*qvalue.QValue, error) { + return parseFieldFromQValueKind(getQValueKindForPostgresOID(oid), value) +} + +func numericToRat(numVal *pgtype.Numeric) (*big.Rat, error) { + if numVal.Valid { + if numVal.NaN { + return nil, errors.New("numeric value is NaN") + } + + switch numVal.InfinityModifier { + case pgtype.NegativeInfinity, pgtype.Infinity: + return nil, errors.New("numeric value is infinity") + } + + rat := new(big.Rat) + + rat.SetInt(numVal.Int) + divisor := new(big.Rat).SetFloat64(math.Pow10(int(-numVal.Exp))) + rat.Quo(rat, divisor) + + return rat, nil + } + + // handle invalid numeric + return nil, errors.New("invalid numeric") +} diff --git a/flow/connectors/snowflake/avro_file_writer_test.go b/flow/connectors/snowflake/avro_file_writer_test.go index c5f187e2e7..41b76513c6 100644 --- a/flow/connectors/snowflake/avro_file_writer_test.go +++ b/flow/connectors/snowflake/avro_file_writer_test.go @@ -19,7 +19,7 @@ func createQValue(t *testing.T, kind qvalue.QValueKind, placeHolder int) qvalue. switch kind { case qvalue.QValueKindInt16, qvalue.QValueKindInt32, qvalue.QValueKindInt64: value = int64(placeHolder) - case qvalue.QValueKindFloat16, qvalue.QValueKindFloat32: + case qvalue.QValueKindFloat32: value = float32(placeHolder) case qvalue.QValueKindFloat64: value = float64(placeHolder) @@ -27,10 +27,9 @@ func createQValue(t *testing.T, kind qvalue.QValueKind, placeHolder int) qvalue. value = placeHolder%2 == 0 case qvalue.QValueKindString: value = fmt.Sprintf("string%d", placeHolder) - case qvalue.QValueKindETime: - et, err := qvalue.NewExtendedTime(time.Now(), qvalue.TimeKindType, "") - require.NoError(t, err) - value = et + case qvalue.QValueKindTimestamp, qvalue.QValueKindTimestampTZ, qvalue.QValueKindTime, + qvalue.QValueKindTimeTZ, qvalue.QValueKindDate: + value = time.Now() case qvalue.QValueKindNumeric: // create a new big.Rat for numeric data value = big.NewRat(int64(placeHolder), 1) @@ -58,7 +57,6 @@ func createQValue(t *testing.T, kind qvalue.QValueKind, placeHolder int) qvalue. func generateRecords(t *testing.T, nullable bool, numRows uint32) *model.QRecordBatch { allQValueKinds := []qvalue.QValueKind{ - qvalue.QValueKindFloat16, qvalue.QValueKindFloat32, qvalue.QValueKindFloat64, qvalue.QValueKindInt16, @@ -68,7 +66,11 @@ func generateRecords(t *testing.T, nullable bool, numRows uint32) *model.QRecord // qvalue.QValueKindArray, // qvalue.QValueKindStruct, qvalue.QValueKindString, - qvalue.QValueKindETime, + qvalue.QValueKindTimestamp, + qvalue.QValueKindTimestampTZ, + qvalue.QValueKindTime, + qvalue.QValueKindTimeTZ, + qvalue.QValueKindDate, qvalue.QValueKindNumeric, qvalue.QValueKindBytes, qvalue.QValueKindUUID, diff --git a/flow/connectors/snowflake/client.go b/flow/connectors/snowflake/client.go index 910ca285e6..ed64a9adb6 100644 --- a/flow/connectors/snowflake/client.go +++ b/flow/connectors/snowflake/client.go @@ -213,15 +213,12 @@ func toQValue(kind qvalue.QValueKind, val interface{}) (qvalue.QValue, error) { Value: ratVal, }, nil } - case qvalue.QValueKindETime: - if v, ok := val.(*time.Time); ok && v != nil { - etimeVal, err := qvalue.NewExtendedTime(*v, qvalue.DateTimeKindType, "") - if err != nil { - return qvalue.QValue{}, fmt.Errorf("failed to create ExtendedTime: %w", err) - } + case qvalue.QValueKindTimestamp, qvalue.QValueKindTimestampTZ, qvalue.QValueKindDate, + qvalue.QValueKindTime, qvalue.QValueKindTimeTZ: + if t, ok := val.(*time.Time); ok && t != nil { return qvalue.QValue{ - Kind: qvalue.QValueKindETime, - Value: etimeVal, + Kind: kind, + Value: *t, }, nil } case qvalue.QValueKindBytes: @@ -234,35 +231,8 @@ func toQValue(kind qvalue.QValueKind, val interface{}) (qvalue.QValue, error) { return qvalue.QValue{}, fmt.Errorf("[snowflakeclient] unsupported type %T for kind %s", val, kind) } -// databaseTypeNameToQValueKind converts a database type name to a QValueKind. -func databaseTypeNameToQValueKind(name string) (qvalue.QValueKind, error) { - switch name { - case "INT": - return qvalue.QValueKindInt32, nil - case "BIGINT": - return qvalue.QValueKindInt64, nil - case "FLOAT": - return qvalue.QValueKindFloat32, nil - case "DOUBLE", "REAL": - return qvalue.QValueKindFloat64, nil - case "VARCHAR", "CHAR", "TEXT": - return qvalue.QValueKindString, nil - case "BOOLEAN": - return qvalue.QValueKindBoolean, nil - case "DATETIME", "TIMESTAMP", "TIMESTAMP_LTZ", "TIMESTAMP_NTZ", "TIMESTAMP_TZ": - return qvalue.QValueKindETime, nil - case "BLOB", "BYTEA", "BINARY": - return qvalue.QValueKindBytes, nil - case "FIXED", "NUMBER": - return qvalue.QValueKindNumeric, nil - default: - // If type is unsupported, return an error - return "", fmt.Errorf("unsupported database type name: %s", name) - } -} - func columnTypeToQField(ct *sql.ColumnType) (*model.QField, error) { - qvKind, err := databaseTypeNameToQValueKind(ct.DatabaseTypeName()) + qvKind, err := snowflakeTypeToQValueKind(ct.DatabaseTypeName()) if err != nil { return nil, err } @@ -311,7 +281,8 @@ func (s *SnowflakeClient) ExecuteAndProcessQuery(query string) (*model.QRecordBa values := make([]interface{}, len(columns)) for i := range values { switch qfields[i].Type { - case qvalue.QValueKindETime: + case qvalue.QValueKindTimestamp, qvalue.QValueKindTimestampTZ, qvalue.QValueKindTime, + qvalue.QValueKindTimeTZ, qvalue.QValueKindDate: values[i] = new(time.Time) case qvalue.QValueKindInt16: values[i] = new(int16) @@ -375,10 +346,7 @@ func (s *SnowflakeClient) ExecuteAndProcessQuery(query string) (*model.QRecordBa func (s *SnowflakeClient) CreateTable(schema *model.QRecordSchema, schemaName string, tableName string) error { var fields []string for _, field := range schema.Fields { - snowflakeType, err := qValueKindToSnowflakeColTypeString(field.Type) - if err != nil { - return err - } + snowflakeType := qValueKindToSnowflakeType(string(field.Type)) fields = append(fields, fmt.Sprintf(`"%s" %s`, field.Name, snowflakeType)) } @@ -392,24 +360,3 @@ func (s *SnowflakeClient) CreateTable(schema *model.QRecordSchema, schemaName st return nil } - -func qValueKindToSnowflakeColTypeString(val qvalue.QValueKind) (string, error) { - switch val { - case qvalue.QValueKindInt32, qvalue.QValueKindInt64: - return "INT", nil - case qvalue.QValueKindFloat32, qvalue.QValueKindFloat64: - return "FLOAT", nil - case qvalue.QValueKindString: - return "STRING", nil - case qvalue.QValueKindBoolean: - return "BOOLEAN", nil - case qvalue.QValueKindETime: - return "TIMESTAMP_LTZ", nil - case qvalue.QValueKindBytes: - return "BINARY", nil - case qvalue.QValueKindNumeric: - return "NUMERIC(38,32)", nil - default: - return "", fmt.Errorf("unsupported QValueKind: %v", val) - } -} diff --git a/flow/connectors/snowflake/qvalue_convert.go b/flow/connectors/snowflake/qvalue_convert.go new file mode 100644 index 0000000000..558a9b8373 --- /dev/null +++ b/flow/connectors/snowflake/qvalue_convert.go @@ -0,0 +1,84 @@ +package connsnowflake + +import ( + "fmt" + + "github.com/PeerDB-io/peer-flow/model/qvalue" +) + +func qValueKindToSnowflakeType(colType string) string { + switch qvalue.QValueKind(colType) { + case qvalue.QValueKindBoolean: + return "BOOLEAN" + // integer types + case qvalue.QValueKindInt16, qvalue.QValueKindInt32, qvalue.QValueKindInt64: + return "INTEGER" + // decimal types + // The names FLOAT, FLOAT4, and FLOAT8 are for compatibility with other systems + // Snowflake treats all three as 64-bit floating-point numbers. + case qvalue.QValueKindFloat32, qvalue.QValueKindFloat64: + return "FLOAT" + case qvalue.QValueKindNumeric: + return "NUMBER(38, 9)" + // string related STRING , TEXT , NVARCHAR , + // NVARCHAR2 , CHAR VARYING , NCHAR VARYING + //Synonymous with VARCHAR. + case qvalue.QValueKindString: + return "STRING" + // json also is stored as string for now + case qvalue.QValueKindJSON: + return "STRING" + // time related + case qvalue.QValueKindTimestamp: + return "TIMESTAMP_NTZ" + case qvalue.QValueKindTimestampTZ: + return "TIMESTAMP_TZ" + case qvalue.QValueKindTime: + return "TIME" + case qvalue.QValueKindDate: + return "DATE" + // handle INTERVAL types again + // case model.ColumnTypeTimeWithTimeZone, model.ColumnTypeInterval: + // return "STRING" + // bytes + case qvalue.QValueKindBit, qvalue.QValueKindBytes: + return "BINARY" + // rest will be strings + default: + return "STRING" + } +} + +// snowflakeTypeToQValueKind converts a database type name to a QValueKind. +func snowflakeTypeToQValueKind(name string) (qvalue.QValueKind, error) { + switch name { + case "INT": + return qvalue.QValueKindInt32, nil + case "BIGINT": + return qvalue.QValueKindInt64, nil + case "FLOAT": + return qvalue.QValueKindFloat32, nil + case "DOUBLE", "REAL": + return qvalue.QValueKindFloat64, nil + case "VARCHAR", "CHAR", "TEXT": + return qvalue.QValueKindString, nil + case "BOOLEAN": + return qvalue.QValueKindBoolean, nil + // assuming TIMESTAMP is an alias to TIMESTAMP_NTZ, which is the default. + case "DATETIME", "TIMESTAMP", "TIMESTAMP_NTZ": + return qvalue.QValueKindTimestamp, nil + case "TIMESTAMP_TZ": + return qvalue.QValueKindTimestampTZ, nil + case "TIME": + return qvalue.QValueKindTime, nil + case "DATE": + return qvalue.QValueKindDate, nil + case "BLOB", "BYTEA", "BINARY": + return qvalue.QValueKindBytes, nil + case "FIXED", "NUMBER", "DECIMAL", "NUMERIC": + return qvalue.QValueKindNumeric, nil + default: + // If type is unsupported, return an error + return "", fmt.Errorf("unsupported database type name: %s", name) + } +} diff --git a/flow/connectors/snowflake/snowflake.go b/flow/connectors/snowflake/snowflake.go index 49300adf8c..0ce38070e1 100644 --- a/flow/connectors/snowflake/snowflake.go +++ b/flow/connectors/snowflake/snowflake.go @@ -13,6 +13,7 @@ import ( "github.com/PeerDB-io/peer-flow/connectors/utils" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" + "github.com/PeerDB-io/peer-flow/model/qvalue" util "github.com/PeerDB-io/peer-flow/utils" "github.com/google/uuid" log "github.com/sirupsen/logrus" @@ -92,7 +93,6 @@ type snowflakeRawRecord struct { recordType int matchData string batchID int64 - items map[string]interface{} unchangedToastColumns string } @@ -380,7 +380,7 @@ func (c *SnowflakeConnector) SyncRecords(req *model.SyncRecordsRequest) (*model. switch typedRecord := record.(type) { case *model.InsertRecord: // json.Marshal converts bytes in Hex automatically to BASE64 string. - itemsJSON, err := json.Marshal(typedRecord.Items) + itemsJSON, err := typedRecord.Items.ToJSON() if err != nil { return nil, fmt.Errorf("failed to serialize insert record items to JSON: %w", err) } @@ -390,19 +390,18 @@ func (c *SnowflakeConnector) SyncRecords(req *model.SyncRecordsRequest) (*model. uid: uuid.New().String(), timestamp: time.Now().UnixNano(), destinationTableName: typedRecord.DestinationTableName, - data: string(itemsJSON), + data: itemsJSON, recordType: 0, matchData: "", batchID: syncBatchID, - items: typedRecord.Items, unchangedToastColumns: utils.KeysToString(typedRecord.UnchangedToastColumns), }) case *model.UpdateRecord: - newItemsJSON, err := json.Marshal(typedRecord.NewItems) + newItemsJSON, err := typedRecord.NewItems.ToJSON() if err != nil { return nil, fmt.Errorf("failed to serialize update record new items to JSON: %w", err) } - oldItemsJSON, err := json.Marshal(typedRecord.OldItems) + oldItemsJSON, err := typedRecord.OldItems.ToJSON() if err != nil { return nil, fmt.Errorf("failed to serialize update record old items to JSON: %w", err) } @@ -412,15 +411,14 @@ func (c *SnowflakeConnector) SyncRecords(req *model.SyncRecordsRequest) (*model. uid: uuid.New().String(), timestamp: time.Now().UnixNano(), destinationTableName: typedRecord.DestinationTableName, - data: string(newItemsJSON), + data: newItemsJSON, recordType: 1, - matchData: string(oldItemsJSON), + matchData: oldItemsJSON, batchID: syncBatchID, - items: typedRecord.NewItems, unchangedToastColumns: utils.KeysToString(typedRecord.UnchangedToastColumns), }) case *model.DeleteRecord: - itemsJSON, err := json.Marshal(typedRecord.Items) + itemsJSON, err := typedRecord.Items.ToJSON() if err != nil { return nil, fmt.Errorf("failed to serialize delete record items to JSON: %w", err) } @@ -430,11 +428,10 @@ func (c *SnowflakeConnector) SyncRecords(req *model.SyncRecordsRequest) (*model. uid: uuid.New().String(), timestamp: time.Now().UnixNano(), destinationTableName: typedRecord.DestinationTableName, - data: string(itemsJSON), + data: itemsJSON, recordType: 2, - matchData: string(itemsJSON), + matchData: itemsJSON, batchID: syncBatchID, - items: typedRecord.Items, unchangedToastColumns: utils.KeysToString(typedRecord.UnchangedToastColumns), }) default: @@ -669,57 +666,15 @@ func (c *SnowflakeConnector) checkIfTableExists(schemaIdentifier string, tableId return result, nil } -func getSnowflakeTypeForGenericColumnType(colType string) string { - switch colType { - case model.ColumnTypeBoolean: - return "BOOLEAN" - // integer types - case model.ColumnTypeInt16, model.ColumnTypeInt32, model.ColumnTypeInt64: - return "INTEGER" - // decimal types - // The names FLOAT, FLOAT4, and FLOAT8 are for compatibility with other systems - // Snowflake treats all three as 64-bit floating-point numbers. - case model.ColumnTypeFloat16, model.ColumnTypeFloat32, model.ColumnTypeFloat64: - return "FLOAT" - case model.ColumnTypeNumeric: - return "NUMBER" - // string related STRING , TEXT , NVARCHAR , - // NVARCHAR2 , CHAR VARYING , NCHAR VARYING - //Synonymous with VARCHAR. - case model.ColumnTypeString: - return "STRING" - // json also is stored as string for now - case model.ColumnTypeJSON: - return "STRING" - // time related - case model.ColumnTypeTimestamp: - return "TIMESTAMP_NTZ" - case model.ColumnTypeTimeStampWithTimeZone: - return "TIMESTAMP_TZ" - case model.ColumnTypeTime: - return "TIME" - case model.ColumnTypeDate: - return "TIMESTAMP_NTZ" - case model.ColumnTypeTimeWithTimeZone, model.ColumnTypeInterval: - return "STRING" - // bytes - case model.ColumnHexBytes, model.ColumnHexBit: - return "BINARY" - // rest will be strings - default: - return "STRING" - } -} - func generateCreateTableSQLForNormalizedTable(sourceTableIdentifier string, sourceTableSchema *protos.TableSchema) string { createTableSQLArray := make([]string, 0, len(sourceTableSchema.Columns)) for columnName, genericColumnType := range sourceTableSchema.Columns { if sourceTableSchema.PrimaryKeyColumn == strings.ToLower(columnName) { createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("%s %s PRIMARY KEY,", - columnName, getSnowflakeTypeForGenericColumnType(genericColumnType))) + columnName, qValueKindToSnowflakeType(genericColumnType))) } else { createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("%s %s,", columnName, - getSnowflakeTypeForGenericColumnType(genericColumnType))) + qValueKindToSnowflakeType(genericColumnType))) } } return fmt.Sprintf(createNormalizedTableSQL, sourceTableIdentifier, @@ -768,19 +723,20 @@ func (c *SnowflakeConnector) generateAndExecuteMergeStatement(destinationTableId flattenedCastsSQLArray := make([]string, 0, len(normalizedTableSchema.Columns)) for columnName, genericColumnType := range normalizedTableSchema.Columns { - sfType := getSnowflakeTypeForGenericColumnType(genericColumnType) - switch genericColumnType { - case model.ColumnHexBytes: + sfType := qValueKindToSnowflakeType(genericColumnType) + switch qvalue.QValueKind(genericColumnType) { + case qvalue.QValueKindBytes: flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("BASE64_DECODE_BINARY(%s:%s) "+ "AS %s,", toVariantColumnName, columnName, columnName)) - case model.ColumnHexBit: + case qvalue.QValueKindBit: // "c2": {"Bytes": "gA==", "Len": 1,"Valid": true} flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("BASE64_DECODE_BINARY(%s:%s:Bytes) "+ "AS %s,", toVariantColumnName, columnName, columnName)) - case model.ColumnTypeTime: - flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("TIME_FROM_PARTS(0,0,0,%s:%s:"+ - "Microseconds*1000) "+ - "AS %s,", toVariantColumnName, columnName, columnName)) + // TODO: https://github.com/PeerDB-io/peerdb/issues/189 - handle time types and interval types + // case model.ColumnTypeTime: + // flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("TIME_FROM_PARTS(0,0,0,%s:%s:"+ + // "Microseconds*1000) "+ + // "AS %s,", toVariantColumnName, columnName, columnName)) default: flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("CAST(%s:%s AS %s) AS %s,", toVariantColumnName, diff --git a/flow/e2e/bigquery_helper.go b/flow/e2e/bigquery_helper.go index d6cc6e2333..b15bf04e38 100644 --- a/flow/e2e/bigquery_helper.go +++ b/flow/e2e/bigquery_helper.go @@ -10,6 +10,7 @@ import ( "time" "cloud.google.com/go/bigquery" + "cloud.google.com/go/civil" peer_bq "github.com/PeerDB-io/peer-flow/connectors/bigquery" "github.com/PeerDB-io/peer-flow/generated/protos" "github.com/PeerDB-io/peer-flow/model" @@ -209,15 +210,10 @@ func toQValue(bqValue bigquery.Value) (qvalue.QValue, error) { return qvalue.QValue{Kind: qvalue.QValueKindString, Value: v}, nil case bool: return qvalue.QValue{Kind: qvalue.QValueKindBoolean, Value: v}, nil + case civil.Date: + return qvalue.QValue{Kind: qvalue.QValueKindDate, Value: bqValue.(civil.Date).In(time.UTC)}, nil case time.Time: - val, err := qvalue.NewExtendedTime(v, qvalue.DateTimeKindType, "") - if err != nil { - return qvalue.QValue{}, fmt.Errorf("failed to create ExtendedTime: %w", err) - } - return qvalue.QValue{ - Kind: qvalue.QValueKindETime, - Value: val, - }, nil + return qvalue.QValue{Kind: qvalue.QValueKindTimestamp, Value: v}, nil case *big.Rat: return qvalue.QValue{Kind: qvalue.QValueKindNumeric, Value: v}, nil case []uint8: @@ -228,38 +224,8 @@ func toQValue(bqValue bigquery.Value) (qvalue.QValue, error) { } } -// bqFieldTypeToQValueKind converts a bigquery FieldType to a QValueKind. -func bqFieldTypeToQValueKind(fieldType bigquery.FieldType) (qvalue.QValueKind, error) { - switch fieldType { - case bigquery.StringFieldType: - return qvalue.QValueKindString, nil - case bigquery.BytesFieldType: - return qvalue.QValueKindBytes, nil - case bigquery.IntegerFieldType: - return qvalue.QValueKindInt64, nil - case bigquery.FloatFieldType: - return qvalue.QValueKindFloat64, nil - case bigquery.BooleanFieldType: - return qvalue.QValueKindBoolean, nil - case bigquery.TimestampFieldType: - return qvalue.QValueKindETime, nil - case bigquery.RecordFieldType: - return qvalue.QValueKindStruct, nil - case bigquery.DateFieldType: - return qvalue.QValueKindETime, nil - case bigquery.TimeFieldType: - return qvalue.QValueKindETime, nil - case bigquery.NumericFieldType: - return qvalue.QValueKindNumeric, nil - case bigquery.GeographyFieldType: - return qvalue.QValueKindString, nil - default: - return "", fmt.Errorf("unsupported bigquery field type: %v", fieldType) - } -} - func bqFieldSchemaToQField(fieldSchema *bigquery.FieldSchema) (*model.QField, error) { - qValueKind, err := bqFieldTypeToQValueKind(fieldSchema.Type) + qValueKind, err := peer_bq.BigQueryTypeToQValueKind(fieldSchema.Type) if err != nil { return nil, err } @@ -395,7 +361,7 @@ func qValueKindToBqColTypeString(val qvalue.QValueKind) (string, error) { return "STRING", nil case qvalue.QValueKindBoolean: return "BOOL", nil - case qvalue.QValueKindETime: + case qvalue.QValueKindTimestamp: return "TIMESTAMP", nil case qvalue.QValueKindBytes: return "BYTES", nil diff --git a/flow/e2e/qrep_flow_test.go b/flow/e2e/qrep_flow_test.go index 6b0f7e6ddc..0a210d1656 100644 --- a/flow/e2e/qrep_flow_test.go +++ b/flow/e2e/qrep_flow_test.go @@ -94,10 +94,10 @@ func getOwnersSchema() *model.QRecordSchema { Fields: []*model.QField{ {Name: "id", Type: qvalue.QValueKindString, Nullable: true}, {Name: "card_id", Type: qvalue.QValueKindString, Nullable: true}, - {Name: "from", Type: qvalue.QValueKindETime, Nullable: true}, + {Name: "from", Type: qvalue.QValueKindTimestamp, Nullable: true}, {Name: "price", Type: qvalue.QValueKindNumeric, Nullable: true}, - {Name: "created_at", Type: qvalue.QValueKindETime, Nullable: true}, - {Name: "updated_at", Type: qvalue.QValueKindETime, Nullable: true}, + {Name: "created_at", Type: qvalue.QValueKindTimestamp, Nullable: true}, + {Name: "updated_at", Type: qvalue.QValueKindTimestamp, Nullable: true}, {Name: "transaction_hash", Type: qvalue.QValueKindBytes, Nullable: true}, {Name: "ownerable_type", Type: qvalue.QValueKindString, Nullable: true}, {Name: "ownerable_id", Type: qvalue.QValueKindString, Nullable: true}, @@ -116,9 +116,9 @@ func getOwnersSchema() *model.QRecordSchema { {Name: "asset_id", Type: qvalue.QValueKindNumeric, Nullable: true}, {Name: "status", Type: qvalue.QValueKindInt64, Nullable: true}, {Name: "transaction_id", Type: qvalue.QValueKindString, Nullable: true}, - {Name: "settled_at", Type: qvalue.QValueKindETime, Nullable: true}, + {Name: "settled_at", Type: qvalue.QValueKindTimestamp, Nullable: true}, {Name: "reference_id", Type: qvalue.QValueKindString, Nullable: true}, - {Name: "settle_at", Type: qvalue.QValueKindETime, Nullable: true}, + {Name: "settle_at", Type: qvalue.QValueKindTimestamp, Nullable: true}, {Name: "settlement_delay_reason", Type: qvalue.QValueKindInt64, Nullable: true}, }, } diff --git a/flow/go.mod b/flow/go.mod index 8bacdde3ce..e815d6132b 100644 --- a/flow/go.mod +++ b/flow/go.mod @@ -13,7 +13,6 @@ require ( github.com/jackc/pgx/v5 v5.3.1 github.com/jmoiron/sqlx v1.3.5 github.com/joho/godotenv v1.5.1 - github.com/linkedin/goavro v2.1.0+incompatible github.com/linkedin/goavro/v2 v2.12.0 github.com/sirupsen/logrus v1.9.3 github.com/snowflakedb/gosnowflake v1.6.21 diff --git a/flow/model/column.go b/flow/model/column.go deleted file mode 100644 index deff60597f..0000000000 --- a/flow/model/column.go +++ /dev/null @@ -1,23 +0,0 @@ -package model - -// ColumnType is an enum for the column type, which are generic across all connectors. -const ( - ColumnTypeInt16 = "int16" - ColumnTypeInt32 = "int32" - ColumnTypeInt64 = "int64" - ColumnHexBytes = "bytes" - ColumnHexBit = "bit" - ColumnTypeBoolean = "bool" - ColumnTypeFloat16 = "float16" - ColumnTypeFloat32 = "float32" - ColumnTypeFloat64 = "float64" - ColumnTypeString = "string" - ColumnTypeNumeric = "numeric" - ColumnTypeJSON = "json" - ColumnTypeInterval = "interval" - ColumnTypeTimestamp = "timestamp" - ColumnTypeTimeStampWithTimeZone = "timestamptz" - ColumnTypeTime = "time" - ColumnTypeTimeWithTimeZone = "timetz" - ColumnTypeDate = "date" -) diff --git a/flow/model/model.go b/flow/model/model.go index 56d2d09aea..e8e985826a 100644 --- a/flow/model/model.go +++ b/flow/model/model.go @@ -1,9 +1,11 @@ package model import ( + "encoding/json" "time" "github.com/PeerDB-io/peer-flow/generated/protos" + "github.com/PeerDB-io/peer-flow/model/qvalue" ) type PullRecordsRequest struct { @@ -29,7 +31,31 @@ type Record interface { // get table name GetTableName() string // get columns and values for the record - GetItems() map[string]interface{} + GetItems() RecordItems +} + +type RecordItems map[string]qvalue.QValue + +func (r RecordItems) ToJSON() (string, error) { + jsonStruct := make(map[string]interface{}) + for k, v := range r { + var err error + switch v.Kind { + case qvalue.QValueKindTimestamp, qvalue.QValueKindTimestampTZ, qvalue.QValueKindDate, + qvalue.QValueKindTime, qvalue.QValueKindTimeTZ: + jsonStruct[k], err = v.GoTimeConvert() + if err != nil { + return "", err + } + default: + jsonStruct[k] = v.Value + } + } + jsonBytes, err := json.Marshal(jsonStruct) + if err != nil { + return "", err + } + return string(jsonBytes), nil } type InsertRecord struct { @@ -42,7 +68,7 @@ type InsertRecord struct { // CommitID is the ID of the commit corresponding to this record. CommitID int64 // Items is a map of column name to value. - Items map[string]interface{} + Items RecordItems // unchanged toast columns UnchangedToastColumns map[string]bool } @@ -56,7 +82,7 @@ func (r *InsertRecord) GetTableName() string { return r.DestinationTableName } -func (r *InsertRecord) GetItems() map[string]interface{} { +func (r *InsertRecord) GetItems() RecordItems { return r.Items } @@ -68,9 +94,9 @@ type UpdateRecord struct { // Name of the destination table DestinationTableName string // OldItems is a map of column name to value. - OldItems map[string]interface{} + OldItems RecordItems // NewItems is a map of column name to value. - NewItems map[string]interface{} + NewItems RecordItems // unchanged toast columns UnchangedToastColumns map[string]bool } @@ -85,7 +111,7 @@ func (r *UpdateRecord) GetTableName() string { return r.DestinationTableName } -func (r *UpdateRecord) GetItems() map[string]interface{} { +func (r *UpdateRecord) GetItems() RecordItems { return r.NewItems } @@ -97,7 +123,7 @@ type DeleteRecord struct { // CheckPointID is the ID of the record. CheckPointID int64 // Items is a map of column name to value. - Items map[string]interface{} + Items RecordItems // unchanged toast columns UnchangedToastColumns map[string]bool } @@ -111,7 +137,7 @@ func (r *DeleteRecord) GetTableName() string { return r.SourceTableName } -func (r *DeleteRecord) GetItems() map[string]interface{} { +func (r *DeleteRecord) GetItems() RecordItems { return r.Items } diff --git a/flow/model/qrecord_batch.go b/flow/model/qrecord_batch.go index 3675b92fcb..385a16f980 100644 --- a/flow/model/qrecord_batch.go +++ b/flow/model/qrecord_batch.go @@ -3,6 +3,7 @@ package model import ( "fmt" "math/big" + "time" "github.com/PeerDB-io/peer-flow/model/qvalue" "github.com/google/uuid" @@ -73,7 +74,7 @@ func (src *QRecordBatchCopyFromSource) Values() ([]interface{}, error) { values := make([]interface{}, numEntries) for i, qValue := range record.Entries { switch qValue.Kind { - case qvalue.QValueKindFloat16, qvalue.QValueKindFloat32: + case qvalue.QValueKindFloat32: v, ok := qValue.Value.(float32) if !ok { src.err = fmt.Errorf("invalid float32 value") @@ -111,15 +112,23 @@ func (src *QRecordBatchCopyFromSource) Values() ([]interface{}, error) { case qvalue.QValueKindString: values[i] = qValue.Value.(string) - case qvalue.QValueKindETime: - et, ok := qValue.Value.(*qvalue.ExtendedTime) + case qvalue.QValueKindTimestamp: + t, ok := qValue.Value.(time.Time) if !ok { src.err = fmt.Errorf("invalid ExtendedTime value") return nil, src.err } - timestamp := pgtype.Timestamp{Time: et.Time, Valid: true} + timestamp := pgtype.Timestamp{Time: t, Valid: true} values[i] = timestamp + case qvalue.QValueKindTimestampTZ: + t, ok := qValue.Value.(time.Time) + if !ok { + src.err = fmt.Errorf("invalid ExtendedTime value") + return nil, src.err + } + timestampTZ := pgtype.Timestamptz{Time: t, Valid: true} + values[i] = timestampTZ case qvalue.QValueKindUUID: if qValue.Value == nil { values[i] = nil diff --git a/flow/model/qvalue/avro_converter.go b/flow/model/qvalue/avro_converter.go index a391393819..8dcaca542f 100644 --- a/flow/model/qvalue/avro_converter.go +++ b/flow/model/qvalue/avro_converter.go @@ -4,9 +4,10 @@ import ( "errors" "fmt" "math/big" + "time" "github.com/google/uuid" - "github.com/linkedin/goavro" + "github.com/linkedin/goavro/v2" ) // QValueKindAvroSchema defines a structure for representing Avro schemas. @@ -41,7 +42,7 @@ func GetAvroSchemaFromQValueKind(kind QValueKind, nullable bool) (*QValueKindAvr return &QValueKindAvroSchema{ AvroLogicalSchema: "long", }, nil - case QValueKindFloat16, QValueKindFloat32: + case QValueKindFloat32: return &QValueKindAvroSchema{ AvroLogicalSchema: "float", }, nil @@ -66,11 +67,10 @@ func GetAvroSchemaFromQValueKind(kind QValueKind, nullable bool) (*QValueKindAvr "scale": 9, }, }, nil - case QValueKindETime: + case QValueKindTime, QValueKindTimeTZ, QValueKindDate, QValueKindTimestamp, QValueKindTimestampTZ: return &QValueKindAvroSchema{ AvroLogicalSchema: map[string]string{ - "type": "long", - "logicalType": "timestamp-micros", + "type": "string", }, }, nil case QValueKindJSON, QValueKindArray, QValueKindStruct, QValueKindBit: @@ -98,21 +98,33 @@ func (c *QValueAvroConverter) ToAvroValue() (interface{}, error) { switch c.Value.Kind { case QValueKindInvalid: return nil, fmt.Errorf("invalid QValueKind: %v", c.Value) - case QValueKindETime: - t, err := c.processExtendedTime() + case QValueKindTime, QValueKindTimeTZ, QValueKindDate, QValueKindTimestamp, QValueKindTimestampTZ: + t, err := c.processGoTime() if err != nil || t == nil { return t, err } + if c.TargetDWH == QDWHTypeSnowflake { + if c.Nullable { + return c.processNullableUnion("string", t.(string)) + } else { + return t.(string), nil + } + } if c.Nullable { - return goavro.Union("long.timestamp-micros", t), nil + return goavro.Union("long.timestamp-micros", t.(int64)), nil } else { - return t, nil + return t.(int64), nil } case QValueKindString: return c.processNullableUnion("string", c.Value.Value) - case QValueKindFloat16, QValueKindFloat32: + case QValueKindFloat32: return c.processNullableUnion("float", c.Value.Value) case QValueKindFloat64: + if c.TargetDWH == QDWHTypeSnowflake || c.TargetDWH == QDWHTypeBigQuery { + if f32Val, ok := c.Value.Value.(float32); ok { + return c.processNullableUnion("double", float64(f32Val)) + } + } return c.processNullableUnion("double", c.Value.Value) case QValueKindInt16, QValueKindInt32, QValueKindInt64: return c.processNullableUnion("long", c.Value.Value) @@ -139,32 +151,23 @@ func (c *QValueAvroConverter) ToAvroValue() (interface{}, error) { } } -func (c *QValueAvroConverter) processExtendedTime() (interface{}, error) { +func (c *QValueAvroConverter) processGoTime() (interface{}, error) { if c.Value.Value == nil && c.Nullable { return nil, nil } - et, ok := c.Value.Value.(*ExtendedTime) + t, ok := c.Value.Value.(time.Time) if !ok { - return nil, fmt.Errorf("invalid ExtendedTime value") + return nil, fmt.Errorf("invalid Time value") } - if et == nil { - return nil, nil - } - - switch et.NestedKind.Type { - case DateTimeKindType, DateKindType, TimeKindType: - ret := et.Time.UnixMicro() - // Snowflake has issues with avro timestamp types - // See: https://stackoverflow.com/questions/66104762/snowflake-date-column-have-incorrect-date-from-avro-file - if c.TargetDWH == QDWHTypeSnowflake { - ret = ret / 1000000 - } - return ret, nil - default: - return nil, fmt.Errorf("unsupported ExtendedTimeKindType: %s", et.NestedKind.Type) + ret := t.UnixMicro() + // Snowflake has issues with avro timestamp types, returning as string form of the int64 + // See: https://stackoverflow.com/questions/66104762/snowflake-date-column-have-incorrect-date-from-avro-file + if c.TargetDWH == QDWHTypeSnowflake { + return fmt.Sprint(ret), nil } + return ret, nil } func (c *QValueAvroConverter) processNullableUnion( diff --git a/flow/model/qvalue/etime.go b/flow/model/qvalue/etime.go deleted file mode 100644 index df5f8bbeb2..0000000000 --- a/flow/model/qvalue/etime.go +++ /dev/null @@ -1,72 +0,0 @@ -package qvalue - -import ( - "errors" - "time" -) - -type ExtendedTimeKindType string - -const ( - DateTimeKindType ExtendedTimeKindType = "datetime" - DateKindType ExtendedTimeKindType = "date" - TimeKindType ExtendedTimeKindType = "time" -) - -type ExtendedTime struct { - time.Time - NestedKind NestedKind -} - -type NestedKind struct { - Type ExtendedTimeKindType - Format string -} - -var ( - DateTime = NestedKind{ - Type: DateTimeKindType, - Format: time.RFC3339Nano, - } - - Date = NestedKind{ - Type: DateKindType, - Format: "2006-01-02", - } - - Time = NestedKind{ - Type: TimeKindType, - Format: "15:04:05.999999", - } -) - -func NewExtendedTime( - t time.Time, - kindType ExtendedTimeKindType, - originalFormat string, -) (*ExtendedTime, error) { - var nk NestedKind - - switch kindType { - case DateTimeKindType: - nk = DateTime - case DateKindType: - nk = Date - case TimeKindType: - nk = Time - default: - return nil, errors.New("invalid ExtendedTimeKindType") - } - - if originalFormat != "" { - nk = NestedKind{ - Type: nk.Type, - Format: originalFormat, - } - } - - return &ExtendedTime{ - Time: t, - NestedKind: nk, - }, nil -} diff --git a/flow/model/qvalue/kind.go b/flow/model/qvalue/kind.go index 032bb63153..be1545d30b 100644 --- a/flow/model/qvalue/kind.go +++ b/flow/model/qvalue/kind.go @@ -3,21 +3,24 @@ package qvalue type QValueKind string const ( - QValueKindInvalid QValueKind = "invalid" - QValueKindFloat16 QValueKind = "float16" - QValueKindFloat32 QValueKind = "float32" - QValueKindFloat64 QValueKind = "float64" - QValueKindInt16 QValueKind = "int16" - QValueKindInt32 QValueKind = "int32" - QValueKindInt64 QValueKind = "int64" - QValueKindBoolean QValueKind = "bool" - QValueKindArray QValueKind = "array" - QValueKindStruct QValueKind = "struct" - QValueKindString QValueKind = "string" - QValueKindETime QValueKind = "extended_time" - QValueKindNumeric QValueKind = "numeric" - QValueKindBytes QValueKind = "bytes" - QValueKindUUID QValueKind = "uuid" - QValueKindJSON QValueKind = "json" - QValueKindBit QValueKind = "bit" + QValueKindInvalid QValueKind = "invalid" + QValueKindFloat32 QValueKind = "float32" + QValueKindFloat64 QValueKind = "float64" + QValueKindInt16 QValueKind = "int16" + QValueKindInt32 QValueKind = "int32" + QValueKindInt64 QValueKind = "int64" + QValueKindBoolean QValueKind = "bool" + QValueKindArray QValueKind = "array" + QValueKindStruct QValueKind = "struct" + QValueKindString QValueKind = "string" + QValueKindTimestamp QValueKind = "timestamp" + QValueKindTimestampTZ QValueKind = "timestamptz" + QValueKindDate QValueKind = "date" + QValueKindTime QValueKind = "time" + QValueKindTimeTZ QValueKind = "timetz" + QValueKindNumeric QValueKind = "numeric" + QValueKindBytes QValueKind = "bytes" + QValueKindUUID QValueKind = "uuid" + QValueKindJSON QValueKind = "json" + QValueKindBit QValueKind = "bit" ) diff --git a/flow/model/qvalue/qvalue.go b/flow/model/qvalue/qvalue.go index 5be793ae9e..cf0b6f11bc 100644 --- a/flow/model/qvalue/qvalue.go +++ b/flow/model/qvalue/qvalue.go @@ -3,9 +3,11 @@ package qvalue import ( "bytes" "encoding/json" + "fmt" "math/big" "reflect" "strconv" + "time" "github.com/google/uuid" ) @@ -19,8 +21,6 @@ func (q *QValue) Equals(other *QValue) bool { switch q.Kind { case QValueKindInvalid: return false // both are invalid we always return false - case QValueKindFloat16: - return compareFloat32(q.Value, other.Value) case QValueKindFloat32: return compareFloat32(q.Value, other.Value) case QValueKindFloat64: @@ -39,8 +39,10 @@ func (q *QValue) Equals(other *QValue) bool { return compareStruct(q.Value, other.Value) case QValueKindString: return compareString(q.Value, other.Value) - case QValueKindETime: - return compareETime(q.Value, other.Value) + // all internally represented as a Golang time.Time + case QValueKindTime, QValueKindTimeTZ, QValueKindDate, + QValueKindTimestamp, QValueKindTimestampTZ: + return compareGoTime(q.Value, other.Value) case QValueKindNumeric: return compareNumeric(q.Value, other.Value) case QValueKindBytes: @@ -56,6 +58,22 @@ func (q *QValue) Equals(other *QValue) bool { return false } +func (q *QValue) GoTimeConvert() (string, error) { + if q.Kind == QValueKindTime { + return q.Value.(time.Time).Format("15:04:05.999999"), nil + } else if q.Kind == QValueKindTimeTZ { + return q.Value.(time.Time).Format("15:04:05.999999-0700"), nil + } else if q.Kind == QValueKindDate { + return q.Value.(time.Time).Format("2006-01-02"), nil + } else if q.Kind == QValueKindTimestamp { + return q.Value.(time.Time).Format("2006-01-02 15:04:05.999999"), nil + } else if q.Kind == QValueKindTimestampTZ { + return q.Value.(time.Time).Format("2006-01-02 15:04:05.999999-0700"), nil + } else { + return "", fmt.Errorf("unsupported QValueKind: %s", q.Kind) + } +} + func compareInt16(value1, value2 interface{}) bool { int1, ok1 := getInt16(value1) int2, ok2 := getInt16(value2) @@ -86,9 +104,9 @@ func compareFloat64(value1, value2 interface{}) bool { return ok1 && ok2 && float1 == float2 } -func compareETime(value1, value2 interface{}) bool { - et1, ok1 := value1.(*ExtendedTime) - et2, ok2 := value2.(*ExtendedTime) +func compareGoTime(value1, value2 interface{}) bool { + et1, ok1 := value1.(time.Time) + et2, ok2 := value2.(time.Time) if !ok1 || !ok2 { return false @@ -96,8 +114,8 @@ func compareETime(value1, value2 interface{}) bool { // TODO: this is a hack, we should be comparing the actual time values // currently this is only used for testing so that is OK. - t1 := et1.Time.UnixMilli() / 1000 - t2 := et2.Time.UnixMilli() / 1000 + t1 := et1.UnixMicro() + t2 := et2.UnixMicro() return t1 == t2 } diff --git a/flow/workflows/setup_flow.go b/flow/workflows/setup_flow.go index b5db2b8968..0131bde450 100644 --- a/flow/workflows/setup_flow.go +++ b/flow/workflows/setup_flow.go @@ -203,11 +203,12 @@ func (s *SetupFlowExecution) fetchTableSchemaAndSetupNormalizedTables( if err := fSrcTableSchema.Get(ctx, &srcTableSchema); err != nil { return nil, fmt.Errorf("failed to fetch schema for source table %s: %w", srcTableName, err) } - s.logger.Info("fetched schema for table %s for peer flow %s ", srcTableSchema, s.PeerFlowName) + s.logger.Info(fmt.Sprintf("fetched schema for table %s for peer flow %s ", srcTableSchema, s.PeerFlowName)) tableNameSchemaMapping[flowConnectionConfigs.TableNameMapping[srcTableName]] = srcTableSchema - s.logger.Info("setting up normalized table for table %s for peer flow - ", srcTableSchema, s.PeerFlowName) + s.logger.Info(fmt.Sprintf("setting up normalized table for table %s for peer flow - %s", + srcTableName, s.PeerFlowName)) // now setup the normalized tables on the destination peer setupConfig := &protos.SetupNormalizedTableInput{ diff --git a/images/banner.jpg b/images/banner.jpg new file mode 100755 index 0000000000..6037c9e7bb Binary files /dev/null and b/images/banner.jpg differ diff --git a/images/logo-light-transparent_copy_2.png b/images/logo-light-transparent_copy_2.png deleted file mode 100644 index 81b2021dca..0000000000 Binary files a/images/logo-light-transparent_copy_2.png and /dev/null differ diff --git a/nexus/Cargo.lock b/nexus/Cargo.lock index 4f41a9e804..7e78f44185 100644 --- a/nexus/Cargo.lock +++ b/nexus/Cargo.lock @@ -34,6 +34,7 @@ dependencies = [ "async-trait", "catalog", "flow-rs", + "lazy_static", "pem 1.1.1", "pt", "serde_json", diff --git a/nexus/analyzer/Cargo.toml b/nexus/analyzer/Cargo.toml index 626753becc..1680d9dcb9 100644 --- a/nexus/analyzer/Cargo.toml +++ b/nexus/analyzer/Cargo.toml @@ -10,6 +10,7 @@ anyhow = "1.0" async-trait = "0.1" catalog = { path = "../catalog" } flow-rs = { path = "../flow-rs" } +lazy_static = "1.4" pem = "1.1.0" pt = { path = "../pt" } sqlparser = { path = "../sqlparser-rs", features = ["visitor"] } diff --git a/nexus/analyzer/src/lib.rs b/nexus/analyzer/src/lib.rs index 66b09cec71..3cb13e3f04 100644 --- a/nexus/analyzer/src/lib.rs +++ b/nexus/analyzer/src/lib.rs @@ -11,12 +11,11 @@ use flow_rs::{FlowJob, FlowJobTableMapping, QRepFlowJob}; use pt::peers::{ peer::Config, BigqueryConfig, DbType, MongoConfig, Peer, PostgresConfig, SnowflakeConfig, }; -use serde_json::Number; +use qrep::process_options; +use sqlparser::ast::CreateMirror::{Select, CDC}; use sqlparser::ast::{visit_relations, visit_statements, FetchDirection, SqlOption, Statement}; -use sqlparser::ast::{ - CreateMirror::{Select, CDC}, - Value, -}; + +mod qrep; pub trait StatementAnalyzer { type Output; @@ -107,76 +106,6 @@ pub enum PeerDDL { }, } -impl PeerDDLAnalyzer { - fn parse_string_for_options( - raw_options: &HashMap<&str, &Value>, - processed_options: &mut HashMap, - key: &str, - is_required: bool, - accepted_values: Option<&[&str]>, - ) -> anyhow::Result<()> { - if raw_options.get(key).is_none() { - if is_required { - anyhow::bail!("{} is required", key); - } else { - Ok(()) - } - } else { - let raw_value = *raw_options.get(key).unwrap(); - match raw_value { - sqlparser::ast::Value::SingleQuotedString(str) => { - if accepted_values.is_some() { - let accepted_values = accepted_values.unwrap(); - if !accepted_values.contains(&str.as_str()) { - anyhow::bail!("{} must be one of {:?}", key, accepted_values); - } - } - processed_options - .insert(key.to_string(), serde_json::Value::String(str.clone())); - Ok(()) - } - _ => { - anyhow::bail!("invalid value for {}", key); - } - } - } - } - - fn parse_number_for_options( - raw_options: &HashMap<&str, &Value>, - processed_options: &mut HashMap, - key: &str, - min_value: u32, - default_value: u32, - ) -> anyhow::Result<()> { - if raw_options.get(key).is_none() { - processed_options.insert( - key.to_string(), - serde_json::Value::Number(Number::from(default_value)), - ); - Ok(()) - } else { - let raw_value = *raw_options.get(key).unwrap(); - match raw_value { - sqlparser::ast::Value::Number(str, _) => { - let value = str.parse::()?; - if value < min_value { - anyhow::bail!("{} must be greater than {}", key, min_value - 1); - } - processed_options.insert( - key.to_string(), - serde_json::Value::Number(Number::from(value)), - ); - Ok(()) - } - _ => { - anyhow::bail!("invalid value for {}", key); - } - } - } - } -} - impl StatementAnalyzer for PeerDDLAnalyzer { type Output = Option; @@ -234,119 +163,7 @@ impl StatementAnalyzer for PeerDDLAnalyzer { raw_options.insert(&option.name.value as &str, &option.value); } - let mut processed_options = HashMap::new(); - - // processing options that are REQUIRED and take a string value. - for key in [ - "destination_table_name", - "watermark_column", - "watermark_table_name", - ] { - PeerDDLAnalyzer::parse_string_for_options( - &raw_options, - &mut processed_options, - key, - true, - None, - )?; - } - PeerDDLAnalyzer::parse_string_for_options( - &raw_options, - &mut processed_options, - "mode", - true, - Some(&["append", "upsert"]), - )?; - // processing options that are OPTIONAL and take a string value. - PeerDDLAnalyzer::parse_string_for_options( - &raw_options, - &mut processed_options, - "unique_key_columns", - false, - None, - )?; - PeerDDLAnalyzer::parse_string_for_options( - &raw_options, - &mut processed_options, - "sync_data_format", - false, - Some(&["default", "avro"]), - )?; - PeerDDLAnalyzer::parse_string_for_options( - &raw_options, - &mut processed_options, - "staging_path", - false, - None, - )?; - // processing options that are OPTIONAL and take a number value which a minimum and default value. - PeerDDLAnalyzer::parse_number_for_options( - &raw_options, - &mut processed_options, - "parallelism", - 1, - 2, - )?; - PeerDDLAnalyzer::parse_number_for_options( - &raw_options, - &mut processed_options, - "refresh_interval", - 10, - 10, - )?; - PeerDDLAnalyzer::parse_number_for_options( - &raw_options, - &mut processed_options, - "batch_size_int", - 1, - 10000, - )?; - PeerDDLAnalyzer::parse_number_for_options( - &raw_options, - &mut processed_options, - "batch_duration_timestamp", - 1, - 60, - )?; - - if !processed_options.contains_key("sync_data_format") { - processed_options.insert( - "sync_data_format".to_string(), - serde_json::Value::String("default".to_string()), - ); - } - - // unique_key_columns should only be specified if mode is upsert - if processed_options.contains_key("unique_key_columns") - ^ (processed_options.get("mode").unwrap() == "upsert") - { - if processed_options.get("mode").unwrap() == "upsert" { - anyhow::bail!( - "unique_key_columns should be specified if mode is upsert" - ); - } else { - anyhow::bail!( - "mode should be upsert if unique_key_columns is specified" - ); - } - } - - if processed_options.contains_key("unique_key_columns") { - processed_options.insert( - "unique_key_columns".to_string(), - serde_json::Value::Array( - processed_options - .get("unique_key_columns") - .unwrap() - .as_str() - .unwrap() - .split(',') - .map(|s| serde_json::Value::String(s.to_string())) - .collect(), - ), - ); - } - + let processed_options = process_options(raw_options)?; let qrep_flow_job = QRepFlowJob { name: select.mirror_name.to_string().to_lowercase(), source_peer: select.source_peer.to_string().to_lowercase(), diff --git a/nexus/analyzer/src/qrep.rs b/nexus/analyzer/src/qrep.rs new file mode 100644 index 0000000000..46653cb9a0 --- /dev/null +++ b/nexus/analyzer/src/qrep.rs @@ -0,0 +1,167 @@ +use std::collections::HashMap; + +use serde_json::Value; +use sqlparser::ast::Value as SqlValue; + +enum QRepOptionType { + String { + name: &'static str, + default_val: Option<&'static str>, + required: bool, + accepted_values: Option>, + }, + Int { + name: &'static str, + min_value: Option, + default_value: u32, + required: bool, + }, + StringArray { + name: &'static str, + }, +} + +lazy_static::lazy_static! { + static ref QREP_OPTIONS: Vec = { + vec![ + QRepOptionType::String { + name: "destination_table_name", + default_val: None, + required: true, + accepted_values: None, + }, + QRepOptionType::String { + name: "watermark_column", + default_val: None, + required: true, + accepted_values: None, + }, + QRepOptionType::String { + name: "watermark_table_name", + default_val: None, + required: true, + accepted_values: None, + }, + QRepOptionType::String { + name: "mode", + default_val: Some("append"), + required: false, + accepted_values: Some(vec!["upsert", "append"]), + }, + QRepOptionType::StringArray { + name: "unique_key_columns", + }, + QRepOptionType::String { + name: "sync_data_format", + default_val: Some("default"), + required: false, + accepted_values: Some(vec!["default", "avro"]), + }, + QRepOptionType::String { + name: "staging_path", + default_val: None, + required: false, + accepted_values: None, + }, + QRepOptionType::Int { + name: "parallelism", + min_value: Some(1), + default_value: 2, + required: false, + }, + QRepOptionType::Int { + name: "refresh_interval", + min_value: Some(10), + default_value: 10, + required: false, + }, + QRepOptionType::Int { + name: "batch_size_int", + min_value: Some(1), + default_value: 1000, + required: false, + }, + QRepOptionType::Int { + name: "batch_duration_timestamp", + min_value: Some(1), + default_value: 60, + required: false, + }, + ] + }; +} + +pub fn process_options( + raw_opts: HashMap<&str, &SqlValue>, +) -> anyhow::Result> { + let mut opts: HashMap = HashMap::new(); + + for opt_type in &*QREP_OPTIONS { + match opt_type { + QRepOptionType::String { + name, + default_val, + required, + accepted_values, + } => { + if let Some(raw_value) = raw_opts.get(*name) { + if let SqlValue::SingleQuotedString(str) = raw_value { + if let Some(values) = accepted_values { + if !values.contains(&str.as_str()) { + anyhow::bail!("{} must be one of {:?}", name, values); + } + } + opts.insert((*name).to_string(), Value::String(str.clone())); + } else { + anyhow::bail!("Invalid value for {}", name); + } + } else if *required { + anyhow::bail!("{} is required", name); + } else if let Some(default) = default_val { + opts.insert((*name).to_string(), Value::String(default.to_string())); + } + } + QRepOptionType::Int { + name, + min_value, + default_value, + required, + } => { + if let Some(raw_value) = raw_opts.get(*name) { + if let SqlValue::Number(num_str, _) = raw_value { + let num = num_str.parse::()?; + if let Some(min) = min_value { + if num < *min { + anyhow::bail!("{} must be greater than {}", name, min); + } + } + opts.insert((*name).to_string(), Value::Number(num.into())); + } else { + anyhow::bail!("Invalid value for {}", name); + } + } else if *required { + anyhow::bail!("{} is required", name); + } else { + let v = *default_value; + opts.insert((*name).to_string(), Value::Number(v.into())); + } + } + QRepOptionType::StringArray { name } => { + // read it as a string and split on comma + if let Some(raw_value) = raw_opts.get(*name) { + if let SqlValue::SingleQuotedString(str) = raw_value { + let values: Vec = str + .split(',') + .map(|s| Value::String(s.trim().to_string())) + .collect(); + opts.insert((*name).to_string(), Value::Array(values)); + } else { + anyhow::bail!("Invalid value for {}", name); + } + } + } + } + } + + Ok(opts) +} diff --git a/nexus/flow-rs/src/lib.rs b/nexus/flow-rs/src/lib.rs index 4906d05beb..70bddd9cd9 100644 --- a/nexus/flow-rs/src/lib.rs +++ b/nexus/flow-rs/src/lib.rs @@ -31,7 +31,7 @@ pub struct QRepFlowJob { pub target_peer: String, pub query_string: String, pub flow_options: HashMap, - pub description: String + pub description: String, } impl FlowHandler { diff --git a/nexus/peer-bigquery/src/ast.rs b/nexus/peer-bigquery/src/ast.rs index 6a72f95d75..d6b47fe631 100644 --- a/nexus/peer-bigquery/src/ast.rs +++ b/nexus/peer-bigquery/src/ast.rs @@ -3,7 +3,7 @@ use std::ops::ControlFlow; use sqlparser::ast::Value::Number; use sqlparser::ast::{ - visit_expressions_mut, visit_function_arg, visit_function_arg_mut, visit_relations_mut, + visit_expressions_mut, visit_function_arg_mut, visit_relations_mut, visit_setexpr_mut, Array, BinaryOperator, DataType, DateTimeField, Expr, Function, FunctionArg, FunctionArgExpr, Ident, ObjectName, Query, SetExpr, SetOperator, SetQuantifier, TimezoneInfo, }; diff --git a/nexus/peer-bigquery/src/lib.rs b/nexus/peer-bigquery/src/lib.rs index 4534ec5564..4ed659f750 100644 --- a/nexus/peer-bigquery/src/lib.rs +++ b/nexus/peer-bigquery/src/lib.rs @@ -11,7 +11,11 @@ use peer_cursor::{CursorModification, QueryExecutor, QueryOutput, SchemaRef}; use pgerror::PgError; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; use pt::peers::BigqueryConfig; -use sqlparser::ast::{CloseCursor, Expr, FetchDirection, Statement, Value}; +use sqlparser::{ + ast::{CloseCursor, Expr, FetchDirection, Statement, Value}, + dialect::GenericDialect, + parser, +}; use stream::{BqRecordStream, BqSchema}; mod ast; @@ -41,7 +45,7 @@ pub async fn bq_client_from_config(config: BigqueryConfig) -> anyhow::Result anyhow::Result { + let sql = "SELECT 1;"; + let test_stmt = parser::Parser::parse_sql(&GenericDialect {}, sql).unwrap(); + let _ = self.execute(&test_stmt[0]).await?; + Ok(true) + } } diff --git a/nexus/peer-cursor/src/lib.rs b/nexus/peer-cursor/src/lib.rs index 37065cf163..7d2525a7df 100644 --- a/nexus/peer-cursor/src/lib.rs +++ b/nexus/peer-cursor/src/lib.rs @@ -51,4 +51,6 @@ pub trait QueryExecutor: Send + Sync { async fn execute(&self, stmt: &Statement) -> PgWireResult; async fn describe(&self, stmt: &Statement) -> PgWireResult>; + + async fn is_connection_valid(&self) -> anyhow::Result; } diff --git a/nexus/peer-cursor/src/util.rs b/nexus/peer-cursor/src/util.rs index 2bf7d0b7ea..6b134d3ef8 100644 --- a/nexus/peer-cursor/src/util.rs +++ b/nexus/peer-cursor/src/util.rs @@ -6,9 +6,9 @@ use pgwire::{ api::results::{DataRowEncoder, FieldInfo, QueryResponse, Response}, error::{PgWireError, PgWireResult}, }; -use value::{array::ArrayValue, Value}; +use value::{Value}; -use crate::{Record, Records, SchemaRef, SendableStream}; +use crate::{Records, SchemaRef, SendableStream}; fn encode_value(value: &Value, builder: &mut DataRowEncoder) -> PgWireResult<()> { match value { diff --git a/nexus/peer-postgres/src/ast.rs b/nexus/peer-postgres/src/ast.rs index e6ec466cbf..3e33a535d8 100644 --- a/nexus/peer-postgres/src/ast.rs +++ b/nexus/peer-postgres/src/ast.rs @@ -1,9 +1,6 @@ use std::ops::ControlFlow; -use sqlparser::ast::{ - visit_relations_mut, visit_statements_mut, ObjectName, ObjectType, Query, Statement, - TableFactor, TableWithJoins, -}; +use sqlparser::ast::{visit_relations_mut, visit_statements_mut, ObjectType, Query, Statement}; #[derive(Default)] pub struct PostgresAst { diff --git a/nexus/peer-postgres/src/lib.rs b/nexus/peer-postgres/src/lib.rs index 32e4eb7817..35fb5ad38f 100644 --- a/nexus/peer-postgres/src/lib.rs +++ b/nexus/peer-postgres/src/lib.rs @@ -16,7 +16,7 @@ mod stream; // PostgresQueryExecutor is a QueryExecutor that uses a Postgres database as its // backing store. pub struct PostgresQueryExecutor { - _config: PostgresConfig, + config: PostgresConfig, peername: Option, client: Box, } @@ -59,7 +59,7 @@ impl PostgresQueryExecutor { })?; Ok(Self { - _config: config.clone(), + config: config.clone(), peername, client: Box::new(client), }) @@ -173,4 +173,9 @@ impl QueryExecutor for PostgresQueryExecutor { _ => Ok(None), } } + + async fn is_connection_valid(&self) -> anyhow::Result { + let _ = PostgresQueryExecutor::new(None, &self.config).await?; + Ok(true) + } } diff --git a/nexus/peer-snowflake/src/auth.rs b/nexus/peer-snowflake/src/auth.rs index 812bec8f0e..8001e56821 100644 --- a/nexus/peer-snowflake/src/auth.rs +++ b/nexus/peer-snowflake/src/auth.rs @@ -3,6 +3,7 @@ use std::{ time::{SystemTime, UNIX_EPOCH}, }; +use anyhow::Context; use base64::encode as base64_encode; use jsonwebtoken::{encode as jwt_encode, Algorithm, EncodingKey, Header}; use pkcs1::EncodeRsaPrivateKey; @@ -44,13 +45,14 @@ impl SnowflakeAuth { private_key: String, refresh_threshold: u64, expiry_threshold: u64, - ) -> Self { + ) -> anyhow::Result { + let pkey = DecodePrivateKey::from_pkcs8_pem(&private_key).context("Invalid private key")?; let mut snowflake_auth: SnowflakeAuth = SnowflakeAuth { // moved normalized_account_id above account_id to satisfy the borrow checker. normalized_account_id: SnowflakeAuth::normalize_account_identifier(&account_id), account_id, username, - private_key: DecodePrivateKey::from_pkcs8_pem(&private_key).unwrap(), + private_key: pkey, public_key_fp: None, refresh_threshold, expiry_threshold, @@ -59,9 +61,10 @@ impl SnowflakeAuth { }; snowflake_auth.public_key_fp = Some(SnowflakeAuth::gen_public_key_fp( &snowflake_auth.private_key, - )); + )?); snowflake_auth.refresh_jwt(); - snowflake_auth + + Ok(snowflake_auth) } // Normalize the account identifer to a form that is embedded into the JWT. @@ -85,26 +88,21 @@ impl SnowflakeAuth { } #[tracing::instrument(name = "peer_sflake::gen_public_key_fp", skip_all)] - fn gen_public_key_fp(private_key: &RsaPrivateKey) -> String { - let public_key = - EncodePublicKey::to_public_key_der(&RsaPublicKey::from(private_key)).unwrap(); - format!( + fn gen_public_key_fp(private_key: &RsaPrivateKey) -> anyhow::Result { + let public_key = EncodePublicKey::to_public_key_der(&RsaPublicKey::from(private_key))?; + let res = format!( "SHA256:{}", base64_encode(Sha256::new_with_prefix(public_key.as_der()).finalize()) - ) + ); + Ok(res) } #[tracing::instrument(name = "peer_sflake::auth_refresh_jwt", skip_all)] - fn refresh_jwt(&mut self) { + fn refresh_jwt(&mut self) -> anyhow::Result<()> { let private_key_jwt: EncodingKey = EncodingKey::from_rsa_der( - EncodeRsaPrivateKey::to_pkcs1_der(&self.private_key) - .unwrap() - .as_der(), + EncodeRsaPrivateKey::to_pkcs1_der(&self.private_key)?.as_der(), ); - self.last_refreshed = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs(); + self.last_refreshed = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs(); info!( "Refreshing SnowFlake JWT for account: {} and user: {} at time {}", self.account_id, self.username, self.last_refreshed @@ -114,7 +112,9 @@ impl SnowflakeAuth { "{}.{}.{}", self.normalized_account_id, self.username.to_uppercase(), - self.public_key_fp.as_deref().unwrap() + self.public_key_fp + .as_deref() + .context("No public key fingerprint")? ), sub: format!( "{}.{}", @@ -125,21 +125,24 @@ impl SnowflakeAuth { exp: self.last_refreshed + self.expiry_threshold, }; let header: Header = Header::new(Algorithm::RS256); - self.current_jwt = Some( - SecretString::from_str(&jwt_encode(&header, &jwt_claims, &private_key_jwt).unwrap()) - .unwrap(), - ); + + let encoded_jwt = jwt_encode(&header, &jwt_claims, &private_key_jwt)?; + let secret = SecretString::from_str(&encoded_jwt)?; + + self.current_jwt = Some(secret); + + Ok(()) } - pub fn get_jwt(&mut self) -> &Secret { - if SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs() + pub fn get_jwt(&mut self) -> anyhow::Result<&Secret> { + if SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs() >= (self.last_refreshed + self.refresh_threshold) { - self.refresh_jwt(); + self.refresh_jwt()?; } - self.current_jwt.as_ref().unwrap() + + self.current_jwt + .as_ref() + .ok_or_else(|| anyhow::anyhow!("JWT not initialized. Please call refresh_jwt() first.")) } } diff --git a/nexus/peer-snowflake/src/lib.rs b/nexus/peer-snowflake/src/lib.rs index 0b7d14ca8f..ad4b262122 100644 --- a/nexus/peer-snowflake/src/lib.rs +++ b/nexus/peer-snowflake/src/lib.rs @@ -4,6 +4,8 @@ use cursor::SnowflakeCursorManager; use peer_cursor::{CursorModification, QueryExecutor, QueryOutput, SchemaRef}; use pgerror::PgError; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; +use sqlparser::dialect::GenericDialect; +use sqlparser::parser; use std::cmp::min; use std::{collections::HashMap, time::Duration}; use stream::SnowflakeDataType; @@ -84,6 +86,7 @@ pub struct ResultSet { struct PartitionResult { data: Vec>>, } + pub struct SnowflakeQueryExecutor { config: SnowflakeConfig, partition_number: usize, @@ -133,7 +136,7 @@ impl SnowflakeQueryExecutor { config.clone().private_key, DEFAULT_REFRESH_THRESHOLD, DEFAULT_EXPIRY_THRESHOLD, - ), + )?, query_timeout: config.query_timeout, reqwest_client, cursor_manager, @@ -166,12 +169,12 @@ impl SnowflakeQueryExecutor { ); let mut auth = self.auth.clone(); - let jwt = auth.get_jwt(); + let jwt = auth.get_jwt()?; let secret = jwt.expose_secret().clone(); // TODO: for things other than SELECTs, the robust way to handle retrys is by // generating a UUID from our end to mark the query as unique and then sending it with the request. // If we need to retry, send same UUID with retry=true parameter set and Snowflake should prevent duplicate execution. - let query_status = self + let query_status_res = self .reqwest_client .post(self.endpoint_url.to_owned()) .bearer_auth(secret) @@ -186,9 +189,14 @@ impl SnowflakeQueryExecutor { }) .send() .await - .map_err(|_| anyhow::anyhow!("failed in making request for QueryStatus"))? - .json::() - .await?; + .map_err(|e| { + anyhow::anyhow!("failed in making request for QueryStatus. error: {:?}", e) + })?; + + let query_json = query_status_res.json::().await?; + let query_status = serde_json::from_value(query_json.clone()).map_err(|e| { + anyhow::anyhow!("failed in parsing json {:?}, error: {:?}", query_json, e) + })?; // TODO: remove this blind retry logic for anything other than a SELECT. let res = self.query_poll(query_status).await?; @@ -197,6 +205,7 @@ impl SnowflakeQueryExecutor { None => self.process_query(query_str).await?, }) } + pub async fn query(&self, query: &Box) -> PgWireResult { let mut query = query.clone(); @@ -219,7 +228,7 @@ impl SnowflakeQueryExecutor { query_status: &QueryStatus, ) -> anyhow::Result { let mut auth = self.auth.clone(); - let jwt = auth.get_jwt(); + let jwt = auth.get_jwt()?; let secret = jwt.expose_secret().clone(); let response = self .reqwest_client @@ -430,8 +439,15 @@ impl QueryExecutor for SnowflakeQueryExecutor { _ => PgWireResult::Err(PgWireError::UserError(Box::new(ErrorInfo::new( "ERROR".to_owned(), "fdw_error".to_owned(), - "only SELECT statements are supported in bigquery".to_owned(), + "only SELECT statements are supported in snowflake".to_owned(), )))), } } + + async fn is_connection_valid(&self) -> anyhow::Result { + let sql = "SELECT 1;"; + let test_stmt = parser::Parser::parse_sql(&GenericDialect {}, sql)?; + let _ = self.execute(&test_stmt[0]).await?; + Ok(true) + } } diff --git a/nexus/peer-snowflake/src/stream.rs b/nexus/peer-snowflake/src/stream.rs index 57c5e21e54..0417163957 100644 --- a/nexus/peer-snowflake/src/stream.rs +++ b/nexus/peer-snowflake/src/stream.rs @@ -207,7 +207,7 @@ impl SnowflakeRecordStream { self.partition_number = self.partition_number + 1; self.partition_index = 0; let partition_number = self.partition_number; - let secret = self.auth.get_jwt().expose_secret().clone(); + let secret = self.auth.get_jwt()?.expose_secret().clone(); let statement_handle = self.result_set.statementHandle.clone(); let url = self.endpoint_url.clone(); println!("Secret: {:#?}", secret); diff --git a/nexus/server/src/main.rs b/nexus/server/src/main.rs index eea316c4cb..1b87c12aa5 100644 --- a/nexus/server/src/main.rs +++ b/nexus/server/src/main.rs @@ -158,6 +158,19 @@ impl NexusBackend { peer, if_not_exists: _, } => { + let peer_executor = self.get_peer_executor(&peer).await.map_err(|err| { + PgWireError::ApiError(Box::new(PgError::Internal { + err_msg: format!("unable to get peer executor: {:?}", err), + })) + })?; + peer_executor.is_connection_valid().await.map_err(|e| { + self.executors.remove(&peer.name); // Otherwise it will keep returning the earlier configured executor + PgWireError::UserError(Box::new(ErrorInfo::new( + "ERROR".to_owned(), + "internal_error".to_owned(), + format!("[peer]: invalid configuration: {}", e.to_string()), + ))) + })?; let catalog = self.catalog.lock().await; catalog.create_peer(peer.as_ref()).await.map_err(|e| { PgWireError::UserError(Box::new(ErrorInfo::new( @@ -219,7 +232,7 @@ impl NexusBackend { })) })?; // make a request to the flow service to start the job. - let workflow_id = self + let _workflow_id = self .flow_handler .start_qrep_flow_job(&qrep_flow_job) .await @@ -297,7 +310,11 @@ impl NexusBackend { QueryAssocation::Peer(peer) => { tracing::info!("handling peer[{}] query: {}", peer.name, stmt); peer_holder = Some(peer.clone()); - self.get_peer_executor(&peer).await + self.get_peer_executor(&peer).await.map_err(|err| { + PgWireError::ApiError(Box::new(PgError::Internal { + err_msg: format!("unable to get peer executor: {:?}", err), + })) + })? } QueryAssocation::Catalog => { tracing::info!("handling catalog query: {}", stmt); @@ -327,7 +344,11 @@ impl NexusBackend { let catalog = self.catalog.lock().await; catalog.get_executor() } - Some(peer) => self.get_peer_executor(peer).await, + Some(peer) => self.get_peer_executor(peer).await.map_err(|err| { + PgWireError::ApiError(Box::new(PgError::Internal { + err_msg: format!("unable to get peer executor: {:?}", err), + })) + })?, } }; @@ -338,32 +359,25 @@ impl NexusBackend { } } - async fn get_peer_executor(&self, peer: &Peer) -> Arc> { + async fn get_peer_executor(&self, peer: &Peer) -> anyhow::Result>> { if let Some(executor) = self.executors.get(&peer.name) { - return Arc::clone(executor.value()); + return Ok(Arc::clone(executor.value())); } let executor = match &peer.config { Some(Config::BigqueryConfig(ref c)) => { let peer_name = peer.name.clone(); let executor = - BigQueryQueryExecutor::new(peer_name, c, self.peer_connections.clone()) - .await - .unwrap(); + BigQueryQueryExecutor::new(peer_name, c, self.peer_connections.clone()).await?; Arc::new(Box::new(executor) as Box) } Some(Config::PostgresConfig(ref c)) => { let peername = Some(peer.name.clone()); - let executor = peer_postgres::PostgresQueryExecutor::new(peername, c) - .await - .unwrap(); + let executor = peer_postgres::PostgresQueryExecutor::new(peername, c).await?; Arc::new(Box::new(executor) as Box) } Some(Config::SnowflakeConfig(ref c)) => { - let peername = Some(peer.name.clone()); - let executor = peer_snowflake::SnowflakeQueryExecutor::new(c) - .await - .unwrap(); + let executor = peer_snowflake::SnowflakeQueryExecutor::new(c).await?; Arc::new(Box::new(executor) as Box) } _ => { @@ -373,7 +387,7 @@ impl NexusBackend { self.executors .insert(peer.name.clone(), Arc::clone(&executor)); - executor + Ok(executor) } } @@ -500,10 +514,45 @@ impl ExtendedQueryHandler for NexusBackend { // if the peer is of type bigquery, let us route the query to bq. match &peer.config { Some(Config::BigqueryConfig(_)) => { - let executor = self.get_peer_executor(peer).await; + let executor = + self.get_peer_executor(peer).await.map_err(|err| { + PgWireError::ApiError(Box::new(PgError::Internal { + err_msg: format!( + "unable to get peer executor: {:?}", + err + ), + })) + })?; + executor.describe(stmt).await? + } + Some(Config::PostgresConfig(_)) => { + let executor = + self.get_peer_executor(peer).await.map_err(|err| { + PgWireError::ApiError(Box::new(PgError::Internal { + err_msg: format!( + "unable to get peer executor: {:?}", + err + ), + })) + })?; executor.describe(stmt).await? } - _ => { + Some(Config::SnowflakeConfig(_)) => { + let executor = + self.get_peer_executor(peer).await.map_err(|err| { + PgWireError::ApiError(Box::new(PgError::Internal { + err_msg: format!( + "unable to get peer executor: {:?}", + err + ), + })) + })?; + executor.describe(stmt).await? + } + Some(Config::MongoConfig(_)) => { + panic!("peer type not supported: {:?}", peer) + } + None => { panic!("peer type not supported: {:?}", peer) } } diff --git a/nexus/value/src/array.rs b/nexus/value/src/array.rs index f48694efb0..9b50c8c679 100644 --- a/nexus/value/src/array.rs +++ b/nexus/value/src/array.rs @@ -149,8 +149,8 @@ impl<'a> ToSql for ArrayValue { } ArrayValue::VarChar(arr) => arr.to_sql(ty, out)?, ArrayValue::Text(arr) => arr.to_sql(ty, out)?, - ArrayValue::Binary(arr) => todo!("support encoding array of binary"), - ArrayValue::VarBinary(arr) => todo!("support encoding array of varbinary"), + ArrayValue::Binary(_arr) => todo!("support encoding array of binary"), + ArrayValue::VarBinary(_arr) => todo!("support encoding array of varbinary"), ArrayValue::Date(arr) => arr.to_sql(ty, out)?, ArrayValue::Time(arr) => arr.to_sql(ty, out)?, ArrayValue::TimeWithTimeZone(arr) => arr.to_sql(ty, out)?, @@ -227,8 +227,8 @@ impl ToSqlText for ArrayValue { ArrayValue::Char(arr) => array_to_sql_text!(arr, ty, out), ArrayValue::VarChar(arr) => array_to_sql_text!(arr, ty, out), ArrayValue::Text(arr) => array_to_sql_text!(arr, ty, out), - ArrayValue::Binary(arr) => todo!("implement encoding array of binary"), - ArrayValue::VarBinary(arr) => todo!("implement encoding array of varbinary"), + ArrayValue::Binary(_arr) => todo!("implement encoding array of binary"), + ArrayValue::VarBinary(_arr) => todo!("implement encoding array of varbinary"), ArrayValue::Date(arr) => array_to_sql_text!(arr, ty, out), ArrayValue::Time(arr) => array_to_sql_text!(arr, ty, out), ArrayValue::TimeWithTimeZone(arr) => array_to_sql_text!(arr, ty, out),