forked from Shopify/ghostferry
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathbatch_writer.go
229 lines (194 loc) · 6 KB
/
batch_writer.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
package ghostferry
import (
"fmt"
sql "github.com/Shopify/ghostferry/sqlwrapper"
"github.com/sirupsen/logrus"
)
type BatchWriterVerificationFailed struct {
mismatchedPaginationKeys []uint64
table string
}
func (e BatchWriterVerificationFailed) Error() string {
return fmt.Sprintf("row fingerprints for paginationKeys %v on %v do not match", e.mismatchedPaginationKeys, e.table)
}
type BatchWriter struct {
DB *sql.DB
InlineVerifier *InlineVerifier
StateTracker *StateTracker
DatabaseRewrites map[string]string
TableRewrites map[string]string
WriteRetries int
stmtCache *StmtCache
logger *logrus.Entry
}
func (w *BatchWriter) Initialize() {
w.stmtCache = NewStmtCache()
w.logger = logrus.WithField("tag", "batch_writer")
}
func (w *BatchWriter) WriteRowBatch(batch RowBatch) error {
return WithRetries(w.WriteRetries, 0, w.logger, "write batch to target", func() (err error) {
db := batch.TableSchema().Schema
if targetDbName, exists := w.DatabaseRewrites[db]; exists {
db = targetDbName
}
table := batch.TableSchema().Name
if targetTableName, exists := w.TableRewrites[table]; exists {
table = targetTableName
}
if batch.Size() == 0 {
w.logger.Debugf("ignoring empty row-batch for %s.%s", db, table)
return
}
txInUse := false
tx, dbErr := w.DB.Begin()
if dbErr != nil {
err = fmt.Errorf("unable to begin transaction in BatchWriter: %v", dbErr)
return
}
// make sure the transaction gets abandoned if we didn't commit it
defer func() {
if tx != nil {
w.logger.Debugf("rolling back transaction: %s", err)
tx.Rollback()
}
}()
query, args, dbErr := batch.AsSQLQuery(db, table)
if dbErr != nil {
err = fmt.Errorf("during generating sql batch query: %v", dbErr)
return
}
txUpdated, dbErr := w.queueStatement(tx, query, args)
if dbErr != nil {
err = dbErr
return
}
if txUpdated {
txInUse = true
}
// Note that the state tracker expects us the track based on the original
// database and table names as opposed to the target ones.
stateTableName := batch.TableSchema().String()
switch b := batch.(type) {
case InsertRowBatch:
endPaginationKeypos, txUpdated, insertErr := w.handleInsertRowBatch(tx, b, db, table)
if insertErr != nil {
err = insertErr
return
}
if txUpdated {
txInUse = true
}
if w.StateTracker != nil && endPaginationKeypos != nil{
defer func() {
if err == nil {
w.StateTracker.UpdateLastSuccessfulPaginationKey(stateTableName, endPaginationKeypos)
}
}()
}
}
if batch.IsTableComplete() && w.StateTracker != nil {
query, args, stateErr := w.StateTracker.GetStoreRowCopyDoneSql(stateTableName)
if stateErr != nil {
err = fmt.Errorf("during generating row-copy done: %v", stateErr)
return
}
txUpdated, dbErr := w.queueStatement(tx, query, args)
if dbErr != nil {
err = dbErr
return
}
if txUpdated {
txInUse = true
}
defer func() {
if err == nil {
w.StateTracker.MarkTableAsCompleted(stateTableName)
}
}()
}
if txInUse {
err = tx.Commit()
if err != nil {
err = fmt.Errorf("during row-copy commit (%s): %v", query, err)
} else {
// avoid rolling it back (too late anyways) on function exit
tx = nil
}
} else {
// we never really added any statement to the transaction - no need
// to commit it. This should practically never happen, but let's be
// on the safe side
w.logger.Debug("discarding empty transaction")
}
return
})
}
func (w *BatchWriter) handleInsertRowBatch(tx *sql.Tx, batch InsertRowBatch, db, table string) (endPaginationKeypos *PaginationKeyData, txUpdated bool, err error) {
var startPaginationKeypos *PaginationKeyData
paginationKey := batch.TableSchema().PaginationKey
if paginationKey != nil {
values := batch.Values()
startPaginationKeypos, err = NewPaginationKeyDataFromRow(values[0], paginationKey)
if err != nil {
return
}
endPaginationKeypos, err = NewPaginationKeyDataFromRow(values[len(values)-1], paginationKey)
if err != nil {
return
}
}
if w.InlineVerifier != nil {
mismatches, verfierErr := w.InlineVerifier.CheckFingerprintInline(tx, db, table, batch)
if err != nil {
err = fmt.Errorf("during fingerprint checking for paginationKey %v -> %v: %v", startPaginationKeypos, endPaginationKeypos, verfierErr)
return
}
if mismatches != nil && len(mismatches) > 0 {
err = BatchWriterVerificationFailed{mismatches, batch.TableSchema().String()}
return
}
}
if w.StateTracker != nil && endPaginationKeypos != nil{
// Note that the state tracker expects us the track based on the original
// database and table names as opposed to the target ones.
//
// NOTE: We would update the state within a transaction, but we cannot
// do that in case of transaction failures. Since it's safe to copy rows
// multiple times, it's vital that we commit the transaction before
// updating the state tracker
query, args, stateErr := w.StateTracker.GetStoreRowCopyPositionSql(batch.TableSchema().String(), endPaginationKeypos)
if stateErr != nil {
err = fmt.Errorf("during generating row-copy position for paginationKey %v -> %v: %v", startPaginationKeypos, endPaginationKeypos, stateErr)
return
}
txUpdatedForState, stateErr := w.queueStatement(tx, query, args)
if stateErr != nil {
err = fmt.Errorf("during generating row-copy position for paginationKey %v -> %v: %v", startPaginationKeypos, endPaginationKeypos, stateErr)
return
}
if txUpdatedForState {
txUpdated = true
}
}
return
}
func (w *BatchWriter) queueStatement(tx *sql.Tx, query string, args []interface{}) (txUpdated bool, err error) {
if query == "" {
return
}
stmt, stmtErr := w.stmtCache.StmtFor(w.DB, query)
if stmtErr != nil {
err = stmtErr
return
}
if IncrediblyVerboseLogging {
w.logger.Debugf("Applying copy statements: %s (%v)", query, args)
}
_, err = tx.Stmt(stmt).Exec(args...)
if err != nil {
err = fmt.Errorf("during copy statement: %v", err)
return
}
txUpdated = true
return
}