Skip to content

Commit

Permalink
Merge pull request #33 from upper/sessiontx
Browse files Browse the repository at this point in the history
add Conn, NewTx, NewSessionTx, TxCommit, TxRollback on bond.Session
  • Loading branch information
pkieltyka authored Sep 6, 2019
2 parents 2702bd1 + c93a052 commit 086276d
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 7 deletions.
3 changes: 2 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ notifications:
language: go

go:
- 1.10.x
- 1.11.x
- 1.12.x
- 1.13.x

env:
global:
Expand Down
58 changes: 52 additions & 6 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ type SQLBackend interface {
}

type Backend interface {
sqlbuilder.SQLBuilder
db.Database
sqlbuilder.SQLBuilder

SetTxOptions(sql.TxOptions)
TxOptions() *sql.TxOptions
Expand All @@ -43,6 +43,8 @@ type Backend interface {
type Session interface {
Backend

Conn() sqlbuilder.Database

Store(collectionName string) Store
ResolveStore(interface{}) Store

Expand All @@ -53,13 +55,18 @@ type Session interface {
Context() context.Context

SessionTx(context.Context, func(tx Session) error) error
NewTx(context.Context) (sqlbuilder.Tx, error)
NewSessionTx(context.Context) (Session, error)

TxCommit() error
TxRollback() error
}

type session struct {
Backend

stores map[string]*store
storesLock sync.Mutex
stores map[string]*store
mu sync.Mutex
}

// Open connects to a database.
Expand All @@ -68,14 +75,20 @@ func Open(adapter string, url db.ConnectionURL) (Session, error) {
if err != nil {
return nil, err
}
return New(conn), nil

sess := New(conn)
return sess, nil
}

// New returns a new session.
func New(conn Backend) Session {
return &session{Backend: conn, stores: make(map[string]*store)}
}

func (s *session) Conn() sqlbuilder.Database {
return s.Backend.(sqlbuilder.Database)
}

func (s *session) WithContext(ctx context.Context) Session {
var backendCtx Backend
switch t := s.Backend.(type) {
Expand Down Expand Up @@ -122,6 +135,39 @@ func Bind(adapter string, backend SQLBackend) (Session, error) {
return &session{Backend: conn, stores: make(map[string]*store)}, nil
}

func (s *session) NewTx(ctx context.Context) (sqlbuilder.Tx, error) {
return s.Conn().NewTx(ctx)
}

func (s *session) NewSessionTx(ctx context.Context) (Session, error) {
tx, err := s.NewTx(ctx)
if err != nil {
return nil, err
}
return &session{
Backend: tx,
stores: make(map[string]*store),
}, nil
}

func (s *session) TxCommit() error {
tx, ok := s.Backend.(sqlbuilder.Tx)
if !ok {
return errors.Errorf("bond: session is not a tx")
}
defer tx.Close()
return tx.Commit()
}

func (s *session) TxRollback() error {
tx, ok := s.Backend.(sqlbuilder.Tx)
if !ok {
return errors.Errorf("bond: session is not a tx")
}
defer tx.Close()
return tx.Rollback()
}

func (s *session) SessionTx(ctx context.Context, fn func(sess Session) error) error {
txFn := func(sess sqlbuilder.Tx) error {
return fn(&session{
Expand Down Expand Up @@ -177,8 +223,8 @@ func (s *session) Store(collectionName string) Store {
return &store{session: s}
}

s.storesLock.Lock()
defer s.storesLock.Unlock()
s.mu.Lock()
defer s.mu.Unlock()

if store, ok := s.stores[collectionName]; ok {
return store
Expand Down

0 comments on commit 086276d

Please sign in to comment.