diff --git a/.github/workflows/flow.yml b/.github/workflows/flow.yml
index 4bc3ef44d8..2adf8c619d 100644
--- a/.github/workflows/flow.yml
+++ b/.github/workflows/flow.yml
@@ -57,7 +57,7 @@ jobs:
- name: run tests
run: |
- gotestsum --format testname -- -p 1 ./...
+ gotestsum --format testname -- -p 1 ./... -timeout 1200s
working-directory: ./flow
env:
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
diff --git a/README.md b/README.md
index 9a00b594c7..135a4f8507 100644
--- a/README.md
+++ b/README.md
@@ -1,11 +1,12 @@
## PeerDB
diff --git a/flow/connectors/bigquery/bigquery.go b/flow/connectors/bigquery/bigquery.go
index 798d4a6e3a..d4c9f9df1c 100644
--- a/flow/connectors/bigquery/bigquery.go
+++ b/flow/connectors/bigquery/bigquery.go
@@ -15,6 +15,7 @@ import (
"github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/PeerDB-io/peer-flow/model"
+ "github.com/PeerDB-io/peer-flow/model/qvalue"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"google.golang.org/api/iterator"
@@ -442,8 +443,8 @@ func (c *BigQueryConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S
// 1. _peerdb_uid - uuid
// 2. _peerdb_timestamp - current timestamp
// 2. _peerdb_timestamp_nanos - current timestamp in nano seconds
- // 3. _peerdb_data - json of `r.Items`
- json, err := json.Marshal(r.Items)
+ // 3. _peerdb_data - itemsJSON of `r.Items`
+ itemsJSON, err := r.Items.ToJSON()
if err != nil {
return nil, fmt.Errorf("failed to create items to json: %v", err)
}
@@ -454,7 +455,7 @@ func (c *BigQueryConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S
timestamp: time.Now(),
timestampNanos: time.Now().UnixNano(),
destinationTableName: r.DestinationTableName,
- data: string(json),
+ data: itemsJSON,
recordType: 0,
matchData: "",
batchID: syncBatchID,
@@ -469,12 +470,12 @@ func (c *BigQueryConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S
// 4. _peerdb_record_type - 1
// 5. _peerdb_match_data - json of `r.OldItems`
- newItemsJSON, err := json.Marshal(r.NewItems)
+ newItemsJSON, err := r.NewItems.ToJSON()
if err != nil {
return nil, fmt.Errorf("failed to create new items to json: %v", err)
}
- oldItemsJSON, err := json.Marshal(r.OldItems)
+ oldItemsJSON, err := r.OldItems.ToJSON()
if err != nil {
return nil, fmt.Errorf("failed to create old items to json: %v", err)
}
@@ -485,9 +486,9 @@ func (c *BigQueryConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S
timestamp: time.Now(),
timestampNanos: time.Now().UnixNano(),
destinationTableName: r.DestinationTableName,
- data: string(newItemsJSON),
+ data: newItemsJSON,
recordType: 1,
- matchData: string(oldItemsJSON),
+ matchData: oldItemsJSON,
batchID: syncBatchID,
stagingBatchID: stagingBatchID,
unchangedToastColumns: utils.KeysToString(r.UnchangedToastColumns),
@@ -500,7 +501,7 @@ func (c *BigQueryConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S
// 4. _peerdb_match_data - json of `r.Items`
// json.Marshal converts bytes in Hex automatically to BASE64 string.
- itemsJSON, err := json.Marshal(r.Items)
+ itemsJSON, err := r.Items.ToJSON()
if err != nil {
return nil, fmt.Errorf("failed to create items to json: %v", err)
}
@@ -511,9 +512,9 @@ func (c *BigQueryConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.S
timestamp: time.Now(),
timestampNanos: time.Now().UnixNano(),
destinationTableName: r.DestinationTableName,
- data: string(itemsJSON),
+ data: itemsJSON,
recordType: 2,
- matchData: string(itemsJSON),
+ matchData: itemsJSON,
batchID: syncBatchID,
stagingBatchID: stagingBatchID,
unchangedToastColumns: utils.KeysToString(r.UnchangedToastColumns),
@@ -823,7 +824,7 @@ func (c *BigQueryConnector) SetupNormalizedTable(
for colName, genericColType := range sourceSchema.Columns {
columns[idx] = &bigquery.FieldSchema{
Name: colName,
- Type: getBigQueryColumnTypeForGenericColType(genericColType),
+ Type: qValueKindToBigQueryType(genericColType),
}
idx++
}
@@ -923,45 +924,6 @@ func (c *BigQueryConnector) truncateTable(tableIdentifier string) error {
return nil
}
-func getBigQueryColumnTypeForGenericColType(colType string) bigquery.FieldType {
- switch colType {
- // boolean
- case model.ColumnTypeBoolean:
- return bigquery.BooleanFieldType
- // integer types
- case model.ColumnTypeInt16, model.ColumnTypeInt32, model.ColumnTypeInt64:
- return bigquery.IntegerFieldType
- // decimal types
- case model.ColumnTypeFloat16, model.ColumnTypeFloat32, model.ColumnTypeFloat64:
- return bigquery.FloatFieldType
- case model.ColumnTypeNumeric:
- return bigquery.NumericFieldType
- // string related
- case model.ColumnTypeString:
- return bigquery.StringFieldType
- // json also is stored as string for now
- case model.ColumnTypeJSON:
- return bigquery.StringFieldType
- // time related
- case model.ColumnTypeTimestamp, model.ColumnTypeTimeStampWithTimeZone:
- return bigquery.TimestampFieldType
- case model.ColumnTypeTime:
- return bigquery.TimeFieldType
- case model.ColumnTypeTimeWithTimeZone:
- return bigquery.StringFieldType
- case model.ColumnTypeDate:
- return bigquery.TimestampFieldType
- case model.ColumnTypeInterval:
- return bigquery.IntervalFieldType
- // bytes
- case model.ColumnHexBytes, model.ColumnHexBit:
- return bigquery.BytesFieldType
- // rest will be strings
- default:
- return bigquery.StringFieldType
- }
-}
-
type MergeStmtGenerator struct {
// dataset of all the tables
Dataset string
@@ -1003,40 +965,42 @@ func (m *MergeStmtGenerator) generateFlattenedCTE() string {
// statement.
flattenedProjs := make([]string, 0)
for colName, colType := range m.NormalizedTableSchema.Columns {
- bqType := getBigQueryColumnTypeForGenericColType(colType)
+ bqType := qValueKindToBigQueryType(colType)
// CAST doesn't work for FLOAT, so rewrite it to FLOAT64.
if bqType == bigquery.FloatFieldType {
bqType = "FLOAT64"
}
var castStmt string
- switch colType {
- case model.ColumnTypeJSON:
+ switch qvalue.QValueKind(colType) {
+ case qvalue.QValueKindJSON:
//if the type is JSON, then just extract JSON
castStmt = fmt.Sprintf("CAST(JSON_EXTRACT(_peerdb_data, '$.%s') AS %s) AS %s",
colName, bqType, colName)
// expecting data in BASE64 format
- case model.ColumnHexBytes:
+ case qvalue.QValueKindBytes:
castStmt = fmt.Sprintf("FROM_BASE64(JSON_EXTRACT_SCALAR(_peerdb_data, '$.%s')) AS %s",
colName, colName)
// MAKE_INTERVAL(years INT64, months INT64, days INT64, hours INT64, minutes INT64, seconds INT64)
// Expecting interval to be in the format of {"Microseconds":2000000,"Days":0,"Months":0,"Valid":true}
// json.Marshal in SyncRecords for Postgres already does this - once new data-stores are added,
// this needs to be handled again
- case model.ColumnTypeInterval:
- castStmt = fmt.Sprintf("MAKE_INTERVAL(0,CAST(JSON_EXTRACT_SCALAR(_peerdb_data, '$.%s.Months') AS INT64),"+
- "CAST(JSON_EXTRACT_SCALAR(_peerdb_data, '$.%s.Days') AS INT64),0,0,"+
- "CAST(CAST(JSON_EXTRACT_SCALAR(_peerdb_data, '$.%s.Microseconds') AS INT64)/1000000 AS INT64)) AS %s",
- colName, colName, colName, colName)
- case model.ColumnHexBit:
+ // TODO add interval types again
+ // case model.ColumnTypeInterval:
+ // castStmt = fmt.Sprintf("MAKE_INTERVAL(0,CAST(JSON_EXTRACT_SCALAR(_peerdb_data, '$.%s.Months') AS INT64),"+
+ // "CAST(JSON_EXTRACT_SCALAR(_peerdb_data, '$.%s.Days') AS INT64),0,0,"+
+ // "CAST(CAST(JSON_EXTRACT_SCALAR(_peerdb_data, '$.%s.Microseconds') AS INT64)/1000000 AS INT64)) AS %s",
+ // colName, colName, colName, colName)
+ case qvalue.QValueKindBit:
// sample raw data for BIT {"a":{"Bytes":"oA==","Len":3,"Valid":true},"id":1}
// need to check correctness TODO
castStmt = fmt.Sprintf("FROM_BASE64(JSON_EXTRACT_SCALAR(_peerdb_data, '$.%s.Bytes')) AS %s",
colName, colName)
- case model.ColumnTypeTime:
- castStmt = fmt.Sprintf("time(timestamp_micros(CAST(JSON_EXTRACT(_peerdb_data, '$.%s.Microseconds')"+
- " AS int64))) AS %s",
- colName, colName)
+ // TODO add proper granularity for time types, then restore this
+ // case model.ColumnTypeTime:
+ // castStmt = fmt.Sprintf("time(timestamp_micros(CAST(JSON_EXTRACT(_peerdb_data, '$.%s.Microseconds')"+
+ // " AS int64))) AS %s",
+ // colName, colName)
default:
castStmt = fmt.Sprintf("CAST(JSON_EXTRACT_SCALAR(_peerdb_data, '$.%s') AS %s) AS %s",
colName, bqType, colName)
diff --git a/flow/connectors/bigquery/qrecord_value_saver.go b/flow/connectors/bigquery/qrecord_value_saver.go
index 7bb82d8f94..672f7ee0b3 100644
--- a/flow/connectors/bigquery/qrecord_value_saver.go
+++ b/flow/connectors/bigquery/qrecord_value_saver.go
@@ -63,10 +63,17 @@ func (q QRecordValueSaver) Save() (map[string]bigquery.Value, string, error) {
}
bqValues[k] = val
+ case qvalue.QValueKindInt16:
+ val, ok := v.Value.(int16)
+ if !ok {
+ return nil, "", fmt.Errorf("failed to convert %v to int16", v.Value)
+ }
+ bqValues[k] = val
+
case qvalue.QValueKindInt32:
val, ok := v.Value.(int32)
if !ok {
- return nil, "", fmt.Errorf("failed to convert %v to int64", v.Value)
+ return nil, "", fmt.Errorf("failed to convert %v to int32", v.Value)
}
bqValues[k] = val
@@ -91,12 +98,12 @@ func (q QRecordValueSaver) Save() (map[string]bigquery.Value, string, error) {
}
bqValues[k] = val
- case qvalue.QValueKindETime:
- val, ok := v.Value.(*qvalue.ExtendedTime)
- if !ok {
- return nil, "", fmt.Errorf("failed to convert %v to ExtendedTime", v.Value)
+ case qvalue.QValueKindTimestamp, qvalue.QValueKindDate, qvalue.QValueKindTime:
+ var err error
+ bqValues[k], err = v.GoTimeConvert()
+ if err != nil {
+ return nil, "", fmt.Errorf("failed to convert parse %v into time.Time", v)
}
- bqValues[k] = val.Time
case qvalue.QValueKindNumeric:
val, ok := v.Value.(*big.Rat)
diff --git a/flow/connectors/bigquery/qrep_avro_sync.go b/flow/connectors/bigquery/qrep_avro_sync.go
index 09703625fc..58d19fbc8b 100644
--- a/flow/connectors/bigquery/qrep_avro_sync.go
+++ b/flow/connectors/bigquery/qrep_avro_sync.go
@@ -216,13 +216,13 @@ func GetAvroType(bqField *bigquery.FieldSchema) (interface{}, error) {
}, nil
case bigquery.DateFieldType:
return map[string]string{
- "type": "int",
- "logicalType": "date",
+ "type": "long",
+ "logicalType": "timestamp-micros",
}, nil
case bigquery.TimeFieldType:
return map[string]string{
"type": "long",
- "logicalType": "time-micros",
+ "logicalType": "timestamp-micros",
}, nil
case bigquery.DateTimeFieldType:
return map[string]interface{}{
diff --git a/flow/connectors/bigquery/qrep_sync_method.go b/flow/connectors/bigquery/qrep_sync_method.go
index 8dd998c973..92a9d4142d 100644
--- a/flow/connectors/bigquery/qrep_sync_method.go
+++ b/flow/connectors/bigquery/qrep_sync_method.go
@@ -94,7 +94,11 @@ func (s *QRepStagingTableSync) SyncQRepRecords(
// col names for the destination table joined by comma
colNames := []string{}
for _, col := range dstTableMetadata.Schema {
- colNames = append(colNames, col.Name)
+ if strings.ToLower(col.Name) == "from" {
+ colNames = append(colNames, "`from`")
+ } else {
+ colNames = append(colNames, col.Name)
+ }
}
colNamesStr := strings.Join(colNames, ", ")
diff --git a/flow/connectors/bigquery/qvalue_convert.go b/flow/connectors/bigquery/qvalue_convert.go
new file mode 100644
index 0000000000..fabdfb97f2
--- /dev/null
+++ b/flow/connectors/bigquery/qvalue_convert.go
@@ -0,0 +1,76 @@
+package connbigquery
+
+import (
+ "fmt"
+
+ "cloud.google.com/go/bigquery"
+ "github.com/PeerDB-io/peer-flow/model/qvalue"
+)
+
+func qValueKindToBigQueryType(colType string) bigquery.FieldType {
+ switch qvalue.QValueKind(colType) {
+ // boolean
+ case qvalue.QValueKindBoolean:
+ return bigquery.BooleanFieldType
+ // integer types
+ case qvalue.QValueKindInt16, qvalue.QValueKindInt32, qvalue.QValueKindInt64:
+ return bigquery.IntegerFieldType
+ // decimal types
+ case qvalue.QValueKindFloat32, qvalue.QValueKindFloat64:
+ return bigquery.FloatFieldType
+ case qvalue.QValueKindNumeric:
+ return bigquery.NumericFieldType
+ // string related
+ case qvalue.QValueKindString:
+ return bigquery.StringFieldType
+ // json also is stored as string for now
+ case qvalue.QValueKindJSON:
+ return bigquery.StringFieldType
+ // time related
+ case qvalue.QValueKindTimestamp, qvalue.QValueKindTimestampTZ:
+ return bigquery.TimestampFieldType
+ // TODO: https://github.com/PeerDB-io/peerdb/issues/189 - DATE support is incomplete
+ case qvalue.QValueKindDate:
+ return bigquery.DateFieldType
+ // TODO: https://github.com/PeerDB-io/peerdb/issues/189 - TIME/TIMETZ support is incomplete
+ case qvalue.QValueKindTime, qvalue.QValueKindTimeTZ:
+ return bigquery.TimeFieldType
+ // TODO: https://github.com/PeerDB-io/peerdb/issues/189 - handle INTERVAL types again,
+ // bytes
+ case qvalue.QValueKindBit, qvalue.QValueKindBytes:
+ return bigquery.BytesFieldType
+ // rest will be strings
+ default:
+ return bigquery.StringFieldType
+ }
+}
+
+// bigqueryTypeToQValueKind converts a bigquery FieldType to a QValueKind.
+func BigQueryTypeToQValueKind(fieldType bigquery.FieldType) (qvalue.QValueKind, error) {
+ switch fieldType {
+ case bigquery.StringFieldType:
+ return qvalue.QValueKindString, nil
+ case bigquery.BytesFieldType:
+ return qvalue.QValueKindBytes, nil
+ case bigquery.IntegerFieldType:
+ return qvalue.QValueKindInt64, nil
+ case bigquery.FloatFieldType:
+ return qvalue.QValueKindFloat64, nil
+ case bigquery.BooleanFieldType:
+ return qvalue.QValueKindBoolean, nil
+ case bigquery.TimestampFieldType:
+ return qvalue.QValueKindTimestamp, nil
+ case bigquery.DateFieldType:
+ return qvalue.QValueKindDate, nil
+ case bigquery.TimeFieldType:
+ return qvalue.QValueKindTime, nil
+ case bigquery.RecordFieldType:
+ return qvalue.QValueKindStruct, nil
+ case bigquery.NumericFieldType:
+ return qvalue.QValueKindNumeric, nil
+ case bigquery.GeographyFieldType:
+ return qvalue.QValueKindString, nil
+ default:
+ return "", fmt.Errorf("unsupported bigquery field type: %v", fieldType)
+ }
+}
diff --git a/flow/connectors/postgres/cdc.go b/flow/connectors/postgres/cdc.go
index 411e9ee41b..a71e229887 100644
--- a/flow/connectors/postgres/cdc.go
+++ b/flow/connectors/postgres/cdc.go
@@ -7,6 +7,7 @@ import (
"time"
"github.com/PeerDB-io/peer-flow/model"
+ "github.com/PeerDB-io/peer-flow/model/qvalue"
"github.com/jackc/pglogrepl"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgproto3"
@@ -386,34 +387,34 @@ It takes a tuple and a relation message as input and returns
func (p *PostgresCDCSource) convertTupleToMap(
tuple *pglogrepl.TupleData,
rel *pglogrepl.RelationMessage,
-) (map[string]interface{}, map[string]bool, error) {
+) (model.RecordItems, map[string]bool, error) {
// if the tuple is nil, return an empty map
if tuple == nil {
- return make(map[string]interface{}), make(map[string]bool), nil
+ return make(model.RecordItems), make(map[string]bool), nil
}
// create empty map of string to interface{}
- items := make(map[string]interface{})
+ items := make(model.RecordItems)
unchangedToastColumns := make(map[string]bool)
for idx, col := range tuple.Columns {
colName := rel.Columns[idx].Name
switch col.DataType {
case 'n': // null
- items[colName] = nil
+ items[colName] = qvalue.QValue{Kind: qvalue.QValueKindInvalid, Value: nil}
case 't': // text
/* bytea also appears here as a hex */
- data, err := p.decodeTextColumnData(col.Data, rel.Columns[idx].DataType)
+ data, err := p.decodeColumnData(col.Data, rel.Columns[idx].DataType, pgtype.TextFormatCode)
if err != nil {
return nil, nil, fmt.Errorf("error decoding text column data: %w", err)
}
- items[colName] = data
+ items[colName] = *data
case 'b': // binary
- data, err := p.decodeBinaryColumnData(col.Data, rel.Columns[idx].DataType)
+ data, err := p.decodeColumnData(col.Data, rel.Columns[idx].DataType, pgtype.BinaryFormatCode)
if err != nil {
return nil, nil, fmt.Errorf("error decoding binary column data: %w", err)
}
- items[colName] = data
+ items[colName] = *data
case 'u': // unchanged toast
unchangedToastColumns[colName] = true
default:
@@ -423,21 +424,25 @@ func (p *PostgresCDCSource) convertTupleToMap(
return items, unchangedToastColumns, nil
}
-func (p *PostgresCDCSource) decodeTextColumnData(data []byte, dataType uint32) (interface{}, error) {
+func (p *PostgresCDCSource) decodeColumnData(data []byte, dataType uint32, formatCode int16) (*qvalue.QValue, error) {
+ var parsedData any
+ var err error
if dt, ok := p.typeMap.TypeForOID(dataType); ok {
if dt.Name == "uuid" {
// below is required to decode uuid to string
- return dt.Codec.DecodeDatabaseSQLValue(p.typeMap, dataType, pgtype.TextFormatCode, data)
- }
- return dt.Codec.DecodeValue(p.typeMap, dataType, pgtype.TextFormatCode, data)
- }
- return string(data), nil
-}
+ parsedData, err = dt.Codec.DecodeDatabaseSQLValue(p.typeMap, dataType, pgtype.TextFormatCode, data)
-// decodeBinaryColumnData decodes the binary data for a column
-func (p *PostgresCDCSource) decodeBinaryColumnData(data []byte, dataType uint32) (interface{}, error) {
- if dt, ok := p.typeMap.TypeForOID(dataType); ok {
- return dt.Codec.DecodeValue(p.typeMap, dataType, pgtype.BinaryFormatCode, data)
+ } else {
+ parsedData, err = dt.Codec.DecodeValue(p.typeMap, dataType, formatCode, data)
+ }
+ if err != nil {
+ return nil, err
+ }
+ retVal, err := parseFieldFromPostgresOID(dataType, parsedData)
+ if err != nil {
+ return nil, err
+ }
+ return retVal, nil
}
- return string(data), nil
+ return &qvalue.QValue{Kind: qvalue.QValueKindString, Value: string(data)}, nil
}
diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go
index d9d228f8b8..909423a3ae 100644
--- a/flow/connectors/postgres/postgres.go
+++ b/flow/connectors/postgres/postgres.go
@@ -7,6 +7,7 @@ import (
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/PeerDB-io/peer-flow/model"
+ "github.com/PeerDB-io/peer-flow/model/qvalue"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
log "github.com/sirupsen/logrus"
@@ -271,17 +272,9 @@ func (c *PostgresConnector) GetTableSchema(req *protos.GetTableSchemaInput) (*pr
return nil, err
}
- relID, err := c.getRelIDForTable(schemaTable)
- if err != nil {
- return nil, err
- }
-
// Get the column names and types
rows, err := c.pool.Query(c.ctx,
- `SELECT a.attname, t.typname FROM pg_attribute a
- JOIN pg_type t ON t.oid = a.atttypid
- WHERE a.attnum > 0 AND NOT a.attisdropped AND a.attrelid = $1`,
- relID)
+ fmt.Sprintf(`SELECT * FROM %s LIMIT 0`, req.TableIdentifier))
if err != nil {
return nil, fmt.Errorf("error getting table schema for table %s: %w", schemaTable, err)
}
@@ -298,20 +291,13 @@ func (c *PostgresConnector) GetTableSchema(req *protos.GetTableSchemaInput) (*pr
PrimaryKeyColumn: pkey,
}
- for rows.Next() {
- var colName string
- var colType string
- err = rows.Scan(&colName, &colType)
- if err != nil {
- return nil, fmt.Errorf("error scanning table schema: %w", err)
- }
-
- colType, err = convertPostgresColumnTypeToGeneric(colType)
- if err != nil {
- return nil, fmt.Errorf("error converting postgres column type: %w", err)
+ for _, fieldDescription := range rows.FieldDescriptions() {
+ genericColType := getQValueKindForPostgresOID(fieldDescription.DataTypeOID)
+ if genericColType == qvalue.QValueKindInvalid {
+ return nil, fmt.Errorf("error converting Postgres OID to QValueKind")
}
- res.Columns[colName] = colType
+ res.Columns[fieldDescription.Name] = string(genericColType)
}
if err = rows.Err(); err != nil {
@@ -450,77 +436,6 @@ func (c *PostgresConnector) getPrimaryKeyColumn(schemaTable *SchemaTable) (strin
return pkCol, nil
}
-func convertPostgresColumnTypeToGeneric(colType string) (string, error) {
- switch colType {
- case "int2":
- return model.ColumnTypeInt16, nil
- case "int4":
- return model.ColumnTypeInt32, nil
- case "int8":
- return model.ColumnTypeInt64, nil
- case "float4":
- return model.ColumnTypeFloat32, nil
- case "float8":
- return model.ColumnTypeFloat64, nil
- case "bool":
- return model.ColumnTypeBoolean, nil
- case "text":
- return model.ColumnTypeString, nil
- case "date":
- return model.ColumnTypeDate, nil
- case "timestamp":
- return model.ColumnTypeTimestamp, nil
- case "timestamptz":
- return model.ColumnTypeTimeStampWithTimeZone, nil
- case "varchar":
- return model.ColumnTypeString, nil
- case "char":
- return model.ColumnTypeString, nil
- case "bpchar":
- return model.ColumnTypeString, nil
- case "numeric":
- return model.ColumnTypeNumeric, nil
- case "uuid":
- return model.ColumnTypeString, nil
- case "json":
- return model.ColumnTypeJSON, nil
- case "jsonb":
- return model.ColumnTypeJSON, nil
- case "xml":
- return model.ColumnTypeString, nil
- case "tsvector":
- return model.ColumnTypeString, nil
- case "tsquery":
- return model.ColumnTypeString, nil
- case "bytea":
- return model.ColumnHexBytes, nil
- case "bit":
- return model.ColumnHexBit, nil
- case "varbit":
- return model.ColumnHexBit, nil
- case "cidr":
- return model.ColumnTypeString, nil
- case "inet":
- return model.ColumnTypeString, nil
- case "interval":
- return model.ColumnTypeInterval, nil
- case "macaddr":
- return model.ColumnTypeString, nil
- case "money":
- return model.ColumnTypeString, nil
- case "oid":
- return model.ColumnTypeInt64, nil
- case "time":
- return model.ColumnTypeTime, nil
- case "timetz":
- return model.ColumnTypeTimeWithTimeZone, nil
- case "txid_snapshot":
- return model.ColumnTypeString, nil
- default:
- return "", fmt.Errorf("unsupported column type: %s", colType)
- }
-}
-
func (c *PostgresConnector) tableExists(schemaTable *SchemaTable) (bool, error) {
var exists bool
err := c.pool.QueryRow(c.ctx,
diff --git a/flow/connectors/postgres/postgres_cdc_test.go b/flow/connectors/postgres/postgres_cdc_test.go
index e45d0f117f..af32f112fc 100644
--- a/flow/connectors/postgres/postgres_cdc_test.go
+++ b/flow/connectors/postgres/postgres_cdc_test.go
@@ -9,6 +9,7 @@ import (
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/PeerDB-io/peer-flow/model"
+ "github.com/PeerDB-io/peer-flow/model/qvalue"
"github.com/stretchr/testify/suite"
)
@@ -133,11 +134,11 @@ func (suite *PostgresCDCTestSuite) validateInsertedToastRecords(records []model.
suite.Equal(dstTableName, insertRecord.DestinationTableName)
suite.Equal(5, len(insertRecord.Items))
- suite.Equal(int32(idx+1), insertRecord.Items["id"].(int32))
- suite.Equal(32768, len(insertRecord.Items["n_t"].(string)))
- suite.Equal(32768, len(insertRecord.Items["lz4_t"].(string)))
- suite.Equal(32768, len(insertRecord.Items["n_b"].([]byte)))
- suite.Equal(32768, len(insertRecord.Items["lz4_b"].([]byte)))
+ suite.Equal(int32(idx+1), insertRecord.Items["id"].Value.(int32))
+ suite.Equal(32768, len(insertRecord.Items["n_t"].Value.(string)))
+ suite.Equal(32768, len(insertRecord.Items["lz4_t"].Value.(string)))
+ suite.Equal(32768, len(insertRecord.Items["n_b"].Value.([]byte)))
+ suite.Equal(32768, len(insertRecord.Items["lz4_b"].Value.([]byte)))
}
}
@@ -175,8 +176,8 @@ func (suite *PostgresCDCTestSuite) validateMutatedToastRecords(records []model.R
suite.Equal(srcTableName, updateRecord.SourceTableName)
suite.Equal(dstTableName, updateRecord.DestinationTableName)
suite.Equal(2, len(updateRecord.NewItems))
- suite.Equal(int32(1), updateRecord.NewItems["id"].(int32))
- suite.Equal(65536, len(updateRecord.NewItems["n_t"].(string)))
+ suite.Equal(int32(1), updateRecord.NewItems["id"].Value.(int32))
+ suite.Equal(65536, len(updateRecord.NewItems["n_t"].Value.(string)))
suite.Equal(3, len(updateRecord.UnchangedToastColumns))
suite.True(updateRecord.UnchangedToastColumns["lz4_t"])
suite.True(updateRecord.UnchangedToastColumns["n_b"])
@@ -187,8 +188,8 @@ func (suite *PostgresCDCTestSuite) validateMutatedToastRecords(records []model.R
suite.Equal(srcTableName, updateRecord.SourceTableName)
suite.Equal(dstTableName, updateRecord.DestinationTableName)
suite.Equal(2, len(updateRecord.NewItems))
- suite.Equal(int32(2), updateRecord.NewItems["id"].(int32))
- suite.Equal(65536, len(updateRecord.NewItems["lz4_b"].([]byte)))
+ suite.Equal(int32(2), updateRecord.NewItems["id"].Value.(int32))
+ suite.Equal(65536, len(updateRecord.NewItems["lz4_b"].Value.([]byte)))
suite.Equal(3, len(updateRecord.UnchangedToastColumns))
suite.True(updateRecord.UnchangedToastColumns["lz4_t"])
suite.True(updateRecord.UnchangedToastColumns["n_b"])
@@ -199,8 +200,8 @@ func (suite *PostgresCDCTestSuite) validateMutatedToastRecords(records []model.R
suite.Equal(srcTableName, updateRecord.SourceTableName)
suite.Equal(dstTableName, updateRecord.DestinationTableName)
suite.Equal(2, len(updateRecord.NewItems))
- suite.Equal(int32(3), updateRecord.NewItems["id"].(int32))
- suite.Equal(65536, len(updateRecord.NewItems["n_b"].([]byte)))
+ suite.Equal(int32(3), updateRecord.NewItems["id"].Value.(int32))
+ suite.Equal(65536, len(updateRecord.NewItems["n_b"].Value.([]byte)))
suite.Equal(3, len(updateRecord.UnchangedToastColumns))
suite.True(updateRecord.UnchangedToastColumns["lz4_t"])
suite.True(updateRecord.UnchangedToastColumns["n_t"])
@@ -211,7 +212,7 @@ func (suite *PostgresCDCTestSuite) validateMutatedToastRecords(records []model.R
suite.Equal(srcTableName, deleteRecord.SourceTableName)
suite.Equal(dstTableName, deleteRecord.DestinationTableName)
suite.Equal(5, len(deleteRecord.Items))
- suite.Equal(int32(3), deleteRecord.Items["id"].(int32))
+ suite.Equal(int32(3), deleteRecord.Items["id"].Value.(int32))
suite.Nil(deleteRecord.Items["n_t"])
suite.Nil(deleteRecord.Items["lz4_t"])
suite.Nil(deleteRecord.Items["n_b"])
@@ -397,8 +398,8 @@ func (suite *PostgresCDCTestSuite) TestErrorForTableNotExist() {
tableNameSchemaMapping[nonExistentFlowDstTableName] = &protos.TableSchema{
TableIdentifier: nonExistentFlowSrcTableName,
Columns: map[string]string{
- "id": model.ColumnTypeInt32,
- "name": model.ColumnTypeString,
+ "id": string(qvalue.QValueKindInt32),
+ "name": string(qvalue.QValueKindString),
},
PrimaryKeyColumn: "id",
}
@@ -494,8 +495,8 @@ func (suite *PostgresCDCTestSuite) TestSimpleHappyFlow() {
suite.Equal(&protos.TableSchema{
TableIdentifier: simpleHappyFlowSrcTableName,
Columns: map[string]string{
- "id": model.ColumnTypeInt32,
- "name": model.ColumnTypeString,
+ "id": string(qvalue.QValueKindInt32),
+ "name": string(qvalue.QValueKindString),
},
PrimaryKeyColumn: "id",
}, tableNameSchema)
@@ -603,36 +604,36 @@ func (suite *PostgresCDCTestSuite) TestAllTypesHappyFlow() {
suite.Equal(&protos.TableSchema{
TableIdentifier: allTypesHappyFlowSrcTableName,
Columns: map[string]string{
- "id": model.ColumnTypeInt64,
- "c1": model.ColumnTypeInt64,
+ "id": string(qvalue.QValueKindInt64),
+ "c1": string(qvalue.QValueKindInt64),
"c2": model.ColumnHexBit,
"c3": model.ColumnHexBit,
- "c4": model.ColumnTypeBoolean,
- "c6": model.ColumnHexBytes,
- "c7": model.ColumnTypeString,
- "c8": model.ColumnTypeString,
- "c9": model.ColumnTypeString,
- "c11": model.ColumnTypeDate,
- "c12": model.ColumnTypeFloat64,
- "c13": model.ColumnTypeFloat64,
+ "c4": string(qvalue.QValueKindBoolean),
+ "c6": string(qvalue.QValueKindBytes),
+ "c7": string(qvalue.QValueKindString),
+ "c8": string(qvalue.QValueKindString),
+ "c9": string(qvalue.QValueKindString),
+ "c11": string(qvalue.QValueKindDate),
+ "c12": string(qvalue.QValueKindFloat64),
+ "c13": string(qvalue.QValueKindFloat64),
"c14": model.ColumnTypeString,
- "c15": model.ColumnTypeInt32,
+ "c15": string(qvalue.QValueKindInt32),
"c16": model.ColumnTypeInterval,
- "c17": model.ColumnTypeJSON,
- "c18": model.ColumnTypeJSON,
+ "c17": string(qvalue.QValueKindJSON),
+ "c18": string(qvalue.QValueKindJSON),
"c21": model.ColumnTypeString,
"c22": model.ColumnTypeString,
- "c23": model.ColumnTypeNumeric,
- "c24": model.ColumnTypeInt64,
- "c28": model.ColumnTypeFloat32,
- "c29": model.ColumnTypeInt16,
- "c30": model.ColumnTypeInt16,
- "c31": model.ColumnTypeInt32,
+ "c23": string(qvalue.QValueKindNumeric),
+ "c24": string(qvalue.QValueKindInt64),
+ "c28": string(qvalue.QValueKindFloat32),
+ "c29": string(qvalue.QValueKindInt16),
+ "c30": string(qvalue.QValueKindInt16),
+ "c31": string(qvalue.QValueKindInt32),
"c32": model.ColumnTypeString,
- "c33": model.ColumnTypeTimestamp,
- "c34": model.ColumnTypeTimeStampWithTimeZone,
- "c35": model.ColumnTypeTime,
- "c36": model.ColumnTypeTimeWithTimeZone,
+ "c33": string(qvalue.QValueKindTimestamp),
+ "c34": string(qvalue.QValueKindTimestampTZ),
+ "c35": string(qvalue.QValueKindTime),
+ "c36": string(qvalue.QValueKindTimeTZ),
"c37": model.ColumnTypeString,
"c38": model.ColumnTypeString,
"c39": model.ColumnTypeString,
@@ -714,11 +715,11 @@ func (suite *PostgresCDCTestSuite) TestToastHappyFlow() {
suite.Equal(&protos.TableSchema{
TableIdentifier: toastHappyFlowSrcTableName,
Columns: map[string]string{
- "id": model.ColumnTypeInt32,
- "n_t": model.ColumnTypeString,
- "lz4_t": model.ColumnTypeString,
- "n_b": model.ColumnHexBytes,
- "lz4_b": model.ColumnHexBytes,
+ "id": string(qvalue.QValueKindInt32),
+ "n_t": string(qvalue.QValueKindString),
+ "lz4_t": string(qvalue.QValueKindString),
+ "n_b": string(qvalue.QValueKindBytes),
+ "lz4_b": string(qvalue.QValueKindBytes),
},
PrimaryKeyColumn: "id",
}, tableNameSchema)
diff --git a/flow/connectors/postgres/qrep.go b/flow/connectors/postgres/qrep.go
index aa74bd9e21..10f1d4249b 100644
--- a/flow/connectors/postgres/qrep.go
+++ b/flow/connectors/postgres/qrep.go
@@ -28,6 +28,8 @@ func (c *PostgresConnector) GetQRepPartitions(
switch v := minValue.(type) {
case int32, int64:
maxValue := maxValue.(int64) + 1
+ fmt.Println("minValue", minValue)
+ fmt.Println("maxValue", maxValue)
partitions, err = c.getIntPartitions(v.(int64), maxValue, config.BatchSizeInt)
case time.Time:
maxValue := maxValue.(time.Time).Add(time.Microsecond)
@@ -286,7 +288,8 @@ func (c *PostgresConnector) getIntPartitions(
for start <= end {
partitionEnd := start + batchSize
- if partitionEnd > end {
+ // safeguard against integer overflow
+ if partitionEnd > end || partitionEnd < start {
partitionEnd = end
}
diff --git a/flow/connectors/postgres/qrep_query_executor.go b/flow/connectors/postgres/qrep_query_executor.go
index 5dd4cc1ab4..8bf3c2373f 100644
--- a/flow/connectors/postgres/qrep_query_executor.go
+++ b/flow/connectors/postgres/qrep_query_executor.go
@@ -2,18 +2,11 @@ package connpostgres
import (
"context"
- "database/sql"
- "errors"
"fmt"
- "math"
- "math/big"
- "time"
"github.com/PeerDB-io/peer-flow/model"
- "github.com/PeerDB-io/peer-flow/model/qvalue"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
- "github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
log "github.com/sirupsen/logrus"
)
@@ -38,50 +31,12 @@ func (qe *QRepQueryExecutor) ExecuteQuery(query string, args ...interface{}) (pg
return rows, nil
}
-func fieldDescriptionToQValueKind(fd pgconn.FieldDescription) qvalue.QValueKind {
- switch fd.DataTypeOID {
- case pgtype.BoolOID:
- return qvalue.QValueKindBoolean
- case pgtype.Int2OID:
- return qvalue.QValueKindInt16
- case pgtype.Int4OID:
- return qvalue.QValueKindInt32
- case pgtype.Int8OID:
- return qvalue.QValueKindInt64
- case pgtype.Float4OID:
- return qvalue.QValueKindFloat32
- case pgtype.Float8OID:
- return qvalue.QValueKindFloat64
- case pgtype.TextOID, pgtype.VarcharOID:
- return qvalue.QValueKindString
- case pgtype.ByteaOID:
- return qvalue.QValueKindBytes
- case pgtype.JSONOID, pgtype.JSONBOID:
- return qvalue.QValueKindJSON
- case pgtype.UUIDOID:
- return qvalue.QValueKindUUID
- case pgtype.TimestampOID, pgtype.TimestamptzOID, pgtype.DateOID, pgtype.TimeOID:
- return qvalue.QValueKindETime
- case pgtype.NumericOID:
- return qvalue.QValueKindNumeric
- default:
- typeName, ok := pgtype.NewMap().TypeForOID(fd.DataTypeOID)
- if !ok {
- log.Warnf("failed to get type name for oid: %v", fd.DataTypeOID)
- return qvalue.QValueKindInvalid
- } else {
- log.Warnf("unsupported field type: %v - type name - %s", fd.DataTypeOID, typeName.Name)
- return qvalue.QValueKindInvalid
- }
- }
-}
-
// FieldDescriptionsToSchema converts a slice of pgconn.FieldDescription to a QRecordSchema.
func fieldDescriptionsToSchema(fds []pgconn.FieldDescription) *model.QRecordSchema {
qfields := make([]*model.QField, len(fds))
for i, fd := range fds {
cname := fd.Name
- ctype := fieldDescriptionToQValueKind(fd)
+ ctype := getQValueKindForPostgresOID(fd.DataTypeOID)
// there isn't a way to know if a column is nullable or not
// TODO fix this.
cnullable := true
@@ -146,216 +101,22 @@ func (qe *QRepQueryExecutor) ExecuteAndProcessQuery(
return batch, nil
}
-func mapRowToQRecord(row pgx.Row, fds []pgconn.FieldDescription) (*model.QRecord, error) {
+func mapRowToQRecord(row pgx.Rows, fds []pgconn.FieldDescription) (*model.QRecord, error) {
// make vals an empty array of QValue of size len(fds)
record := model.NewQRecord(len(fds))
- scanArgs := make([]interface{}, len(fds))
- for i := range scanArgs {
- switch fds[i].DataTypeOID {
- case pgtype.BoolOID:
- scanArgs[i] = new(pgtype.Bool)
- case pgtype.TimestampOID:
- scanArgs[i] = new(pgtype.Timestamp)
- case pgtype.TimestamptzOID:
- scanArgs[i] = new(pgtype.Timestamptz)
- case pgtype.Int4OID:
- scanArgs[i] = new(pgtype.Int4)
- case pgtype.Int8OID:
- scanArgs[i] = new(pgtype.Int8)
- case pgtype.Float4OID:
- scanArgs[i] = new(pgtype.Float4)
- case pgtype.Float8OID:
- scanArgs[i] = new(pgtype.Float8)
- case pgtype.TextOID:
- scanArgs[i] = new(pgtype.Text)
- case pgtype.VarcharOID:
- scanArgs[i] = new(pgtype.Text)
- case pgtype.NumericOID:
- scanArgs[i] = new(pgtype.Numeric)
- case pgtype.UUIDOID:
- scanArgs[i] = new(pgtype.UUID)
- case pgtype.ByteaOID:
- scanArgs[i] = new(sql.RawBytes)
- case pgtype.DateOID:
- scanArgs[i] = new(pgtype.Date)
- default:
- scanArgs[i] = new(pgtype.Text)
- }
- }
-
- err := row.Scan(scanArgs...)
+ values, err := row.Values()
if err != nil {
return nil, fmt.Errorf("failed to scan row: %w", err)
}
for i, fd := range fds {
- tmp, err := parseField(fd.DataTypeOID, scanArgs[i])
+ tmp, err := parseFieldFromPostgresOID(fd.DataTypeOID, values[i])
if err != nil {
return nil, fmt.Errorf("failed to parse field: %w", err)
}
- record.Set(i, tmp)
+ record.Set(i, *tmp)
}
return record, nil
}
-
-func parseField(oid uint32, value interface{}) (qvalue.QValue, error) {
- var val qvalue.QValue
-
- switch oid {
- case pgtype.TimestampOID:
- timestamp := value.(*pgtype.Timestamp)
- var et *qvalue.ExtendedTime
- if timestamp.Valid {
- var err error
- et, err = qvalue.NewExtendedTime(timestamp.Time, qvalue.DateTimeKindType, "")
- if err != nil {
- return qvalue.QValue{}, fmt.Errorf("failed to create ExtendedTime: %w", err)
- }
- }
- val = qvalue.QValue{Kind: qvalue.QValueKindETime, Value: et}
- case pgtype.TimestamptzOID:
- timestamp := value.(*pgtype.Timestamptz)
- var et *qvalue.ExtendedTime
- if timestamp.Valid {
- var err error
- et, err = qvalue.NewExtendedTime(timestamp.Time, qvalue.DateTimeKindType, "")
- if err != nil {
- return qvalue.QValue{}, fmt.Errorf("failed to create ExtendedTime: %w", err)
- }
- }
- val = qvalue.QValue{Kind: qvalue.QValueKindETime, Value: et}
- case pgtype.DateOID:
- date := value.(*pgtype.Date)
- var et *qvalue.ExtendedTime
- if date.Valid {
- var err error
- et, err = qvalue.NewExtendedTime(date.Time, qvalue.DateKindType, "")
- if err != nil {
- return qvalue.QValue{}, fmt.Errorf("failed to create ExtendedTime: %w", err)
- }
- }
- val = qvalue.QValue{Kind: qvalue.QValueKindETime, Value: et}
- case pgtype.TimeOID:
- timeVal := value.(*pgtype.Text)
- var et *qvalue.ExtendedTime
- if timeVal.Valid {
- t, err := time.Parse("15:04:05.999999", timeVal.String)
- if err != nil {
- return qvalue.QValue{}, fmt.Errorf("failed to parse time: %w", err)
- }
- et, err = qvalue.NewExtendedTime(t, qvalue.TimeKindType, "")
- if err != nil {
- return qvalue.QValue{}, fmt.Errorf("failed to create ExtendedTime: %w", err)
- }
- }
- val = qvalue.QValue{Kind: qvalue.QValueKindETime, Value: et}
- case pgtype.BoolOID:
- boolVal := value.(*pgtype.Bool)
- if boolVal.Valid {
- val = qvalue.QValue{Kind: qvalue.QValueKindBoolean, Value: boolVal.Bool}
- } else {
- val = qvalue.QValue{Kind: qvalue.QValueKindBoolean, Value: nil}
- }
- case pgtype.JSONOID, pgtype.JSONBOID:
- // TODO: improve JSON support
- strVal := value.(*pgtype.Text)
- if strVal != nil {
- val = qvalue.QValue{Kind: qvalue.QValueKindJSON, Value: strVal.String}
- } else {
- val = qvalue.QValue{Kind: qvalue.QValueKindJSON, Value: nil}
- }
- case pgtype.Int2OID:
- intVal := value.(*pgtype.Int2)
- if intVal.Valid {
- val = qvalue.QValue{Kind: qvalue.QValueKindInt16, Value: intVal.Int16}
- } else {
- val = qvalue.QValue{Kind: qvalue.QValueKindInt16, Value: nil}
- }
- case pgtype.Int4OID:
- intVal := value.(*pgtype.Int4)
- if intVal.Valid {
- val = qvalue.QValue{Kind: qvalue.QValueKindInt32, Value: intVal.Int32}
- } else {
- val = qvalue.QValue{Kind: qvalue.QValueKindInt32, Value: nil}
- }
- case pgtype.Int8OID:
- intVal := value.(*pgtype.Int8)
- if intVal.Valid {
- val = qvalue.QValue{Kind: qvalue.QValueKindInt64, Value: intVal.Int64}
- } else {
- val = qvalue.QValue{Kind: qvalue.QValueKindInt64, Value: nil}
- }
- case pgtype.Float4OID:
- floatVal := value.(*pgtype.Float4)
- if floatVal.Valid {
- val = qvalue.QValue{Kind: qvalue.QValueKindFloat32, Value: floatVal.Float32}
- } else {
- val = qvalue.QValue{Kind: qvalue.QValueKindFloat32, Value: nil}
- }
- case pgtype.Float8OID:
- floatVal := value.(*pgtype.Float8)
- if floatVal.Valid {
- val = qvalue.QValue{Kind: qvalue.QValueKindFloat64, Value: floatVal.Float64}
- } else {
- val = qvalue.QValue{Kind: qvalue.QValueKindFloat64, Value: nil}
- }
- case pgtype.TextOID, pgtype.VarcharOID:
- textVal := value.(*pgtype.Text)
- if textVal.Valid {
- val = qvalue.QValue{Kind: qvalue.QValueKindString, Value: textVal.String}
- } else {
- val = qvalue.QValue{Kind: qvalue.QValueKindString, Value: nil}
- }
- case pgtype.UUIDOID:
- uuidVal := value.(*pgtype.UUID)
- if uuidVal.Valid {
- val = qvalue.QValue{Kind: qvalue.QValueKindUUID, Value: uuidVal.Bytes}
- } else {
- val = qvalue.QValue{Kind: qvalue.QValueKindUUID, Value: nil}
- }
- case pgtype.ByteaOID:
- rawBytes := value.(*sql.RawBytes)
- val = qvalue.QValue{Kind: qvalue.QValueKindBytes, Value: []byte(*rawBytes)}
- case pgtype.NumericOID:
- numVal := value.(*pgtype.Numeric)
- rat, err := numericToRat(numVal)
- if err != nil {
- log.Warnf("failed to convert numeric [%v] to rat: %v", value, err)
- val = qvalue.QValue{Kind: qvalue.QValueKindNumeric, Value: nil}
- } else {
- val = qvalue.QValue{Kind: qvalue.QValueKindNumeric, Value: rat}
- }
- default:
- typ, _ := pgtype.NewMap().TypeForOID(oid)
- fmt.Printf("QValueKindInvalid => oid: %v, typename: %v\n", oid, typ)
- val = qvalue.QValue{Kind: qvalue.QValueKindInvalid, Value: nil}
- }
-
- return val, nil
-}
-
-func numericToRat(numVal *pgtype.Numeric) (*big.Rat, error) {
- if numVal.Valid {
- if numVal.NaN {
- return nil, errors.New("numeric value is NaN")
- }
-
- switch numVal.InfinityModifier {
- case pgtype.NegativeInfinity, pgtype.Infinity:
- return nil, errors.New("numeric value is infinity")
- }
-
- rat := new(big.Rat)
-
- rat.SetInt(numVal.Int)
- divisor := new(big.Rat).SetFloat64(math.Pow10(int(-numVal.Exp)))
- rat.Quo(rat, divisor)
-
- return rat, nil
- }
-
- // handle invalid numeric
- return nil, errors.New("invalid numeric")
-}
diff --git a/flow/connectors/postgres/qvalue_convert.go b/flow/connectors/postgres/qvalue_convert.go
new file mode 100644
index 0000000000..d9b821f2a2
--- /dev/null
+++ b/flow/connectors/postgres/qvalue_convert.go
@@ -0,0 +1,195 @@
+package connpostgres
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "math"
+ "math/big"
+ "strings"
+ "time"
+
+ "github.com/PeerDB-io/peer-flow/model/qvalue"
+ "github.com/jackc/pgx/v5/pgtype"
+ log "github.com/sirupsen/logrus"
+)
+
+func getQValueKindForPostgresOID(oid uint32) qvalue.QValueKind {
+ switch oid {
+ case pgtype.BoolOID:
+ return qvalue.QValueKindBoolean
+ case pgtype.Int2OID:
+ return qvalue.QValueKindInt16
+ case pgtype.Int4OID:
+ return qvalue.QValueKindInt32
+ case pgtype.Int8OID:
+ return qvalue.QValueKindInt64
+ case pgtype.Float4OID:
+ return qvalue.QValueKindFloat32
+ case pgtype.Float8OID:
+ return qvalue.QValueKindFloat64
+ case pgtype.TextOID, pgtype.VarcharOID, pgtype.BPCharOID:
+ return qvalue.QValueKindString
+ case pgtype.ByteaOID:
+ return qvalue.QValueKindBytes
+ case pgtype.JSONOID, pgtype.JSONBOID:
+ return qvalue.QValueKindJSON
+ case pgtype.UUIDOID:
+ return qvalue.QValueKindUUID
+ case pgtype.TimeOID:
+ return qvalue.QValueKindTime
+ case pgtype.DateOID:
+ return qvalue.QValueKindDate
+ case pgtype.TimestampOID:
+ return qvalue.QValueKindTimestamp
+ case pgtype.TimestamptzOID:
+ return qvalue.QValueKindTimestampTZ
+ case pgtype.NumericOID:
+ return qvalue.QValueKindNumeric
+ default:
+ typeName, ok := pgtype.NewMap().TypeForOID(oid)
+ if !ok {
+ // workaround for TIMETZ not being defined by this pgtype
+ if oid == 1266 {
+ return qvalue.QValueKindTimeTZ
+ }
+ log.Warnf("failed to get type name for oid: %v", oid)
+ return qvalue.QValueKindInvalid
+ } else {
+ log.Warnf("unsupported field type: %v - type name - %s", oid, typeName.Name)
+ return qvalue.QValueKindInvalid
+ }
+ }
+}
+
+func parseFieldFromQValueKind(qvalueKind qvalue.QValueKind, value interface{}) (*qvalue.QValue, error) {
+ var val *qvalue.QValue = nil
+
+ switch qvalueKind {
+ case qvalue.QValueKindTimestamp:
+ timestamp := value.(time.Time)
+ val = &qvalue.QValue{Kind: qvalue.QValueKindTimestamp, Value: timestamp}
+ case qvalue.QValueKindTimestampTZ:
+ timestamp := value.(time.Time)
+ val = &qvalue.QValue{Kind: qvalue.QValueKindTimestampTZ, Value: timestamp}
+ case qvalue.QValueKindDate:
+ date := value.(time.Time)
+ val = &qvalue.QValue{Kind: qvalue.QValueKindDate, Value: date}
+ case qvalue.QValueKindTime:
+ timeVal := value.(pgtype.Time)
+ if timeVal.Valid {
+ var timeValStr any
+ timeValStr, err := timeVal.Value()
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse time: %w", err)
+ }
+ // edge case, only Postgres supports this extreme value for time
+ timeValStr = strings.Replace(timeValStr.(string), "24:00:00.000000", "23:59:59.999999", 1)
+ t, err := time.Parse("15:04:05.999999", timeValStr.(string))
+ t = t.AddDate(1970, 0, 0)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse time: %w", err)
+ }
+ val = &qvalue.QValue{Kind: qvalue.QValueKindTime, Value: t}
+ }
+ case qvalue.QValueKindTimeTZ:
+ timeVal := value.(string)
+ // edge case, Postgres supports this extreme value for time
+ timeVal = strings.Replace(timeVal, "24:00:00.000000", "23:59:59.999999", 1)
+ t, err := time.Parse("15:04:05.999999-0700", timeVal)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse time: %w", err)
+ }
+ t = t.AddDate(1970, 0, 0)
+ val = &qvalue.QValue{Kind: qvalue.QValueKindTime, Value: t}
+
+ case qvalue.QValueKindBoolean:
+ boolVal := value.(bool)
+ val = &qvalue.QValue{Kind: qvalue.QValueKindBoolean, Value: boolVal}
+ case qvalue.QValueKindJSON:
+ // TODO: improve JSON support
+ jsonVal := value.(map[string]interface{})
+ jsonValStr, err := json.Marshal(jsonVal)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse json: %w", err)
+ }
+ val = &qvalue.QValue{Kind: qvalue.QValueKindJSON, Value: string(jsonValStr)}
+ case qvalue.QValueKindInt16:
+ intVal := value.(int16)
+ val = &qvalue.QValue{Kind: qvalue.QValueKindInt16, Value: intVal}
+ case qvalue.QValueKindInt32:
+ intVal := value.(int32)
+ val = &qvalue.QValue{Kind: qvalue.QValueKindInt32, Value: intVal}
+ case qvalue.QValueKindInt64:
+ intVal := value.(int64)
+ val = &qvalue.QValue{Kind: qvalue.QValueKindInt64, Value: intVal}
+ case qvalue.QValueKindFloat32:
+ floatVal := value.(float32)
+ val = &qvalue.QValue{Kind: qvalue.QValueKindFloat32, Value: floatVal}
+ case qvalue.QValueKindFloat64:
+ floatVal := value.(float64)
+ val = &qvalue.QValue{Kind: qvalue.QValueKindFloat64, Value: floatVal}
+ case qvalue.QValueKindString:
+ textVal := value.(string)
+ val = &qvalue.QValue{Kind: qvalue.QValueKindString, Value: textVal}
+ case qvalue.QValueKindUUID:
+ switch value.(type) {
+ case string:
+ val = &qvalue.QValue{Kind: qvalue.QValueKindUUID, Value: value}
+ case [16]byte:
+ val = &qvalue.QValue{Kind: qvalue.QValueKindUUID, Value: value}
+ default:
+ return nil, fmt.Errorf("failed to parse UUID: %v", value)
+ }
+ case qvalue.QValueKindBytes:
+ rawBytes := value.([]byte)
+ val = &qvalue.QValue{Kind: qvalue.QValueKindBytes, Value: rawBytes}
+ // TODO: check for handling of QValueKindBit
+ case qvalue.QValueKindNumeric:
+ numVal := value.(pgtype.Numeric)
+ if numVal.Valid {
+ rat, err := numericToRat(&numVal)
+ if err != nil {
+ return nil, fmt.Errorf("failed to convert numeric [%v] to rat: %w", value, err)
+ }
+ val = &qvalue.QValue{Kind: qvalue.QValueKindNumeric, Value: rat}
+ }
+ default:
+ log.Errorf("unhandled QValueKind => %v\n", qvalueKind)
+ return nil, fmt.Errorf("unhandled QValueKind => %v", qvalueKind)
+ }
+
+ // parsing into pgtype failed.
+ if val == nil {
+ return nil, fmt.Errorf("failed to parse value %v into QValueKind %v", value, qvalueKind)
+ }
+ return val, nil
+}
+
+func parseFieldFromPostgresOID(oid uint32, value interface{}) (*qvalue.QValue, error) {
+ return parseFieldFromQValueKind(getQValueKindForPostgresOID(oid), value)
+}
+
+func numericToRat(numVal *pgtype.Numeric) (*big.Rat, error) {
+ if numVal.Valid {
+ if numVal.NaN {
+ return nil, errors.New("numeric value is NaN")
+ }
+
+ switch numVal.InfinityModifier {
+ case pgtype.NegativeInfinity, pgtype.Infinity:
+ return nil, errors.New("numeric value is infinity")
+ }
+
+ rat := new(big.Rat)
+
+ rat.SetInt(numVal.Int)
+ divisor := new(big.Rat).SetFloat64(math.Pow10(int(-numVal.Exp)))
+ rat.Quo(rat, divisor)
+
+ return rat, nil
+ }
+
+ // handle invalid numeric
+ return nil, errors.New("invalid numeric")
+}
diff --git a/flow/connectors/snowflake/avro_file_writer_test.go b/flow/connectors/snowflake/avro_file_writer_test.go
index c5f187e2e7..41b76513c6 100644
--- a/flow/connectors/snowflake/avro_file_writer_test.go
+++ b/flow/connectors/snowflake/avro_file_writer_test.go
@@ -19,7 +19,7 @@ func createQValue(t *testing.T, kind qvalue.QValueKind, placeHolder int) qvalue.
switch kind {
case qvalue.QValueKindInt16, qvalue.QValueKindInt32, qvalue.QValueKindInt64:
value = int64(placeHolder)
- case qvalue.QValueKindFloat16, qvalue.QValueKindFloat32:
+ case qvalue.QValueKindFloat32:
value = float32(placeHolder)
case qvalue.QValueKindFloat64:
value = float64(placeHolder)
@@ -27,10 +27,9 @@ func createQValue(t *testing.T, kind qvalue.QValueKind, placeHolder int) qvalue.
value = placeHolder%2 == 0
case qvalue.QValueKindString:
value = fmt.Sprintf("string%d", placeHolder)
- case qvalue.QValueKindETime:
- et, err := qvalue.NewExtendedTime(time.Now(), qvalue.TimeKindType, "")
- require.NoError(t, err)
- value = et
+ case qvalue.QValueKindTimestamp, qvalue.QValueKindTimestampTZ, qvalue.QValueKindTime,
+ qvalue.QValueKindTimeTZ, qvalue.QValueKindDate:
+ value = time.Now()
case qvalue.QValueKindNumeric:
// create a new big.Rat for numeric data
value = big.NewRat(int64(placeHolder), 1)
@@ -58,7 +57,6 @@ func createQValue(t *testing.T, kind qvalue.QValueKind, placeHolder int) qvalue.
func generateRecords(t *testing.T, nullable bool, numRows uint32) *model.QRecordBatch {
allQValueKinds := []qvalue.QValueKind{
- qvalue.QValueKindFloat16,
qvalue.QValueKindFloat32,
qvalue.QValueKindFloat64,
qvalue.QValueKindInt16,
@@ -68,7 +66,11 @@ func generateRecords(t *testing.T, nullable bool, numRows uint32) *model.QRecord
// qvalue.QValueKindArray,
// qvalue.QValueKindStruct,
qvalue.QValueKindString,
- qvalue.QValueKindETime,
+ qvalue.QValueKindTimestamp,
+ qvalue.QValueKindTimestampTZ,
+ qvalue.QValueKindTime,
+ qvalue.QValueKindTimeTZ,
+ qvalue.QValueKindDate,
qvalue.QValueKindNumeric,
qvalue.QValueKindBytes,
qvalue.QValueKindUUID,
diff --git a/flow/connectors/snowflake/client.go b/flow/connectors/snowflake/client.go
index 910ca285e6..ed64a9adb6 100644
--- a/flow/connectors/snowflake/client.go
+++ b/flow/connectors/snowflake/client.go
@@ -213,15 +213,12 @@ func toQValue(kind qvalue.QValueKind, val interface{}) (qvalue.QValue, error) {
Value: ratVal,
}, nil
}
- case qvalue.QValueKindETime:
- if v, ok := val.(*time.Time); ok && v != nil {
- etimeVal, err := qvalue.NewExtendedTime(*v, qvalue.DateTimeKindType, "")
- if err != nil {
- return qvalue.QValue{}, fmt.Errorf("failed to create ExtendedTime: %w", err)
- }
+ case qvalue.QValueKindTimestamp, qvalue.QValueKindTimestampTZ, qvalue.QValueKindDate,
+ qvalue.QValueKindTime, qvalue.QValueKindTimeTZ:
+ if t, ok := val.(*time.Time); ok && t != nil {
return qvalue.QValue{
- Kind: qvalue.QValueKindETime,
- Value: etimeVal,
+ Kind: kind,
+ Value: *t,
}, nil
}
case qvalue.QValueKindBytes:
@@ -234,35 +231,8 @@ func toQValue(kind qvalue.QValueKind, val interface{}) (qvalue.QValue, error) {
return qvalue.QValue{}, fmt.Errorf("[snowflakeclient] unsupported type %T for kind %s", val, kind)
}
-// databaseTypeNameToQValueKind converts a database type name to a QValueKind.
-func databaseTypeNameToQValueKind(name string) (qvalue.QValueKind, error) {
- switch name {
- case "INT":
- return qvalue.QValueKindInt32, nil
- case "BIGINT":
- return qvalue.QValueKindInt64, nil
- case "FLOAT":
- return qvalue.QValueKindFloat32, nil
- case "DOUBLE", "REAL":
- return qvalue.QValueKindFloat64, nil
- case "VARCHAR", "CHAR", "TEXT":
- return qvalue.QValueKindString, nil
- case "BOOLEAN":
- return qvalue.QValueKindBoolean, nil
- case "DATETIME", "TIMESTAMP", "TIMESTAMP_LTZ", "TIMESTAMP_NTZ", "TIMESTAMP_TZ":
- return qvalue.QValueKindETime, nil
- case "BLOB", "BYTEA", "BINARY":
- return qvalue.QValueKindBytes, nil
- case "FIXED", "NUMBER":
- return qvalue.QValueKindNumeric, nil
- default:
- // If type is unsupported, return an error
- return "", fmt.Errorf("unsupported database type name: %s", name)
- }
-}
-
func columnTypeToQField(ct *sql.ColumnType) (*model.QField, error) {
- qvKind, err := databaseTypeNameToQValueKind(ct.DatabaseTypeName())
+ qvKind, err := snowflakeTypeToQValueKind(ct.DatabaseTypeName())
if err != nil {
return nil, err
}
@@ -311,7 +281,8 @@ func (s *SnowflakeClient) ExecuteAndProcessQuery(query string) (*model.QRecordBa
values := make([]interface{}, len(columns))
for i := range values {
switch qfields[i].Type {
- case qvalue.QValueKindETime:
+ case qvalue.QValueKindTimestamp, qvalue.QValueKindTimestampTZ, qvalue.QValueKindTime,
+ qvalue.QValueKindTimeTZ, qvalue.QValueKindDate:
values[i] = new(time.Time)
case qvalue.QValueKindInt16:
values[i] = new(int16)
@@ -375,10 +346,7 @@ func (s *SnowflakeClient) ExecuteAndProcessQuery(query string) (*model.QRecordBa
func (s *SnowflakeClient) CreateTable(schema *model.QRecordSchema, schemaName string, tableName string) error {
var fields []string
for _, field := range schema.Fields {
- snowflakeType, err := qValueKindToSnowflakeColTypeString(field.Type)
- if err != nil {
- return err
- }
+ snowflakeType := qValueKindToSnowflakeType(string(field.Type))
fields = append(fields, fmt.Sprintf(`"%s" %s`, field.Name, snowflakeType))
}
@@ -392,24 +360,3 @@ func (s *SnowflakeClient) CreateTable(schema *model.QRecordSchema, schemaName st
return nil
}
-
-func qValueKindToSnowflakeColTypeString(val qvalue.QValueKind) (string, error) {
- switch val {
- case qvalue.QValueKindInt32, qvalue.QValueKindInt64:
- return "INT", nil
- case qvalue.QValueKindFloat32, qvalue.QValueKindFloat64:
- return "FLOAT", nil
- case qvalue.QValueKindString:
- return "STRING", nil
- case qvalue.QValueKindBoolean:
- return "BOOLEAN", nil
- case qvalue.QValueKindETime:
- return "TIMESTAMP_LTZ", nil
- case qvalue.QValueKindBytes:
- return "BINARY", nil
- case qvalue.QValueKindNumeric:
- return "NUMERIC(38,32)", nil
- default:
- return "", fmt.Errorf("unsupported QValueKind: %v", val)
- }
-}
diff --git a/flow/connectors/snowflake/qvalue_convert.go b/flow/connectors/snowflake/qvalue_convert.go
new file mode 100644
index 0000000000..558a9b8373
--- /dev/null
+++ b/flow/connectors/snowflake/qvalue_convert.go
@@ -0,0 +1,84 @@
+package connsnowflake
+
+import (
+ "fmt"
+
+ "github.com/PeerDB-io/peer-flow/model/qvalue"
+)
+
+func qValueKindToSnowflakeType(colType string) string {
+ switch qvalue.QValueKind(colType) {
+ case qvalue.QValueKindBoolean:
+ return "BOOLEAN"
+ // integer types
+ case qvalue.QValueKindInt16, qvalue.QValueKindInt32, qvalue.QValueKindInt64:
+ return "INTEGER"
+ // decimal types
+ // The names FLOAT, FLOAT4, and FLOAT8 are for compatibility with other systems
+ // Snowflake treats all three as 64-bit floating-point numbers.
+ case qvalue.QValueKindFloat32, qvalue.QValueKindFloat64:
+ return "FLOAT"
+ case qvalue.QValueKindNumeric:
+ return "NUMBER(38, 9)"
+ // string related STRING , TEXT , NVARCHAR ,
+ // NVARCHAR2 , CHAR VARYING , NCHAR VARYING
+ //Synonymous with VARCHAR.
+ case qvalue.QValueKindString:
+ return "STRING"
+ // json also is stored as string for now
+ case qvalue.QValueKindJSON:
+ return "STRING"
+ // time related
+ case qvalue.QValueKindTimestamp:
+ return "TIMESTAMP_NTZ"
+ case qvalue.QValueKindTimestampTZ:
+ return "TIMESTAMP_TZ"
+ case qvalue.QValueKindTime:
+ return "TIME"
+ case qvalue.QValueKindDate:
+ return "DATE"
+ // handle INTERVAL types again
+ // case model.ColumnTypeTimeWithTimeZone, model.ColumnTypeInterval:
+ // return "STRING"
+ // bytes
+ case qvalue.QValueKindBit, qvalue.QValueKindBytes:
+ return "BINARY"
+ // rest will be strings
+ default:
+ return "STRING"
+ }
+}
+
+// snowflakeTypeToQValueKind converts a database type name to a QValueKind.
+func snowflakeTypeToQValueKind(name string) (qvalue.QValueKind, error) {
+ switch name {
+ case "INT":
+ return qvalue.QValueKindInt32, nil
+ case "BIGINT":
+ return qvalue.QValueKindInt64, nil
+ case "FLOAT":
+ return qvalue.QValueKindFloat32, nil
+ case "DOUBLE", "REAL":
+ return qvalue.QValueKindFloat64, nil
+ case "VARCHAR", "CHAR", "TEXT":
+ return qvalue.QValueKindString, nil
+ case "BOOLEAN":
+ return qvalue.QValueKindBoolean, nil
+ // assuming TIMESTAMP is an alias to TIMESTAMP_NTZ, which is the default.
+ case "DATETIME", "TIMESTAMP", "TIMESTAMP_NTZ":
+ return qvalue.QValueKindTimestamp, nil
+ case "TIMESTAMP_TZ":
+ return qvalue.QValueKindTimestampTZ, nil
+ case "TIME":
+ return qvalue.QValueKindTime, nil
+ case "DATE":
+ return qvalue.QValueKindDate, nil
+ case "BLOB", "BYTEA", "BINARY":
+ return qvalue.QValueKindBytes, nil
+ case "FIXED", "NUMBER", "DECIMAL", "NUMERIC":
+ return qvalue.QValueKindNumeric, nil
+ default:
+ // If type is unsupported, return an error
+ return "", fmt.Errorf("unsupported database type name: %s", name)
+ }
+}
diff --git a/flow/connectors/snowflake/snowflake.go b/flow/connectors/snowflake/snowflake.go
index 49300adf8c..0ce38070e1 100644
--- a/flow/connectors/snowflake/snowflake.go
+++ b/flow/connectors/snowflake/snowflake.go
@@ -13,6 +13,7 @@ import (
"github.com/PeerDB-io/peer-flow/connectors/utils"
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/PeerDB-io/peer-flow/model"
+ "github.com/PeerDB-io/peer-flow/model/qvalue"
util "github.com/PeerDB-io/peer-flow/utils"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
@@ -92,7 +93,6 @@ type snowflakeRawRecord struct {
recordType int
matchData string
batchID int64
- items map[string]interface{}
unchangedToastColumns string
}
@@ -380,7 +380,7 @@ func (c *SnowflakeConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.
switch typedRecord := record.(type) {
case *model.InsertRecord:
// json.Marshal converts bytes in Hex automatically to BASE64 string.
- itemsJSON, err := json.Marshal(typedRecord.Items)
+ itemsJSON, err := typedRecord.Items.ToJSON()
if err != nil {
return nil, fmt.Errorf("failed to serialize insert record items to JSON: %w", err)
}
@@ -390,19 +390,18 @@ func (c *SnowflakeConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.
uid: uuid.New().String(),
timestamp: time.Now().UnixNano(),
destinationTableName: typedRecord.DestinationTableName,
- data: string(itemsJSON),
+ data: itemsJSON,
recordType: 0,
matchData: "",
batchID: syncBatchID,
- items: typedRecord.Items,
unchangedToastColumns: utils.KeysToString(typedRecord.UnchangedToastColumns),
})
case *model.UpdateRecord:
- newItemsJSON, err := json.Marshal(typedRecord.NewItems)
+ newItemsJSON, err := typedRecord.NewItems.ToJSON()
if err != nil {
return nil, fmt.Errorf("failed to serialize update record new items to JSON: %w", err)
}
- oldItemsJSON, err := json.Marshal(typedRecord.OldItems)
+ oldItemsJSON, err := typedRecord.OldItems.ToJSON()
if err != nil {
return nil, fmt.Errorf("failed to serialize update record old items to JSON: %w", err)
}
@@ -412,15 +411,14 @@ func (c *SnowflakeConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.
uid: uuid.New().String(),
timestamp: time.Now().UnixNano(),
destinationTableName: typedRecord.DestinationTableName,
- data: string(newItemsJSON),
+ data: newItemsJSON,
recordType: 1,
- matchData: string(oldItemsJSON),
+ matchData: oldItemsJSON,
batchID: syncBatchID,
- items: typedRecord.NewItems,
unchangedToastColumns: utils.KeysToString(typedRecord.UnchangedToastColumns),
})
case *model.DeleteRecord:
- itemsJSON, err := json.Marshal(typedRecord.Items)
+ itemsJSON, err := typedRecord.Items.ToJSON()
if err != nil {
return nil, fmt.Errorf("failed to serialize delete record items to JSON: %w", err)
}
@@ -430,11 +428,10 @@ func (c *SnowflakeConnector) SyncRecords(req *model.SyncRecordsRequest) (*model.
uid: uuid.New().String(),
timestamp: time.Now().UnixNano(),
destinationTableName: typedRecord.DestinationTableName,
- data: string(itemsJSON),
+ data: itemsJSON,
recordType: 2,
- matchData: string(itemsJSON),
+ matchData: itemsJSON,
batchID: syncBatchID,
- items: typedRecord.Items,
unchangedToastColumns: utils.KeysToString(typedRecord.UnchangedToastColumns),
})
default:
@@ -669,57 +666,15 @@ func (c *SnowflakeConnector) checkIfTableExists(schemaIdentifier string, tableId
return result, nil
}
-func getSnowflakeTypeForGenericColumnType(colType string) string {
- switch colType {
- case model.ColumnTypeBoolean:
- return "BOOLEAN"
- // integer types
- case model.ColumnTypeInt16, model.ColumnTypeInt32, model.ColumnTypeInt64:
- return "INTEGER"
- // decimal types
- // The names FLOAT, FLOAT4, and FLOAT8 are for compatibility with other systems
- // Snowflake treats all three as 64-bit floating-point numbers.
- case model.ColumnTypeFloat16, model.ColumnTypeFloat32, model.ColumnTypeFloat64:
- return "FLOAT"
- case model.ColumnTypeNumeric:
- return "NUMBER"
- // string related STRING , TEXT , NVARCHAR ,
- // NVARCHAR2 , CHAR VARYING , NCHAR VARYING
- //Synonymous with VARCHAR.
- case model.ColumnTypeString:
- return "STRING"
- // json also is stored as string for now
- case model.ColumnTypeJSON:
- return "STRING"
- // time related
- case model.ColumnTypeTimestamp:
- return "TIMESTAMP_NTZ"
- case model.ColumnTypeTimeStampWithTimeZone:
- return "TIMESTAMP_TZ"
- case model.ColumnTypeTime:
- return "TIME"
- case model.ColumnTypeDate:
- return "TIMESTAMP_NTZ"
- case model.ColumnTypeTimeWithTimeZone, model.ColumnTypeInterval:
- return "STRING"
- // bytes
- case model.ColumnHexBytes, model.ColumnHexBit:
- return "BINARY"
- // rest will be strings
- default:
- return "STRING"
- }
-}
-
func generateCreateTableSQLForNormalizedTable(sourceTableIdentifier string, sourceTableSchema *protos.TableSchema) string {
createTableSQLArray := make([]string, 0, len(sourceTableSchema.Columns))
for columnName, genericColumnType := range sourceTableSchema.Columns {
if sourceTableSchema.PrimaryKeyColumn == strings.ToLower(columnName) {
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("%s %s PRIMARY KEY,",
- columnName, getSnowflakeTypeForGenericColumnType(genericColumnType)))
+ columnName, qValueKindToSnowflakeType(genericColumnType)))
} else {
createTableSQLArray = append(createTableSQLArray, fmt.Sprintf("%s %s,", columnName,
- getSnowflakeTypeForGenericColumnType(genericColumnType)))
+ qValueKindToSnowflakeType(genericColumnType)))
}
}
return fmt.Sprintf(createNormalizedTableSQL, sourceTableIdentifier,
@@ -768,19 +723,20 @@ func (c *SnowflakeConnector) generateAndExecuteMergeStatement(destinationTableId
flattenedCastsSQLArray := make([]string, 0, len(normalizedTableSchema.Columns))
for columnName, genericColumnType := range normalizedTableSchema.Columns {
- sfType := getSnowflakeTypeForGenericColumnType(genericColumnType)
- switch genericColumnType {
- case model.ColumnHexBytes:
+ sfType := qValueKindToSnowflakeType(genericColumnType)
+ switch qvalue.QValueKind(genericColumnType) {
+ case qvalue.QValueKindBytes:
flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("BASE64_DECODE_BINARY(%s:%s) "+
"AS %s,", toVariantColumnName, columnName, columnName))
- case model.ColumnHexBit:
+ case qvalue.QValueKindBit:
// "c2": {"Bytes": "gA==", "Len": 1,"Valid": true}
flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("BASE64_DECODE_BINARY(%s:%s:Bytes) "+
"AS %s,", toVariantColumnName, columnName, columnName))
- case model.ColumnTypeTime:
- flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("TIME_FROM_PARTS(0,0,0,%s:%s:"+
- "Microseconds*1000) "+
- "AS %s,", toVariantColumnName, columnName, columnName))
+ // TODO: https://github.com/PeerDB-io/peerdb/issues/189 - handle time types and interval types
+ // case model.ColumnTypeTime:
+ // flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("TIME_FROM_PARTS(0,0,0,%s:%s:"+
+ // "Microseconds*1000) "+
+ // "AS %s,", toVariantColumnName, columnName, columnName))
default:
flattenedCastsSQLArray = append(flattenedCastsSQLArray, fmt.Sprintf("CAST(%s:%s AS %s) AS %s,",
toVariantColumnName,
diff --git a/flow/e2e/bigquery_helper.go b/flow/e2e/bigquery_helper.go
index d6cc6e2333..b15bf04e38 100644
--- a/flow/e2e/bigquery_helper.go
+++ b/flow/e2e/bigquery_helper.go
@@ -10,6 +10,7 @@ import (
"time"
"cloud.google.com/go/bigquery"
+ "cloud.google.com/go/civil"
peer_bq "github.com/PeerDB-io/peer-flow/connectors/bigquery"
"github.com/PeerDB-io/peer-flow/generated/protos"
"github.com/PeerDB-io/peer-flow/model"
@@ -209,15 +210,10 @@ func toQValue(bqValue bigquery.Value) (qvalue.QValue, error) {
return qvalue.QValue{Kind: qvalue.QValueKindString, Value: v}, nil
case bool:
return qvalue.QValue{Kind: qvalue.QValueKindBoolean, Value: v}, nil
+ case civil.Date:
+ return qvalue.QValue{Kind: qvalue.QValueKindDate, Value: bqValue.(civil.Date).In(time.UTC)}, nil
case time.Time:
- val, err := qvalue.NewExtendedTime(v, qvalue.DateTimeKindType, "")
- if err != nil {
- return qvalue.QValue{}, fmt.Errorf("failed to create ExtendedTime: %w", err)
- }
- return qvalue.QValue{
- Kind: qvalue.QValueKindETime,
- Value: val,
- }, nil
+ return qvalue.QValue{Kind: qvalue.QValueKindTimestamp, Value: v}, nil
case *big.Rat:
return qvalue.QValue{Kind: qvalue.QValueKindNumeric, Value: v}, nil
case []uint8:
@@ -228,38 +224,8 @@ func toQValue(bqValue bigquery.Value) (qvalue.QValue, error) {
}
}
-// bqFieldTypeToQValueKind converts a bigquery FieldType to a QValueKind.
-func bqFieldTypeToQValueKind(fieldType bigquery.FieldType) (qvalue.QValueKind, error) {
- switch fieldType {
- case bigquery.StringFieldType:
- return qvalue.QValueKindString, nil
- case bigquery.BytesFieldType:
- return qvalue.QValueKindBytes, nil
- case bigquery.IntegerFieldType:
- return qvalue.QValueKindInt64, nil
- case bigquery.FloatFieldType:
- return qvalue.QValueKindFloat64, nil
- case bigquery.BooleanFieldType:
- return qvalue.QValueKindBoolean, nil
- case bigquery.TimestampFieldType:
- return qvalue.QValueKindETime, nil
- case bigquery.RecordFieldType:
- return qvalue.QValueKindStruct, nil
- case bigquery.DateFieldType:
- return qvalue.QValueKindETime, nil
- case bigquery.TimeFieldType:
- return qvalue.QValueKindETime, nil
- case bigquery.NumericFieldType:
- return qvalue.QValueKindNumeric, nil
- case bigquery.GeographyFieldType:
- return qvalue.QValueKindString, nil
- default:
- return "", fmt.Errorf("unsupported bigquery field type: %v", fieldType)
- }
-}
-
func bqFieldSchemaToQField(fieldSchema *bigquery.FieldSchema) (*model.QField, error) {
- qValueKind, err := bqFieldTypeToQValueKind(fieldSchema.Type)
+ qValueKind, err := peer_bq.BigQueryTypeToQValueKind(fieldSchema.Type)
if err != nil {
return nil, err
}
@@ -395,7 +361,7 @@ func qValueKindToBqColTypeString(val qvalue.QValueKind) (string, error) {
return "STRING", nil
case qvalue.QValueKindBoolean:
return "BOOL", nil
- case qvalue.QValueKindETime:
+ case qvalue.QValueKindTimestamp:
return "TIMESTAMP", nil
case qvalue.QValueKindBytes:
return "BYTES", nil
diff --git a/flow/e2e/qrep_flow_test.go b/flow/e2e/qrep_flow_test.go
index 6b0f7e6ddc..0a210d1656 100644
--- a/flow/e2e/qrep_flow_test.go
+++ b/flow/e2e/qrep_flow_test.go
@@ -94,10 +94,10 @@ func getOwnersSchema() *model.QRecordSchema {
Fields: []*model.QField{
{Name: "id", Type: qvalue.QValueKindString, Nullable: true},
{Name: "card_id", Type: qvalue.QValueKindString, Nullable: true},
- {Name: "from", Type: qvalue.QValueKindETime, Nullable: true},
+ {Name: "from", Type: qvalue.QValueKindTimestamp, Nullable: true},
{Name: "price", Type: qvalue.QValueKindNumeric, Nullable: true},
- {Name: "created_at", Type: qvalue.QValueKindETime, Nullable: true},
- {Name: "updated_at", Type: qvalue.QValueKindETime, Nullable: true},
+ {Name: "created_at", Type: qvalue.QValueKindTimestamp, Nullable: true},
+ {Name: "updated_at", Type: qvalue.QValueKindTimestamp, Nullable: true},
{Name: "transaction_hash", Type: qvalue.QValueKindBytes, Nullable: true},
{Name: "ownerable_type", Type: qvalue.QValueKindString, Nullable: true},
{Name: "ownerable_id", Type: qvalue.QValueKindString, Nullable: true},
@@ -116,9 +116,9 @@ func getOwnersSchema() *model.QRecordSchema {
{Name: "asset_id", Type: qvalue.QValueKindNumeric, Nullable: true},
{Name: "status", Type: qvalue.QValueKindInt64, Nullable: true},
{Name: "transaction_id", Type: qvalue.QValueKindString, Nullable: true},
- {Name: "settled_at", Type: qvalue.QValueKindETime, Nullable: true},
+ {Name: "settled_at", Type: qvalue.QValueKindTimestamp, Nullable: true},
{Name: "reference_id", Type: qvalue.QValueKindString, Nullable: true},
- {Name: "settle_at", Type: qvalue.QValueKindETime, Nullable: true},
+ {Name: "settle_at", Type: qvalue.QValueKindTimestamp, Nullable: true},
{Name: "settlement_delay_reason", Type: qvalue.QValueKindInt64, Nullable: true},
},
}
diff --git a/flow/go.mod b/flow/go.mod
index 8bacdde3ce..e815d6132b 100644
--- a/flow/go.mod
+++ b/flow/go.mod
@@ -13,7 +13,6 @@ require (
github.com/jackc/pgx/v5 v5.3.1
github.com/jmoiron/sqlx v1.3.5
github.com/joho/godotenv v1.5.1
- github.com/linkedin/goavro v2.1.0+incompatible
github.com/linkedin/goavro/v2 v2.12.0
github.com/sirupsen/logrus v1.9.3
github.com/snowflakedb/gosnowflake v1.6.21
diff --git a/flow/model/column.go b/flow/model/column.go
deleted file mode 100644
index deff60597f..0000000000
--- a/flow/model/column.go
+++ /dev/null
@@ -1,23 +0,0 @@
-package model
-
-// ColumnType is an enum for the column type, which are generic across all connectors.
-const (
- ColumnTypeInt16 = "int16"
- ColumnTypeInt32 = "int32"
- ColumnTypeInt64 = "int64"
- ColumnHexBytes = "bytes"
- ColumnHexBit = "bit"
- ColumnTypeBoolean = "bool"
- ColumnTypeFloat16 = "float16"
- ColumnTypeFloat32 = "float32"
- ColumnTypeFloat64 = "float64"
- ColumnTypeString = "string"
- ColumnTypeNumeric = "numeric"
- ColumnTypeJSON = "json"
- ColumnTypeInterval = "interval"
- ColumnTypeTimestamp = "timestamp"
- ColumnTypeTimeStampWithTimeZone = "timestamptz"
- ColumnTypeTime = "time"
- ColumnTypeTimeWithTimeZone = "timetz"
- ColumnTypeDate = "date"
-)
diff --git a/flow/model/model.go b/flow/model/model.go
index 56d2d09aea..e8e985826a 100644
--- a/flow/model/model.go
+++ b/flow/model/model.go
@@ -1,9 +1,11 @@
package model
import (
+ "encoding/json"
"time"
"github.com/PeerDB-io/peer-flow/generated/protos"
+ "github.com/PeerDB-io/peer-flow/model/qvalue"
)
type PullRecordsRequest struct {
@@ -29,7 +31,31 @@ type Record interface {
// get table name
GetTableName() string
// get columns and values for the record
- GetItems() map[string]interface{}
+ GetItems() RecordItems
+}
+
+type RecordItems map[string]qvalue.QValue
+
+func (r RecordItems) ToJSON() (string, error) {
+ jsonStruct := make(map[string]interface{})
+ for k, v := range r {
+ var err error
+ switch v.Kind {
+ case qvalue.QValueKindTimestamp, qvalue.QValueKindTimestampTZ, qvalue.QValueKindDate,
+ qvalue.QValueKindTime, qvalue.QValueKindTimeTZ:
+ jsonStruct[k], err = v.GoTimeConvert()
+ if err != nil {
+ return "", err
+ }
+ default:
+ jsonStruct[k] = v.Value
+ }
+ }
+ jsonBytes, err := json.Marshal(jsonStruct)
+ if err != nil {
+ return "", err
+ }
+ return string(jsonBytes), nil
}
type InsertRecord struct {
@@ -42,7 +68,7 @@ type InsertRecord struct {
// CommitID is the ID of the commit corresponding to this record.
CommitID int64
// Items is a map of column name to value.
- Items map[string]interface{}
+ Items RecordItems
// unchanged toast columns
UnchangedToastColumns map[string]bool
}
@@ -56,7 +82,7 @@ func (r *InsertRecord) GetTableName() string {
return r.DestinationTableName
}
-func (r *InsertRecord) GetItems() map[string]interface{} {
+func (r *InsertRecord) GetItems() RecordItems {
return r.Items
}
@@ -68,9 +94,9 @@ type UpdateRecord struct {
// Name of the destination table
DestinationTableName string
// OldItems is a map of column name to value.
- OldItems map[string]interface{}
+ OldItems RecordItems
// NewItems is a map of column name to value.
- NewItems map[string]interface{}
+ NewItems RecordItems
// unchanged toast columns
UnchangedToastColumns map[string]bool
}
@@ -85,7 +111,7 @@ func (r *UpdateRecord) GetTableName() string {
return r.DestinationTableName
}
-func (r *UpdateRecord) GetItems() map[string]interface{} {
+func (r *UpdateRecord) GetItems() RecordItems {
return r.NewItems
}
@@ -97,7 +123,7 @@ type DeleteRecord struct {
// CheckPointID is the ID of the record.
CheckPointID int64
// Items is a map of column name to value.
- Items map[string]interface{}
+ Items RecordItems
// unchanged toast columns
UnchangedToastColumns map[string]bool
}
@@ -111,7 +137,7 @@ func (r *DeleteRecord) GetTableName() string {
return r.SourceTableName
}
-func (r *DeleteRecord) GetItems() map[string]interface{} {
+func (r *DeleteRecord) GetItems() RecordItems {
return r.Items
}
diff --git a/flow/model/qrecord_batch.go b/flow/model/qrecord_batch.go
index 3675b92fcb..385a16f980 100644
--- a/flow/model/qrecord_batch.go
+++ b/flow/model/qrecord_batch.go
@@ -3,6 +3,7 @@ package model
import (
"fmt"
"math/big"
+ "time"
"github.com/PeerDB-io/peer-flow/model/qvalue"
"github.com/google/uuid"
@@ -73,7 +74,7 @@ func (src *QRecordBatchCopyFromSource) Values() ([]interface{}, error) {
values := make([]interface{}, numEntries)
for i, qValue := range record.Entries {
switch qValue.Kind {
- case qvalue.QValueKindFloat16, qvalue.QValueKindFloat32:
+ case qvalue.QValueKindFloat32:
v, ok := qValue.Value.(float32)
if !ok {
src.err = fmt.Errorf("invalid float32 value")
@@ -111,15 +112,23 @@ func (src *QRecordBatchCopyFromSource) Values() ([]interface{}, error) {
case qvalue.QValueKindString:
values[i] = qValue.Value.(string)
- case qvalue.QValueKindETime:
- et, ok := qValue.Value.(*qvalue.ExtendedTime)
+ case qvalue.QValueKindTimestamp:
+ t, ok := qValue.Value.(time.Time)
if !ok {
src.err = fmt.Errorf("invalid ExtendedTime value")
return nil, src.err
}
- timestamp := pgtype.Timestamp{Time: et.Time, Valid: true}
+ timestamp := pgtype.Timestamp{Time: t, Valid: true}
values[i] = timestamp
+ case qvalue.QValueKindTimestampTZ:
+ t, ok := qValue.Value.(time.Time)
+ if !ok {
+ src.err = fmt.Errorf("invalid ExtendedTime value")
+ return nil, src.err
+ }
+ timestampTZ := pgtype.Timestamptz{Time: t, Valid: true}
+ values[i] = timestampTZ
case qvalue.QValueKindUUID:
if qValue.Value == nil {
values[i] = nil
diff --git a/flow/model/qvalue/avro_converter.go b/flow/model/qvalue/avro_converter.go
index a391393819..8dcaca542f 100644
--- a/flow/model/qvalue/avro_converter.go
+++ b/flow/model/qvalue/avro_converter.go
@@ -4,9 +4,10 @@ import (
"errors"
"fmt"
"math/big"
+ "time"
"github.com/google/uuid"
- "github.com/linkedin/goavro"
+ "github.com/linkedin/goavro/v2"
)
// QValueKindAvroSchema defines a structure for representing Avro schemas.
@@ -41,7 +42,7 @@ func GetAvroSchemaFromQValueKind(kind QValueKind, nullable bool) (*QValueKindAvr
return &QValueKindAvroSchema{
AvroLogicalSchema: "long",
}, nil
- case QValueKindFloat16, QValueKindFloat32:
+ case QValueKindFloat32:
return &QValueKindAvroSchema{
AvroLogicalSchema: "float",
}, nil
@@ -66,11 +67,10 @@ func GetAvroSchemaFromQValueKind(kind QValueKind, nullable bool) (*QValueKindAvr
"scale": 9,
},
}, nil
- case QValueKindETime:
+ case QValueKindTime, QValueKindTimeTZ, QValueKindDate, QValueKindTimestamp, QValueKindTimestampTZ:
return &QValueKindAvroSchema{
AvroLogicalSchema: map[string]string{
- "type": "long",
- "logicalType": "timestamp-micros",
+ "type": "string",
},
}, nil
case QValueKindJSON, QValueKindArray, QValueKindStruct, QValueKindBit:
@@ -98,21 +98,33 @@ func (c *QValueAvroConverter) ToAvroValue() (interface{}, error) {
switch c.Value.Kind {
case QValueKindInvalid:
return nil, fmt.Errorf("invalid QValueKind: %v", c.Value)
- case QValueKindETime:
- t, err := c.processExtendedTime()
+ case QValueKindTime, QValueKindTimeTZ, QValueKindDate, QValueKindTimestamp, QValueKindTimestampTZ:
+ t, err := c.processGoTime()
if err != nil || t == nil {
return t, err
}
+ if c.TargetDWH == QDWHTypeSnowflake {
+ if c.Nullable {
+ return c.processNullableUnion("string", t.(string))
+ } else {
+ return t.(string), nil
+ }
+ }
if c.Nullable {
- return goavro.Union("long.timestamp-micros", t), nil
+ return goavro.Union("long.timestamp-micros", t.(int64)), nil
} else {
- return t, nil
+ return t.(int64), nil
}
case QValueKindString:
return c.processNullableUnion("string", c.Value.Value)
- case QValueKindFloat16, QValueKindFloat32:
+ case QValueKindFloat32:
return c.processNullableUnion("float", c.Value.Value)
case QValueKindFloat64:
+ if c.TargetDWH == QDWHTypeSnowflake || c.TargetDWH == QDWHTypeBigQuery {
+ if f32Val, ok := c.Value.Value.(float32); ok {
+ return c.processNullableUnion("double", float64(f32Val))
+ }
+ }
return c.processNullableUnion("double", c.Value.Value)
case QValueKindInt16, QValueKindInt32, QValueKindInt64:
return c.processNullableUnion("long", c.Value.Value)
@@ -139,32 +151,23 @@ func (c *QValueAvroConverter) ToAvroValue() (interface{}, error) {
}
}
-func (c *QValueAvroConverter) processExtendedTime() (interface{}, error) {
+func (c *QValueAvroConverter) processGoTime() (interface{}, error) {
if c.Value.Value == nil && c.Nullable {
return nil, nil
}
- et, ok := c.Value.Value.(*ExtendedTime)
+ t, ok := c.Value.Value.(time.Time)
if !ok {
- return nil, fmt.Errorf("invalid ExtendedTime value")
+ return nil, fmt.Errorf("invalid Time value")
}
- if et == nil {
- return nil, nil
- }
-
- switch et.NestedKind.Type {
- case DateTimeKindType, DateKindType, TimeKindType:
- ret := et.Time.UnixMicro()
- // Snowflake has issues with avro timestamp types
- // See: https://stackoverflow.com/questions/66104762/snowflake-date-column-have-incorrect-date-from-avro-file
- if c.TargetDWH == QDWHTypeSnowflake {
- ret = ret / 1000000
- }
- return ret, nil
- default:
- return nil, fmt.Errorf("unsupported ExtendedTimeKindType: %s", et.NestedKind.Type)
+ ret := t.UnixMicro()
+ // Snowflake has issues with avro timestamp types, returning as string form of the int64
+ // See: https://stackoverflow.com/questions/66104762/snowflake-date-column-have-incorrect-date-from-avro-file
+ if c.TargetDWH == QDWHTypeSnowflake {
+ return fmt.Sprint(ret), nil
}
+ return ret, nil
}
func (c *QValueAvroConverter) processNullableUnion(
diff --git a/flow/model/qvalue/etime.go b/flow/model/qvalue/etime.go
deleted file mode 100644
index df5f8bbeb2..0000000000
--- a/flow/model/qvalue/etime.go
+++ /dev/null
@@ -1,72 +0,0 @@
-package qvalue
-
-import (
- "errors"
- "time"
-)
-
-type ExtendedTimeKindType string
-
-const (
- DateTimeKindType ExtendedTimeKindType = "datetime"
- DateKindType ExtendedTimeKindType = "date"
- TimeKindType ExtendedTimeKindType = "time"
-)
-
-type ExtendedTime struct {
- time.Time
- NestedKind NestedKind
-}
-
-type NestedKind struct {
- Type ExtendedTimeKindType
- Format string
-}
-
-var (
- DateTime = NestedKind{
- Type: DateTimeKindType,
- Format: time.RFC3339Nano,
- }
-
- Date = NestedKind{
- Type: DateKindType,
- Format: "2006-01-02",
- }
-
- Time = NestedKind{
- Type: TimeKindType,
- Format: "15:04:05.999999",
- }
-)
-
-func NewExtendedTime(
- t time.Time,
- kindType ExtendedTimeKindType,
- originalFormat string,
-) (*ExtendedTime, error) {
- var nk NestedKind
-
- switch kindType {
- case DateTimeKindType:
- nk = DateTime
- case DateKindType:
- nk = Date
- case TimeKindType:
- nk = Time
- default:
- return nil, errors.New("invalid ExtendedTimeKindType")
- }
-
- if originalFormat != "" {
- nk = NestedKind{
- Type: nk.Type,
- Format: originalFormat,
- }
- }
-
- return &ExtendedTime{
- Time: t,
- NestedKind: nk,
- }, nil
-}
diff --git a/flow/model/qvalue/kind.go b/flow/model/qvalue/kind.go
index 032bb63153..be1545d30b 100644
--- a/flow/model/qvalue/kind.go
+++ b/flow/model/qvalue/kind.go
@@ -3,21 +3,24 @@ package qvalue
type QValueKind string
const (
- QValueKindInvalid QValueKind = "invalid"
- QValueKindFloat16 QValueKind = "float16"
- QValueKindFloat32 QValueKind = "float32"
- QValueKindFloat64 QValueKind = "float64"
- QValueKindInt16 QValueKind = "int16"
- QValueKindInt32 QValueKind = "int32"
- QValueKindInt64 QValueKind = "int64"
- QValueKindBoolean QValueKind = "bool"
- QValueKindArray QValueKind = "array"
- QValueKindStruct QValueKind = "struct"
- QValueKindString QValueKind = "string"
- QValueKindETime QValueKind = "extended_time"
- QValueKindNumeric QValueKind = "numeric"
- QValueKindBytes QValueKind = "bytes"
- QValueKindUUID QValueKind = "uuid"
- QValueKindJSON QValueKind = "json"
- QValueKindBit QValueKind = "bit"
+ QValueKindInvalid QValueKind = "invalid"
+ QValueKindFloat32 QValueKind = "float32"
+ QValueKindFloat64 QValueKind = "float64"
+ QValueKindInt16 QValueKind = "int16"
+ QValueKindInt32 QValueKind = "int32"
+ QValueKindInt64 QValueKind = "int64"
+ QValueKindBoolean QValueKind = "bool"
+ QValueKindArray QValueKind = "array"
+ QValueKindStruct QValueKind = "struct"
+ QValueKindString QValueKind = "string"
+ QValueKindTimestamp QValueKind = "timestamp"
+ QValueKindTimestampTZ QValueKind = "timestamptz"
+ QValueKindDate QValueKind = "date"
+ QValueKindTime QValueKind = "time"
+ QValueKindTimeTZ QValueKind = "timetz"
+ QValueKindNumeric QValueKind = "numeric"
+ QValueKindBytes QValueKind = "bytes"
+ QValueKindUUID QValueKind = "uuid"
+ QValueKindJSON QValueKind = "json"
+ QValueKindBit QValueKind = "bit"
)
diff --git a/flow/model/qvalue/qvalue.go b/flow/model/qvalue/qvalue.go
index 5be793ae9e..cf0b6f11bc 100644
--- a/flow/model/qvalue/qvalue.go
+++ b/flow/model/qvalue/qvalue.go
@@ -3,9 +3,11 @@ package qvalue
import (
"bytes"
"encoding/json"
+ "fmt"
"math/big"
"reflect"
"strconv"
+ "time"
"github.com/google/uuid"
)
@@ -19,8 +21,6 @@ func (q *QValue) Equals(other *QValue) bool {
switch q.Kind {
case QValueKindInvalid:
return false // both are invalid we always return false
- case QValueKindFloat16:
- return compareFloat32(q.Value, other.Value)
case QValueKindFloat32:
return compareFloat32(q.Value, other.Value)
case QValueKindFloat64:
@@ -39,8 +39,10 @@ func (q *QValue) Equals(other *QValue) bool {
return compareStruct(q.Value, other.Value)
case QValueKindString:
return compareString(q.Value, other.Value)
- case QValueKindETime:
- return compareETime(q.Value, other.Value)
+ // all internally represented as a Golang time.Time
+ case QValueKindTime, QValueKindTimeTZ, QValueKindDate,
+ QValueKindTimestamp, QValueKindTimestampTZ:
+ return compareGoTime(q.Value, other.Value)
case QValueKindNumeric:
return compareNumeric(q.Value, other.Value)
case QValueKindBytes:
@@ -56,6 +58,22 @@ func (q *QValue) Equals(other *QValue) bool {
return false
}
+func (q *QValue) GoTimeConvert() (string, error) {
+ if q.Kind == QValueKindTime {
+ return q.Value.(time.Time).Format("15:04:05.999999"), nil
+ } else if q.Kind == QValueKindTimeTZ {
+ return q.Value.(time.Time).Format("15:04:05.999999-0700"), nil
+ } else if q.Kind == QValueKindDate {
+ return q.Value.(time.Time).Format("2006-01-02"), nil
+ } else if q.Kind == QValueKindTimestamp {
+ return q.Value.(time.Time).Format("2006-01-02 15:04:05.999999"), nil
+ } else if q.Kind == QValueKindTimestampTZ {
+ return q.Value.(time.Time).Format("2006-01-02 15:04:05.999999-0700"), nil
+ } else {
+ return "", fmt.Errorf("unsupported QValueKind: %s", q.Kind)
+ }
+}
+
func compareInt16(value1, value2 interface{}) bool {
int1, ok1 := getInt16(value1)
int2, ok2 := getInt16(value2)
@@ -86,9 +104,9 @@ func compareFloat64(value1, value2 interface{}) bool {
return ok1 && ok2 && float1 == float2
}
-func compareETime(value1, value2 interface{}) bool {
- et1, ok1 := value1.(*ExtendedTime)
- et2, ok2 := value2.(*ExtendedTime)
+func compareGoTime(value1, value2 interface{}) bool {
+ et1, ok1 := value1.(time.Time)
+ et2, ok2 := value2.(time.Time)
if !ok1 || !ok2 {
return false
@@ -96,8 +114,8 @@ func compareETime(value1, value2 interface{}) bool {
// TODO: this is a hack, we should be comparing the actual time values
// currently this is only used for testing so that is OK.
- t1 := et1.Time.UnixMilli() / 1000
- t2 := et2.Time.UnixMilli() / 1000
+ t1 := et1.UnixMicro()
+ t2 := et2.UnixMicro()
return t1 == t2
}
diff --git a/flow/workflows/setup_flow.go b/flow/workflows/setup_flow.go
index b5db2b8968..0131bde450 100644
--- a/flow/workflows/setup_flow.go
+++ b/flow/workflows/setup_flow.go
@@ -203,11 +203,12 @@ func (s *SetupFlowExecution) fetchTableSchemaAndSetupNormalizedTables(
if err := fSrcTableSchema.Get(ctx, &srcTableSchema); err != nil {
return nil, fmt.Errorf("failed to fetch schema for source table %s: %w", srcTableName, err)
}
- s.logger.Info("fetched schema for table %s for peer flow %s ", srcTableSchema, s.PeerFlowName)
+ s.logger.Info(fmt.Sprintf("fetched schema for table %s for peer flow %s ", srcTableSchema, s.PeerFlowName))
tableNameSchemaMapping[flowConnectionConfigs.TableNameMapping[srcTableName]] = srcTableSchema
- s.logger.Info("setting up normalized table for table %s for peer flow - ", srcTableSchema, s.PeerFlowName)
+ s.logger.Info(fmt.Sprintf("setting up normalized table for table %s for peer flow - %s",
+ srcTableName, s.PeerFlowName))
// now setup the normalized tables on the destination peer
setupConfig := &protos.SetupNormalizedTableInput{
diff --git a/images/banner.jpg b/images/banner.jpg
new file mode 100755
index 0000000000..6037c9e7bb
Binary files /dev/null and b/images/banner.jpg differ
diff --git a/images/logo-light-transparent_copy_2.png b/images/logo-light-transparent_copy_2.png
deleted file mode 100644
index 81b2021dca..0000000000
Binary files a/images/logo-light-transparent_copy_2.png and /dev/null differ
diff --git a/nexus/Cargo.lock b/nexus/Cargo.lock
index 4f41a9e804..7e78f44185 100644
--- a/nexus/Cargo.lock
+++ b/nexus/Cargo.lock
@@ -34,6 +34,7 @@ dependencies = [
"async-trait",
"catalog",
"flow-rs",
+ "lazy_static",
"pem 1.1.1",
"pt",
"serde_json",
diff --git a/nexus/analyzer/Cargo.toml b/nexus/analyzer/Cargo.toml
index 626753becc..1680d9dcb9 100644
--- a/nexus/analyzer/Cargo.toml
+++ b/nexus/analyzer/Cargo.toml
@@ -10,6 +10,7 @@ anyhow = "1.0"
async-trait = "0.1"
catalog = { path = "../catalog" }
flow-rs = { path = "../flow-rs" }
+lazy_static = "1.4"
pem = "1.1.0"
pt = { path = "../pt" }
sqlparser = { path = "../sqlparser-rs", features = ["visitor"] }
diff --git a/nexus/analyzer/src/lib.rs b/nexus/analyzer/src/lib.rs
index 66b09cec71..3cb13e3f04 100644
--- a/nexus/analyzer/src/lib.rs
+++ b/nexus/analyzer/src/lib.rs
@@ -11,12 +11,11 @@ use flow_rs::{FlowJob, FlowJobTableMapping, QRepFlowJob};
use pt::peers::{
peer::Config, BigqueryConfig, DbType, MongoConfig, Peer, PostgresConfig, SnowflakeConfig,
};
-use serde_json::Number;
+use qrep::process_options;
+use sqlparser::ast::CreateMirror::{Select, CDC};
use sqlparser::ast::{visit_relations, visit_statements, FetchDirection, SqlOption, Statement};
-use sqlparser::ast::{
- CreateMirror::{Select, CDC},
- Value,
-};
+
+mod qrep;
pub trait StatementAnalyzer {
type Output;
@@ -107,76 +106,6 @@ pub enum PeerDDL {
},
}
-impl PeerDDLAnalyzer {
- fn parse_string_for_options(
- raw_options: &HashMap<&str, &Value>,
- processed_options: &mut HashMap,
- key: &str,
- is_required: bool,
- accepted_values: Option<&[&str]>,
- ) -> anyhow::Result<()> {
- if raw_options.get(key).is_none() {
- if is_required {
- anyhow::bail!("{} is required", key);
- } else {
- Ok(())
- }
- } else {
- let raw_value = *raw_options.get(key).unwrap();
- match raw_value {
- sqlparser::ast::Value::SingleQuotedString(str) => {
- if accepted_values.is_some() {
- let accepted_values = accepted_values.unwrap();
- if !accepted_values.contains(&str.as_str()) {
- anyhow::bail!("{} must be one of {:?}", key, accepted_values);
- }
- }
- processed_options
- .insert(key.to_string(), serde_json::Value::String(str.clone()));
- Ok(())
- }
- _ => {
- anyhow::bail!("invalid value for {}", key);
- }
- }
- }
- }
-
- fn parse_number_for_options(
- raw_options: &HashMap<&str, &Value>,
- processed_options: &mut HashMap,
- key: &str,
- min_value: u32,
- default_value: u32,
- ) -> anyhow::Result<()> {
- if raw_options.get(key).is_none() {
- processed_options.insert(
- key.to_string(),
- serde_json::Value::Number(Number::from(default_value)),
- );
- Ok(())
- } else {
- let raw_value = *raw_options.get(key).unwrap();
- match raw_value {
- sqlparser::ast::Value::Number(str, _) => {
- let value = str.parse::()?;
- if value < min_value {
- anyhow::bail!("{} must be greater than {}", key, min_value - 1);
- }
- processed_options.insert(
- key.to_string(),
- serde_json::Value::Number(Number::from(value)),
- );
- Ok(())
- }
- _ => {
- anyhow::bail!("invalid value for {}", key);
- }
- }
- }
- }
-}
-
impl StatementAnalyzer for PeerDDLAnalyzer {
type Output = Option;
@@ -234,119 +163,7 @@ impl StatementAnalyzer for PeerDDLAnalyzer {
raw_options.insert(&option.name.value as &str, &option.value);
}
- let mut processed_options = HashMap::new();
-
- // processing options that are REQUIRED and take a string value.
- for key in [
- "destination_table_name",
- "watermark_column",
- "watermark_table_name",
- ] {
- PeerDDLAnalyzer::parse_string_for_options(
- &raw_options,
- &mut processed_options,
- key,
- true,
- None,
- )?;
- }
- PeerDDLAnalyzer::parse_string_for_options(
- &raw_options,
- &mut processed_options,
- "mode",
- true,
- Some(&["append", "upsert"]),
- )?;
- // processing options that are OPTIONAL and take a string value.
- PeerDDLAnalyzer::parse_string_for_options(
- &raw_options,
- &mut processed_options,
- "unique_key_columns",
- false,
- None,
- )?;
- PeerDDLAnalyzer::parse_string_for_options(
- &raw_options,
- &mut processed_options,
- "sync_data_format",
- false,
- Some(&["default", "avro"]),
- )?;
- PeerDDLAnalyzer::parse_string_for_options(
- &raw_options,
- &mut processed_options,
- "staging_path",
- false,
- None,
- )?;
- // processing options that are OPTIONAL and take a number value which a minimum and default value.
- PeerDDLAnalyzer::parse_number_for_options(
- &raw_options,
- &mut processed_options,
- "parallelism",
- 1,
- 2,
- )?;
- PeerDDLAnalyzer::parse_number_for_options(
- &raw_options,
- &mut processed_options,
- "refresh_interval",
- 10,
- 10,
- )?;
- PeerDDLAnalyzer::parse_number_for_options(
- &raw_options,
- &mut processed_options,
- "batch_size_int",
- 1,
- 10000,
- )?;
- PeerDDLAnalyzer::parse_number_for_options(
- &raw_options,
- &mut processed_options,
- "batch_duration_timestamp",
- 1,
- 60,
- )?;
-
- if !processed_options.contains_key("sync_data_format") {
- processed_options.insert(
- "sync_data_format".to_string(),
- serde_json::Value::String("default".to_string()),
- );
- }
-
- // unique_key_columns should only be specified if mode is upsert
- if processed_options.contains_key("unique_key_columns")
- ^ (processed_options.get("mode").unwrap() == "upsert")
- {
- if processed_options.get("mode").unwrap() == "upsert" {
- anyhow::bail!(
- "unique_key_columns should be specified if mode is upsert"
- );
- } else {
- anyhow::bail!(
- "mode should be upsert if unique_key_columns is specified"
- );
- }
- }
-
- if processed_options.contains_key("unique_key_columns") {
- processed_options.insert(
- "unique_key_columns".to_string(),
- serde_json::Value::Array(
- processed_options
- .get("unique_key_columns")
- .unwrap()
- .as_str()
- .unwrap()
- .split(',')
- .map(|s| serde_json::Value::String(s.to_string()))
- .collect(),
- ),
- );
- }
-
+ let processed_options = process_options(raw_options)?;
let qrep_flow_job = QRepFlowJob {
name: select.mirror_name.to_string().to_lowercase(),
source_peer: select.source_peer.to_string().to_lowercase(),
diff --git a/nexus/analyzer/src/qrep.rs b/nexus/analyzer/src/qrep.rs
new file mode 100644
index 0000000000..46653cb9a0
--- /dev/null
+++ b/nexus/analyzer/src/qrep.rs
@@ -0,0 +1,167 @@
+use std::collections::HashMap;
+
+use serde_json::Value;
+use sqlparser::ast::Value as SqlValue;
+
+enum QRepOptionType {
+ String {
+ name: &'static str,
+ default_val: Option<&'static str>,
+ required: bool,
+ accepted_values: Option>,
+ },
+ Int {
+ name: &'static str,
+ min_value: Option,
+ default_value: u32,
+ required: bool,
+ },
+ StringArray {
+ name: &'static str,
+ },
+}
+
+lazy_static::lazy_static! {
+ static ref QREP_OPTIONS: Vec = {
+ vec![
+ QRepOptionType::String {
+ name: "destination_table_name",
+ default_val: None,
+ required: true,
+ accepted_values: None,
+ },
+ QRepOptionType::String {
+ name: "watermark_column",
+ default_val: None,
+ required: true,
+ accepted_values: None,
+ },
+ QRepOptionType::String {
+ name: "watermark_table_name",
+ default_val: None,
+ required: true,
+ accepted_values: None,
+ },
+ QRepOptionType::String {
+ name: "mode",
+ default_val: Some("append"),
+ required: false,
+ accepted_values: Some(vec!["upsert", "append"]),
+ },
+ QRepOptionType::StringArray {
+ name: "unique_key_columns",
+ },
+ QRepOptionType::String {
+ name: "sync_data_format",
+ default_val: Some("default"),
+ required: false,
+ accepted_values: Some(vec!["default", "avro"]),
+ },
+ QRepOptionType::String {
+ name: "staging_path",
+ default_val: None,
+ required: false,
+ accepted_values: None,
+ },
+ QRepOptionType::Int {
+ name: "parallelism",
+ min_value: Some(1),
+ default_value: 2,
+ required: false,
+ },
+ QRepOptionType::Int {
+ name: "refresh_interval",
+ min_value: Some(10),
+ default_value: 10,
+ required: false,
+ },
+ QRepOptionType::Int {
+ name: "batch_size_int",
+ min_value: Some(1),
+ default_value: 1000,
+ required: false,
+ },
+ QRepOptionType::Int {
+ name: "batch_duration_timestamp",
+ min_value: Some(1),
+ default_value: 60,
+ required: false,
+ },
+ ]
+ };
+}
+
+pub fn process_options(
+ raw_opts: HashMap<&str, &SqlValue>,
+) -> anyhow::Result> {
+ let mut opts: HashMap = HashMap::new();
+
+ for opt_type in &*QREP_OPTIONS {
+ match opt_type {
+ QRepOptionType::String {
+ name,
+ default_val,
+ required,
+ accepted_values,
+ } => {
+ if let Some(raw_value) = raw_opts.get(*name) {
+ if let SqlValue::SingleQuotedString(str) = raw_value {
+ if let Some(values) = accepted_values {
+ if !values.contains(&str.as_str()) {
+ anyhow::bail!("{} must be one of {:?}", name, values);
+ }
+ }
+ opts.insert((*name).to_string(), Value::String(str.clone()));
+ } else {
+ anyhow::bail!("Invalid value for {}", name);
+ }
+ } else if *required {
+ anyhow::bail!("{} is required", name);
+ } else if let Some(default) = default_val {
+ opts.insert((*name).to_string(), Value::String(default.to_string()));
+ }
+ }
+ QRepOptionType::Int {
+ name,
+ min_value,
+ default_value,
+ required,
+ } => {
+ if let Some(raw_value) = raw_opts.get(*name) {
+ if let SqlValue::Number(num_str, _) = raw_value {
+ let num = num_str.parse::()?;
+ if let Some(min) = min_value {
+ if num < *min {
+ anyhow::bail!("{} must be greater than {}", name, min);
+ }
+ }
+ opts.insert((*name).to_string(), Value::Number(num.into()));
+ } else {
+ anyhow::bail!("Invalid value for {}", name);
+ }
+ } else if *required {
+ anyhow::bail!("{} is required", name);
+ } else {
+ let v = *default_value;
+ opts.insert((*name).to_string(), Value::Number(v.into()));
+ }
+ }
+ QRepOptionType::StringArray { name } => {
+ // read it as a string and split on comma
+ if let Some(raw_value) = raw_opts.get(*name) {
+ if let SqlValue::SingleQuotedString(str) = raw_value {
+ let values: Vec = str
+ .split(',')
+ .map(|s| Value::String(s.trim().to_string()))
+ .collect();
+ opts.insert((*name).to_string(), Value::Array(values));
+ } else {
+ anyhow::bail!("Invalid value for {}", name);
+ }
+ }
+ }
+ }
+ }
+
+ Ok(opts)
+}
diff --git a/nexus/flow-rs/src/lib.rs b/nexus/flow-rs/src/lib.rs
index 4906d05beb..70bddd9cd9 100644
--- a/nexus/flow-rs/src/lib.rs
+++ b/nexus/flow-rs/src/lib.rs
@@ -31,7 +31,7 @@ pub struct QRepFlowJob {
pub target_peer: String,
pub query_string: String,
pub flow_options: HashMap,
- pub description: String
+ pub description: String,
}
impl FlowHandler {
diff --git a/nexus/peer-bigquery/src/ast.rs b/nexus/peer-bigquery/src/ast.rs
index 6a72f95d75..d6b47fe631 100644
--- a/nexus/peer-bigquery/src/ast.rs
+++ b/nexus/peer-bigquery/src/ast.rs
@@ -3,7 +3,7 @@ use std::ops::ControlFlow;
use sqlparser::ast::Value::Number;
use sqlparser::ast::{
- visit_expressions_mut, visit_function_arg, visit_function_arg_mut, visit_relations_mut,
+ visit_expressions_mut, visit_function_arg_mut, visit_relations_mut,
visit_setexpr_mut, Array, BinaryOperator, DataType, DateTimeField, Expr, Function, FunctionArg,
FunctionArgExpr, Ident, ObjectName, Query, SetExpr, SetOperator, SetQuantifier, TimezoneInfo,
};
diff --git a/nexus/peer-bigquery/src/lib.rs b/nexus/peer-bigquery/src/lib.rs
index 4534ec5564..4ed659f750 100644
--- a/nexus/peer-bigquery/src/lib.rs
+++ b/nexus/peer-bigquery/src/lib.rs
@@ -11,7 +11,11 @@ use peer_cursor::{CursorModification, QueryExecutor, QueryOutput, SchemaRef};
use pgerror::PgError;
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
use pt::peers::BigqueryConfig;
-use sqlparser::ast::{CloseCursor, Expr, FetchDirection, Statement, Value};
+use sqlparser::{
+ ast::{CloseCursor, Expr, FetchDirection, Statement, Value},
+ dialect::GenericDialect,
+ parser,
+};
use stream::{BqRecordStream, BqSchema};
mod ast;
@@ -41,7 +45,7 @@ pub async fn bq_client_from_config(config: BigqueryConfig) -> anyhow::Result anyhow::Result {
+ let sql = "SELECT 1;";
+ let test_stmt = parser::Parser::parse_sql(&GenericDialect {}, sql).unwrap();
+ let _ = self.execute(&test_stmt[0]).await?;
+ Ok(true)
+ }
}
diff --git a/nexus/peer-cursor/src/lib.rs b/nexus/peer-cursor/src/lib.rs
index 37065cf163..7d2525a7df 100644
--- a/nexus/peer-cursor/src/lib.rs
+++ b/nexus/peer-cursor/src/lib.rs
@@ -51,4 +51,6 @@ pub trait QueryExecutor: Send + Sync {
async fn execute(&self, stmt: &Statement) -> PgWireResult;
async fn describe(&self, stmt: &Statement) -> PgWireResult