diff --git a/.travis.yml b/.travis.yml index e41a73b..50806b4 100644 --- a/.travis.yml +++ b/.travis.yml @@ -6,8 +6,9 @@ notifications: language: go go: - - 1.10.x - 1.11.x + - 1.12.x + - 1.13.x env: global: diff --git a/session.go b/session.go index 2694f8d..28a9900 100644 --- a/session.go +++ b/session.go @@ -33,8 +33,8 @@ type SQLBackend interface { } type Backend interface { - sqlbuilder.SQLBuilder db.Database + sqlbuilder.SQLBuilder SetTxOptions(sql.TxOptions) TxOptions() *sql.TxOptions @@ -43,6 +43,8 @@ type Backend interface { type Session interface { Backend + Conn() sqlbuilder.Database + Store(collectionName string) Store ResolveStore(interface{}) Store @@ -53,13 +55,18 @@ type Session interface { Context() context.Context SessionTx(context.Context, func(tx Session) error) error + NewTx(context.Context) (sqlbuilder.Tx, error) + NewSessionTx(context.Context) (Session, error) + + TxCommit() error + TxRollback() error } type session struct { Backend - stores map[string]*store - storesLock sync.Mutex + stores map[string]*store + mu sync.Mutex } // Open connects to a database. @@ -68,7 +75,9 @@ func Open(adapter string, url db.ConnectionURL) (Session, error) { if err != nil { return nil, err } - return New(conn), nil + + sess := New(conn) + return sess, nil } // New returns a new session. @@ -76,6 +85,10 @@ func New(conn Backend) Session { return &session{Backend: conn, stores: make(map[string]*store)} } +func (s *session) Conn() sqlbuilder.Database { + return s.Backend.(sqlbuilder.Database) +} + func (s *session) WithContext(ctx context.Context) Session { var backendCtx Backend switch t := s.Backend.(type) { @@ -122,6 +135,39 @@ func Bind(adapter string, backend SQLBackend) (Session, error) { return &session{Backend: conn, stores: make(map[string]*store)}, nil } +func (s *session) NewTx(ctx context.Context) (sqlbuilder.Tx, error) { + return s.Conn().NewTx(ctx) +} + +func (s *session) NewSessionTx(ctx context.Context) (Session, error) { + tx, err := s.NewTx(ctx) + if err != nil { + return nil, err + } + return &session{ + Backend: tx, + stores: make(map[string]*store), + }, nil +} + +func (s *session) TxCommit() error { + tx, ok := s.Backend.(sqlbuilder.Tx) + if !ok { + return errors.Errorf("bond: session is not a tx") + } + defer tx.Close() + return tx.Commit() +} + +func (s *session) TxRollback() error { + tx, ok := s.Backend.(sqlbuilder.Tx) + if !ok { + return errors.Errorf("bond: session is not a tx") + } + defer tx.Close() + return tx.Rollback() +} + func (s *session) SessionTx(ctx context.Context, fn func(sess Session) error) error { txFn := func(sess sqlbuilder.Tx) error { return fn(&session{ @@ -177,8 +223,8 @@ func (s *session) Store(collectionName string) Store { return &store{session: s} } - s.storesLock.Lock() - defer s.storesLock.Unlock() + s.mu.Lock() + defer s.mu.Unlock() if store, ok := s.stores[collectionName]; ok { return store