diff --git a/.gitignore b/.gitignore index 34b0df6..379a88d 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,4 @@ _testmain.go # OSX *.DS_Store +*.db diff --git a/example/server/main.go b/example/server/main.go index 2ad0b6c..bb46810 100644 --- a/example/server/main.go +++ b/example/server/main.go @@ -27,7 +27,7 @@ func init() { func main() { manager := manage.NewDefaultManager() // token store - manager.MapTokenStorage(store.NewMemoryTokenStore(0)) + manager.MustTokenStorage(store.NewMemoryTokenStore()) // client store manager.MapClientStorage(store.NewTestClientStore(&models.Client{ ID: "222222", diff --git a/generates/access.go b/generates/access.go index c22f6cf..64a7fae 100644 --- a/generates/access.go +++ b/generates/access.go @@ -6,7 +6,7 @@ import ( "strconv" "strings" - "github.com/LyricTian/go.uuid" + "github.com/satori/go.uuid" "gopkg.in/oauth2.v3" ) diff --git a/generates/authorize.go b/generates/authorize.go index 8eafdbf..96d1c95 100644 --- a/generates/authorize.go +++ b/generates/authorize.go @@ -5,7 +5,7 @@ import ( "encoding/base64" "strings" - "github.com/LyricTian/go.uuid" + "github.com/satori/go.uuid" "gopkg.in/oauth2.v3" ) diff --git a/manage/manage_test.go b/manage/manage_test.go index ba39ff9..0107ff6 100644 --- a/manage/manage_test.go +++ b/manage/manage_test.go @@ -22,7 +22,7 @@ func TestManager(t *testing.T) { }) Convey("Memory store test", func() { - manager.MapTokenStorage(store.NewMemoryTokenStore(0)) + manager.MustTokenStorage(store.NewMemoryTokenStore()) testManager(manager) }) }) diff --git a/manage/manager.go b/manage/manager.go index d8a74ce..741ba17 100644 --- a/manage/manager.go +++ b/manage/manager.go @@ -1,12 +1,10 @@ package manage import ( - "time" - - "github.com/LyricTian/inject" - "reflect" + "time" + "github.com/codegangsta/inject" "gopkg.in/oauth2.v3" "gopkg.in/oauth2.v3/errors" "gopkg.in/oauth2.v3/generates" diff --git a/server/server_test.go b/server/server_test.go index 9ed8583..6976c4e 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -27,7 +27,7 @@ var ( func init() { manager = manage.NewDefaultManager() - manager.MapTokenStorage(store.NewMemoryTokenStore(0)) + manager.MustTokenStorage(store.NewMemoryTokenStore()) } func clientStore(domain string) oauth2.ClientStore { diff --git a/store/token.go b/store/token.go index e6794b8..5d47564 100644 --- a/store/token.go +++ b/store/token.go @@ -1,170 +1,126 @@ package store import ( - "container/list" - "strconv" - "sync" + "encoding/json" "time" + "github.com/satori/go.uuid" + "github.com/tidwall/buntdb" "gopkg.in/oauth2.v3" + "gopkg.in/oauth2.v3/models" ) // NewMemoryTokenStore Create a token store instance based on memory -// gcInterval Perform garbage collection intervals(The default is 30 seconds) -func NewMemoryTokenStore(gcInterval time.Duration) oauth2.TokenStore { - if gcInterval == 0 { - gcInterval = time.Second * 30 - } - store := &MemoryTokenStore{ - gcInterval: gcInterval, - basicList: list.New(), - data: make(map[string]oauth2.TokenInfo), - access: make(map[string]string), - refresh: make(map[string]string), +func NewMemoryTokenStore() (store oauth2.TokenStore, err error) { + store, err = NewFileTokenStore(":memory:") + return +} + +// NewFileTokenStore Create a token store instance based on file +func NewFileTokenStore(filename string) (store oauth2.TokenStore, err error) { + db, err := buntdb.Open(filename) + if err != nil { + return } - go store.gc() - return store + store = &TokenStore{db: db} + return } -// MemoryTokenStore Memory storage for token -type MemoryTokenStore struct { - gcInterval time.Duration - globalID int64 - lock sync.RWMutex - data map[string]oauth2.TokenInfo - access map[string]string - refresh map[string]string - basicList *list.List - listLock sync.RWMutex +// TokenStore Token storage based on buntdb(https://github.com/tidwall/buntdb) +type TokenStore struct { + db *buntdb.DB } -func (mts *MemoryTokenStore) gc() { - time.AfterFunc(mts.gcInterval, func() { - defer mts.gc() - rmeles := make([]*list.Element, 0, 32) - mts.listLock.RLock() - ele := mts.basicList.Front() - mts.listLock.RUnlock() - for ele != nil { - if rm := mts.gcElement(ele); rm { - rmeles = append(rmeles, ele) +// Create Create and store the new token information +func (ts *TokenStore) Create(info oauth2.TokenInfo) (err error) { + ct := time.Now() + jv, err := json.Marshal(info) + if err != nil { + return + } + basicID := uuid.NewV4().String() + aexp := info.GetAccessExpiresIn() + rexp := aexp + + err = ts.db.Update(func(tx *buntdb.Tx) (err error) { + if refresh := info.GetRefresh(); refresh != "" { + rexp = info.GetRefreshCreateAt().Add(info.GetRefreshExpiresIn()).Sub(ct) + if aexp.Seconds() > rexp.Seconds() { + aexp = rexp + } + _, _, err = tx.Set(refresh, basicID, &buntdb.SetOptions{Expires: true, TTL: rexp}) + if err != nil { + return } - mts.listLock.RLock() - ele = ele.Next() - mts.listLock.RUnlock() } - - for _, e := range rmeles { - mts.listLock.Lock() - mts.basicList.Remove(e) - mts.listLock.Unlock() + _, _, err = tx.Set(basicID, string(jv), &buntdb.SetOptions{Expires: true, TTL: rexp}) + if err != nil { + return } + _, _, err = tx.Set(info.GetAccess(), basicID, &buntdb.SetOptions{Expires: true, TTL: aexp}) + return }) + return } -func (mts *MemoryTokenStore) gcElement(ele *list.Element) (rm bool) { - basicID := ele.Value.(string) - mts.lock.RLock() - ti, ok := mts.data[basicID] - mts.lock.RUnlock() - if !ok { - rm = true +// remove key +func (ts *TokenStore) remove(key string) (err error) { + verr := ts.db.Update(func(tx *buntdb.Tx) (err error) { + _, err = tx.Delete(key) + return + }) + if verr == buntdb.ErrNotFound { return } - ct := time.Now() - if refresh := ti.GetRefresh(); refresh != "" && - ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) { - mts.lock.RLock() - delete(mts.access, ti.GetAccess()) - delete(mts.refresh, refresh) - delete(mts.data, basicID) - mts.lock.RUnlock() - rm = true - } else if ti.GetAccessCreateAt().Add(ti.GetAccessExpiresIn()).Before(ct) { - mts.lock.RLock() - delete(mts.access, ti.GetAccess()) - if refresh := ti.GetRefresh(); refresh == "" { - delete(mts.data, basicID) - rm = true - } - mts.lock.RUnlock() - } + err = verr return } -func (mts *MemoryTokenStore) getBasicID(id int64, info oauth2.TokenInfo) string { - return info.GetClientID() + "_" + strconv.FormatInt(id, 10) +// RemoveByAccess Use the access token to delete the token information +func (ts *TokenStore) RemoveByAccess(access string) (err error) { + err = ts.remove(access) + return } -// Create Create and store the new token information -func (mts *MemoryTokenStore) Create(info oauth2.TokenInfo) (err error) { - mts.lock.Lock() - defer mts.lock.Unlock() - mts.globalID++ - basicID := mts.getBasicID(mts.globalID, info) - mts.data[basicID] = info - mts.access[info.GetAccess()] = basicID - if refresh := info.GetRefresh(); refresh != "" { - mts.refresh[refresh] = basicID - } - - mts.listLock.Lock() - mts.basicList.PushBack(basicID) - mts.listLock.Unlock() +// RemoveByRefresh Use the refresh token to delete the token information +func (ts *TokenStore) RemoveByRefresh(refresh string) (err error) { + err = ts.remove(refresh) return } -// RemoveByAccess Use the access token to delete the token information -func (mts *MemoryTokenStore) RemoveByAccess(access string) (err error) { - mts.lock.RLock() - v, ok := mts.access[access] - if !ok { - mts.lock.RUnlock() +func (ts *TokenStore) get(key string) (ti oauth2.TokenInfo, err error) { + verr := ts.db.View(func(tx *buntdb.Tx) (err error) { + basicID, err := tx.Get(key) + if err != nil { + return + } + jv, err := tx.Get(basicID) + if err != nil { + return + } + var tm models.Token + err = json.Unmarshal([]byte(jv), &tm) + if err != nil { + return + } + ti = &tm + return + }) + if verr == buntdb.ErrNotFound { return } - info := mts.data[v] - mts.lock.RUnlock() - - mts.lock.Lock() - defer mts.lock.Unlock() - delete(mts.access, access) - if refresh := info.GetRefresh(); refresh == "" { - delete(mts.data, v) - } - return -} - -// RemoveByRefresh Use the refresh token to delete the token information -func (mts *MemoryTokenStore) RemoveByRefresh(refresh string) (err error) { - mts.lock.Lock() - defer mts.lock.Unlock() - delete(mts.refresh, refresh) - + err = verr return } // GetByAccess Use the access token for token information data -func (mts *MemoryTokenStore) GetByAccess(access string) (ti oauth2.TokenInfo, err error) { - mts.lock.RLock() - v, ok := mts.access[access] - if !ok { - mts.lock.RUnlock() - return - } - ti = mts.data[v] - mts.lock.RUnlock() +func (ts *TokenStore) GetByAccess(access string) (ti oauth2.TokenInfo, err error) { + ti, err = ts.get(access) return } // GetByRefresh Use the refresh token for token information data -func (mts *MemoryTokenStore) GetByRefresh(refresh string) (ti oauth2.TokenInfo, err error) { - mts.lock.RLock() - v, ok := mts.refresh[refresh] - if !ok { - mts.lock.RUnlock() - return - } - ti = mts.data[v] - mts.lock.RUnlock() +func (ts *TokenStore) GetByRefresh(refresh string) (ti oauth2.TokenInfo, err error) { + ti, err = ts.get(refresh) return } diff --git a/store/token_test.go b/store/token_test.go index 84232fc..ed9707c 100644 --- a/store/token_test.go +++ b/store/token_test.go @@ -4,6 +4,7 @@ import ( "testing" "time" + "gopkg.in/oauth2.v3" "gopkg.in/oauth2.v3/models" "gopkg.in/oauth2.v3/store" @@ -12,78 +13,88 @@ import ( func TestTokenStore(t *testing.T) { Convey("Test memory store", t, func() { - store := store.NewMemoryTokenStore(time.Second * 1) + store, err := store.NewMemoryTokenStore() + So(err, ShouldBeNil) + testToken(store) + }) + + Convey("Test file store", t, func() { + store, err := store.NewFileTokenStore("data.db") + So(err, ShouldBeNil) + testToken(store) + }) +} - Convey("Test access token store", func() { - info := &models.Token{ - ClientID: "1", - UserID: "1_1", - RedirectURI: "http://localhost/", - Scope: "all", - Access: "1_1_1", - AccessCreateAt: time.Now(), - AccessExpiresIn: time.Second * 5, - } - err := store.Create(info) - So(err, ShouldBeNil) +func testToken(store oauth2.TokenStore) { + Convey("Test access token store", func() { + info := &models.Token{ + ClientID: "1", + UserID: "1_1", + RedirectURI: "http://localhost/", + Scope: "all", + Access: "1_1_1", + AccessCreateAt: time.Now(), + AccessExpiresIn: time.Second * 5, + } + err := store.Create(info) + So(err, ShouldBeNil) - ainfo, err := store.GetByAccess(info.GetAccess()) - So(err, ShouldBeNil) - So(ainfo.GetUserID(), ShouldEqual, info.GetUserID()) + ainfo, err := store.GetByAccess(info.GetAccess()) + So(err, ShouldBeNil) + So(ainfo.GetUserID(), ShouldEqual, info.GetUserID()) - err = store.RemoveByAccess(info.GetAccess()) - So(err, ShouldBeNil) + err = store.RemoveByAccess(info.GetAccess()) + So(err, ShouldBeNil) - ainfo, err = store.GetByAccess(info.GetAccess()) - So(err, ShouldBeNil) - So(ainfo, ShouldBeNil) - }) + ainfo, err = store.GetByAccess(info.GetAccess()) + So(err, ShouldBeNil) + So(ainfo, ShouldBeNil) + }) - Convey("Test refresh token store", func() { - info := &models.Token{ - ClientID: "1", - UserID: "1_2", - RedirectURI: "http://localhost/", - Scope: "all", - Access: "1_2_1", - AccessCreateAt: time.Now(), - AccessExpiresIn: time.Second * 5, - Refresh: "1_2_2", - RefreshCreateAt: time.Now(), - RefreshExpiresIn: time.Second * 15, - } - err := store.Create(info) - So(err, ShouldBeNil) + Convey("Test refresh token store", func() { + info := &models.Token{ + ClientID: "1", + UserID: "1_2", + RedirectURI: "http://localhost/", + Scope: "all", + Access: "1_2_1", + AccessCreateAt: time.Now(), + AccessExpiresIn: time.Second * 5, + Refresh: "1_2_2", + RefreshCreateAt: time.Now(), + RefreshExpiresIn: time.Second * 15, + } + err := store.Create(info) + So(err, ShouldBeNil) - rinfo, err := store.GetByRefresh(info.GetRefresh()) - So(err, ShouldBeNil) - So(rinfo.GetUserID(), ShouldEqual, info.GetUserID()) + rinfo, err := store.GetByRefresh(info.GetRefresh()) + So(err, ShouldBeNil) + So(rinfo.GetUserID(), ShouldEqual, info.GetUserID()) - err = store.RemoveByRefresh(info.GetRefresh()) - So(err, ShouldBeNil) + err = store.RemoveByRefresh(info.GetRefresh()) + So(err, ShouldBeNil) - rinfo, err = store.GetByRefresh(info.GetRefresh()) - So(err, ShouldBeNil) - So(rinfo, ShouldBeNil) - }) + rinfo, err = store.GetByRefresh(info.GetRefresh()) + So(err, ShouldBeNil) + So(rinfo, ShouldBeNil) + }) - Convey("Test gc", func() { - info := &models.Token{ - ClientID: "1", - UserID: "1_3", - RedirectURI: "http://localhost/", - Scope: "all", - Access: "1_3_1", - AccessCreateAt: time.Now(), - AccessExpiresIn: time.Second * 1, - } - err := store.Create(info) - So(err, ShouldBeNil) + Convey("Test gc", func() { + info := &models.Token{ + ClientID: "1", + UserID: "1_3", + RedirectURI: "http://localhost/", + Scope: "all", + Access: "1_3_1", + AccessCreateAt: time.Now(), + AccessExpiresIn: time.Second * 1, + } + err := store.Create(info) + So(err, ShouldBeNil) - time.Sleep(time.Second * 1) - ainfo, err := store.GetByRefresh(info.GetAccess()) - So(err, ShouldBeNil) - So(ainfo, ShouldBeNil) - }) + time.Sleep(time.Second * 1) + ainfo, err := store.GetByAccess(info.GetAccess()) + So(err, ShouldBeNil) + So(ainfo, ShouldBeNil) }) }