-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathdriver.go
143 lines (119 loc) · 4.13 KB
/
driver.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
package postgres
import (
"bytes"
"fmt"
"strings"
"github.com/Sirupsen/logrus"
_ "github.com/jackc/pgx/stdlib"
"github.com/jmoiron/sqlx"
"github.com/segment-sources/sqlsource/domain"
"github.com/segment-sources/sqlsource/driver"
)
const chunkSize = 1000000
type tableDescriptionRow struct {
Catalog string `db:"table_catalog"`
SchemaName string `db:"table_schema"`
TableName string `db:"table_name"`
ColumnName string `db:"column_name"`
IsPrimary bool `db:"is_primary_key"`
}
type Postgres struct {
Connection *sqlx.DB
}
func (p *Postgres) Init(c *domain.Config) error {
var extraOptions bytes.Buffer
if len(c.ExtraOptions) > 0 {
extraOptions.WriteRune('?')
extraOptions.WriteString(strings.Join(c.ExtraOptions, "&"))
}
connectionString := fmt.Sprintf(
"postgres://%s:%s@%s:%s/%s%s",
c.Username, c.Password, c.Hostname, c.Port, c.Database, extraOptions.String(),
)
db, err := sqlx.Connect("pgx", connectionString)
if err != nil {
return err
}
p.Connection = db
return nil
}
func (p *Postgres) Scan(t *domain.Table, lastPkValues []interface{}) (driver.SqlRows, error) {
// in most cases whereClause will simply look like "id" > 114, but since the source supports compound PKs
// we must be able to include all PK columns in the query. For example, for a table with 3-column PK:
// a | b | c
// ---+---+---
// 1 | 1 | 1
// 1 | 1 | 2
// 1 | 2 | 1
// 1 | 2 | 2
// 2 | 1 | 1
//
// whereClause selecting records after (1, 1, 1) should look like:
// a > 1 OR a = 1 AND b > 1 OR a = 1 AND b = 1 AND c > 1
whereClause := "true"
if len(lastPkValues) > 0 {
// {"a > 1", "a = 1 AND b > 1", "a = 1 AND b = 1 AND c > 1"}
whereOrList := []string{}
for i, pk := range t.PrimaryKeys {
// {"a = 1", "b = 1", "c > 1"}
choiceAndList := []string{}
for j := 0; j < i; j++ {
choiceAndList = append(choiceAndList, fmt.Sprintf(`"%s" = $%d`, t.PrimaryKeys[j], j+1))
}
choiceAndList = append(choiceAndList, fmt.Sprintf(`"%s" > $%d`, pk, i+1))
whereOrList = append(whereOrList, strings.Join(choiceAndList, " AND "))
}
whereClause = strings.Join(whereOrList, " OR ")
}
orderByList := make([]string, 0, len(t.PrimaryKeys))
for _, column := range t.PrimaryKeys {
orderByList = append(orderByList, fmt.Sprintf(`"%s"`, column))
}
orderByClause := strings.Join(orderByList, ", ")
query := fmt.Sprintf("SELECT %s FROM %q.%q WHERE %s ORDER BY %s LIMIT %d", t.ColumnToSQL(), t.SchemaName,
t.TableName, whereClause, orderByClause, chunkSize)
logger := logrus.WithFields(logrus.Fields{
"query": query,
"args": lastPkValues,
})
logger.Debugf("Executing query")
return p.Connection.Queryx(query, lastPkValues...)
}
func (p *Postgres) Transform(row map[string]interface{}) map[string]interface{} {
return row
}
func (p *Postgres) Describe() (*domain.Description, error) {
describeQuery := `
with o_1 as (SELECT
_s.nspname AS table_schema,
_t.relname AS table_name,
c.conkey AS column_positions
FROM pg_catalog.pg_constraint c
LEFT JOIN pg_catalog.pg_class _t ON c.conrelid = _t.oid
LEFT JOIN pg_catalog.pg_class referenced_table ON c.confrelid = referenced_table.oid
LEFT JOIN pg_catalog.pg_namespace _s ON _t.relnamespace = _s.oid
LEFT JOIN pg_catalog.pg_namespace referenced_schema ON referenced_table.relnamespace = referenced_schema.oid
WHERE c.contype = 'p')
select c.table_catalog, c.table_schema, c.table_name, c.column_name, CASE WHEN c.ordinal_position = ANY(o_1.column_positions) THEN true ELSE false END as "is_primary_key"
FROM o_1 INNER JOIN information_schema.columns c
ON o_1.table_schema = c.table_schema
AND o_1.table_name = c.table_name;
`
res := domain.NewDescription()
rows, err := p.Connection.Queryx(describeQuery)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
row := &tableDescriptionRow{}
if err := rows.StructScan(row); err != nil {
return nil, err
}
res.AddColumn(&domain.Column{Name: row.ColumnName, Schema: row.SchemaName, Table: row.TableName, IsPrimaryKey: row.IsPrimary})
}
if err := rows.Err(); err != nil {
return nil, err
}
return res, nil
}