Skip to content

Commit

Permalink
Merge pull request dexidp#602 from ericchiang/dev-add-garbage-collect…
Browse files Browse the repository at this point in the history
…-method-to-storage

dev branch: add garbage collect method to storage
  • Loading branch information
ericchiang authored Oct 13, 2016
2 parents 13554ee + 449f34e commit 5bec61d
Show file tree
Hide file tree
Showing 20 changed files with 265 additions and 357 deletions.
3 changes: 2 additions & 1 deletion cmd/dex/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"net/http"

"github.com/spf13/cobra"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
yaml "gopkg.in/yaml.v2"
Expand Down Expand Up @@ -124,7 +125,7 @@ func serve(cmd *cobra.Command, args []string) error {
EnablePasswordDB: c.EnablePasswordDB,
}

serv, err := server.NewServer(serverConfig)
serv, err := server.NewServer(context.Background(), serverConfig)
if err != nil {
return fmt.Errorf("initializing server: %v", err)
}
Expand Down
5 changes: 3 additions & 2 deletions server/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
s.renderError(w, http.StatusInternalServerError, err.Type, err.Description)
return
}
authReq.Expiry = s.now().Add(time.Minute * 30)
if err := s.storage.CreateAuthRequest(authReq); err != nil {
log.Printf("Failed to create authorization request: %v", err)
s.renderError(w, http.StatusInternalServerError, errServerError, "")
Expand Down Expand Up @@ -342,7 +343,7 @@ func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) {
}

func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authReq storage.AuthRequest) {
if authReq.Expiry.After(s.now()) {
if s.now().After(authReq.Expiry) {
s.renderError(w, http.StatusBadRequest, errInvalidRequest, "Authorization request period has expired.")
return
}
Expand Down Expand Up @@ -373,7 +374,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe
Nonce: authReq.Nonce,
Scopes: authReq.Scopes,
Claims: authReq.Claims,
Expiry: s.now().Add(time.Minute * 5),
Expiry: s.now().Add(time.Minute * 30),
RedirectURI: authReq.RedirectURI,
}
if err := s.storage.CreateAuthCode(code); err != nil {
Expand Down
7 changes: 6 additions & 1 deletion server/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@ import (
"net/http"
"net/http/httptest"
"testing"

"golang.org/x/net/context"
)

func TestHandleHealth(t *testing.T) {
httpServer, server := newTestServer(t, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

httpServer, server := newTestServer(t, ctx, nil)
defer httpServer.Close()

rr := httptest.NewRecorder()
Expand Down
40 changes: 17 additions & 23 deletions server/rotation.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,40 +56,34 @@ type keyRotater struct {
storage.Storage

strategy rotationStrategy
cancel context.CancelFunc

now func() time.Time
now func() time.Time
}

func storageWithKeyRotation(s storage.Storage, strategy rotationStrategy, now func() time.Time) storage.Storage {
if now == nil {
now = time.Now
}
ctx, cancel := context.WithCancel(context.Background())
rotater := keyRotater{s, strategy, cancel, now}
// startKeyRotation begins key rotation in a new goroutine, closing once the context is canceled.
//
// The method blocks until after the first attempt to rotate keys has completed. That way
// healthy storages will return from this call with valid keys.
func startKeyRotation(ctx context.Context, s storage.Storage, strategy rotationStrategy, now func() time.Time) {
rotater := keyRotater{s, strategy, now}

// Try to rotate immediately so properly configured storages will return a
// storage with keys.
// Try to rotate immediately so properly configured storages will have keys.
if err := rotater.rotate(); err != nil {
log.Printf("failed to rotate keys: %v", err)
}

go func() {
select {
case <-ctx.Done():
return
case <-time.After(time.Second * 30):
if err := rotater.rotate(); err != nil {
log.Printf("failed to rotate keys: %v", err)
for {
select {
case <-ctx.Done():
return
case <-time.After(strategy.period):
if err := rotater.rotate(); err != nil {
log.Printf("failed to rotate keys: %v", err)
}
}
}
}()
return rotater
}

func (k keyRotater) Close() error {
k.cancel()
return k.Storage.Close()
return
}

func (k keyRotater) rotate() error {
Expand Down
41 changes: 30 additions & 11 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"time"

"golang.org/x/crypto/bcrypt"
"golang.org/x/net/context"

"github.com/gorilla/mux"

Expand Down Expand Up @@ -48,6 +49,8 @@ type Config struct {
RotateKeysAfter time.Duration // Defaults to 6 hours.
IDTokensValidFor time.Duration // Defaults to 24 hours

GCFrequency time.Duration // Defaults to 5 minutes

// If specified, the server will use this function for determining time.
Now func() time.Time

Expand Down Expand Up @@ -87,14 +90,14 @@ type Server struct {
}

// NewServer constructs a server from the provided config.
func NewServer(c Config) (*Server, error) {
return newServer(c, defaultRotationStrategy(
func NewServer(ctx context.Context, c Config) (*Server, error) {
return newServer(ctx, c, defaultRotationStrategy(
value(c.RotateKeysAfter, 6*time.Hour),
value(c.IDTokensValidFor, 24*time.Hour),
))
}

func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) {
func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) (*Server, error) {
issuerURL, err := url.Parse(c.Issuer)
if err != nil {
return nil, fmt.Errorf("server: can't parse issuer URL")
Expand Down Expand Up @@ -138,14 +141,9 @@ func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) {
}

s := &Server{
issuerURL: *issuerURL,
connectors: make(map[string]Connector),
storage: newKeyCacher(
storageWithKeyRotation(
c.Storage, rotationStrategy, now,
),
now,
),
issuerURL: *issuerURL,
connectors: make(map[string]Connector),
storage: newKeyCacher(c.Storage, now),
supportedResponseTypes: supported,
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
skipApproval: c.SkipApprovalScreen,
Expand Down Expand Up @@ -179,6 +177,9 @@ func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) {
handleFunc("/healthz", s.handleHealth)
s.mux = r

startKeyRotation(ctx, c.Storage, rotationStrategy, now)
startGarbageCollection(ctx, c.Storage, value(c.GCFrequency, 5*time.Minute), now)

return s, nil
}

Expand Down Expand Up @@ -262,3 +263,21 @@ func (k *keyCacher) GetKeys() (storage.Keys, error) {
}
return storageKeys, nil
}

func startGarbageCollection(ctx context.Context, s storage.Storage, frequency time.Duration, now func() time.Time) {
go func() {
for {
select {
case <-ctx.Done():
return
case <-time.After(frequency):
if r, err := s.GarbageCollect(now()); err != nil {
log.Printf("garbage collection failed: %v", err)
} else {
log.Printf("garbage collection run, delete auth requests=%d, auth codes=%d", r.AuthRequests, r.AuthCodes)
}
}
}
}()
return
}
16 changes: 9 additions & 7 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ FDWV28nTP9sqbtsmU8Tem2jzMvZ7C/Q0AuDoKELFUpux8shm8wfIhyaPnXUGZoAZ
Np4vUwMSYV5mopESLWOg3loBxKyLGFtgGKVCjGiQvy6zISQ4fQo=
-----END RSA PRIVATE KEY-----`)

func newTestServer(t *testing.T, updateConfig func(c *Config)) (*httptest.Server, *Server) {
func newTestServer(t *testing.T, ctx context.Context, updateConfig func(c *Config)) (*httptest.Server, *Server) {
var server *Server
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
server.ServeHTTP(w, r)
Expand All @@ -91,22 +91,24 @@ func newTestServer(t *testing.T, updateConfig func(c *Config)) (*httptest.Server
s.URL = config.Issuer

var err error
if server, err = newServer(config, staticRotationStrategy(testKey)); err != nil {
if server, err = newServer(ctx, config, staticRotationStrategy(testKey)); err != nil {
t.Fatal(err)
}
server.skipApproval = true // Don't prompt for approval, just immediately redirect with code.
return s, server
}

func TestNewTestServer(t *testing.T) {
newTestServer(t, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
newTestServer(t, ctx, nil)
}

func TestDiscovery(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

httpServer, _ := newTestServer(t, func(c *Config) {
httpServer, _ := newTestServer(t, ctx, func(c *Config) {
c.Issuer = c.Issuer + "/non-root-path"
})
defer httpServer.Close()
Expand Down Expand Up @@ -255,7 +257,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

httpServer, s := newTestServer(t, func(c *Config) {
httpServer, s := newTestServer(t, ctx, func(c *Config) {
c.Issuer = c.Issuer + "/non-root-path"
})
defer httpServer.Close()
Expand Down Expand Up @@ -368,7 +370,7 @@ func TestOAuth2ImplicitFlow(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

httpServer, s := newTestServer(t, func(c *Config) {
httpServer, s := newTestServer(t, ctx, func(c *Config) {
// Enable support for the implicit flow.
c.SupportedResponseTypes = []string{"code", "token"}
})
Expand Down Expand Up @@ -498,7 +500,7 @@ func TestCrossClientScopes(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

httpServer, s := newTestServer(t, func(c *Config) {
httpServer, s := newTestServer(t, ctx, func(c *Config) {
c.Issuer = c.Issuer + "/non-root-path"
})
defer httpServer.Close()
Expand Down
104 changes: 97 additions & 7 deletions storage/conformance/conformance.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,10 @@ import (
// ensure that values being tested on never expire.
var neverExpire = time.Now().UTC().Add(time.Hour * 24 * 365 * 100)

// StorageFactory is a method for creating a new storage. The returned storage sould be initialized
// but shouldn't have any existing data in it.
type StorageFactory func() storage.Storage

// RunTestSuite runs a set of conformance tests against a storage.
func RunTestSuite(t *testing.T, sf StorageFactory) {
// RunTests runs a set of conformance tests against a storage. newStorage should
// return an initialized but empty storage. The storage will be closed at the
// end of each test run.
func RunTests(t *testing.T, newStorage func() storage.Storage) {
tests := []struct {
name string
run func(t *testing.T, s storage.Storage)
Expand All @@ -33,10 +31,13 @@ func RunTestSuite(t *testing.T, sf StorageFactory) {
{"ClientCRUD", testClientCRUD},
{"RefreshTokenCRUD", testRefreshTokenCRUD},
{"PasswordCRUD", testPasswordCRUD},
{"GarbageCollection", testGC},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
test.run(t, sf())
s := newStorage()
test.run(t, s)
s.Close()
})
}
}
Expand Down Expand Up @@ -276,3 +277,92 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) {
t.Errorf("after deleting password expected storage.ErrNotFound, got %v", err)
}
}

func testGC(t *testing.T, s storage.Storage) {
n := time.Now()
c := storage.AuthCode{
ID: storage.NewID(),
ClientID: "foobar",
RedirectURI: "https://localhost:80/callback",
Nonce: "foobar",
Scopes: []string{"openid", "email"},
Expiry: n.Add(time.Second),
ConnectorID: "ldap",
ConnectorData: []byte(`{"some":"data"}`),
Claims: storage.Claims{
UserID: "1",
Username: "jane",
Email: "[email protected]",
EmailVerified: true,
Groups: []string{"a", "b"},
},
}

if err := s.CreateAuthCode(c); err != nil {
t.Fatalf("failed creating auth code: %v", err)
}

if _, err := s.GarbageCollect(n); err != nil {
t.Errorf("garbage collection failed: %v", err)
}
if _, err := s.GetAuthCode(c.ID); err != nil {
t.Errorf("expected to be able to get auth code after GC: %v", err)
}

if r, err := s.GarbageCollect(n.Add(time.Minute)); err != nil {
t.Errorf("garbage collection failed: %v", err)
} else if r.AuthCodes != 1 {
t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthCodes)
}

if _, err := s.GetAuthCode(c.ID); err == nil {
t.Errorf("expected auth code to be GC'd")
} else if err != storage.ErrNotFound {
t.Errorf("expected storage.ErrNotFound, got %v", err)
}

a := storage.AuthRequest{
ID: storage.NewID(),
ClientID: "foobar",
ResponseTypes: []string{"code"},
Scopes: []string{"openid", "email"},
RedirectURI: "https://localhost:80/callback",
Nonce: "foo",
State: "bar",
ForceApprovalPrompt: true,
LoggedIn: true,
Expiry: n,
ConnectorID: "ldap",
ConnectorData: []byte(`{"some":"data"}`),
Claims: storage.Claims{
UserID: "1",
Username: "jane",
Email: "[email protected]",
EmailVerified: true,
Groups: []string{"a", "b"},
},
}

if err := s.CreateAuthRequest(a); err != nil {
t.Fatalf("failed creating auth request: %v", err)
}

if _, err := s.GarbageCollect(n); err != nil {
t.Errorf("garbage collection failed: %v", err)
}
if _, err := s.GetAuthRequest(a.ID); err != nil {
t.Errorf("expected to be able to get auth code after GC: %v", err)
}

if r, err := s.GarbageCollect(n.Add(time.Minute)); err != nil {
t.Errorf("garbage collection failed: %v", err)
} else if r.AuthRequests != 1 {
t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthRequests)
}

if _, err := s.GetAuthRequest(a.ID); err == nil {
t.Errorf("expected auth code to be GC'd")
} else if err != storage.ErrNotFound {
t.Errorf("expected storage.ErrNotFound, got %v", err)
}
}
Loading

0 comments on commit 5bec61d

Please sign in to comment.