-
Notifications
You must be signed in to change notification settings - Fork 38
/
Copy pathconn.go
115 lines (97 loc) · 2.4 KB
/
conn.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
package testdb
import (
"database/sql/driver"
"errors"
)
type conn struct {
queries map[string]query
queryFunc func(query string, args []driver.Value) (driver.Rows, error)
execFunc func(query string, args []driver.Value) (driver.Result, error)
beginFunc func() (driver.Tx, error)
commitFunc func() error
rollbackFunc func() error
}
func newConn() *conn {
return &conn{
queries: make(map[string]query),
}
}
func (c *conn) Prepare(query string) (driver.Stmt, error) {
s := new(stmt)
if c.queryFunc != nil {
s.queryFunc = func(args []driver.Value) (driver.Rows, error) {
return c.queryFunc(query, args)
}
}
if c.execFunc != nil {
s.execFunc = func(args []driver.Value) (driver.Result, error) {
return c.execFunc(query, args)
}
}
if q, ok := d.conn.queries[getQueryHash(query)]; ok {
if s.queryFunc == nil && q.rows != nil {
s.queryFunc = func(args []driver.Value) (driver.Rows, error) {
if q.rows != nil {
if rows, ok := q.rows.(*rows); ok {
return rows.clone(), nil
}
return q.rows, nil
}
return nil, q.err
}
}
if s.execFunc == nil && q.result != nil {
s.execFunc = func(args []driver.Value) (driver.Result, error) {
if q.result != nil {
return q.result, nil
}
return nil, q.err
}
}
}
if s.queryFunc == nil && s.execFunc == nil {
return new(stmt), errors.New("Query not stubbed: " + query)
}
return s, nil
}
func (*conn) Close() error {
return nil
}
func (c *conn) Begin() (driver.Tx, error) {
if c.beginFunc != nil {
return c.beginFunc()
}
t := &Tx{}
if c.commitFunc != nil {
t.SetCommitFunc(c.commitFunc)
}
if c.rollbackFunc != nil {
t.SetRollbackFunc(c.rollbackFunc)
}
return t, nil
}
func (c *conn) Query(query string, args []driver.Value) (driver.Rows, error) {
if c.queryFunc != nil {
return c.queryFunc(query, args)
}
if q, ok := d.conn.queries[getQueryHash(query)]; ok {
if rows, ok := q.rows.(*rows); ok {
return rows.clone(), q.err
}
return q.rows, q.err
}
return nil, errors.New("Query not stubbed: " + query)
}
func (c *conn) Exec(query string, args []driver.Value) (driver.Result, error) {
if c.execFunc != nil {
return c.execFunc(query, args)
}
if q, ok := d.conn.queries[getQueryHash(query)]; ok {
if q.result != nil {
return q.result, nil
} else if q.err != nil {
return nil, q.err
}
}
return nil, errors.New("Exec call not stubbed: " + query)
}