Skip to content

Commit

Permalink
Ensure token is refreshed on Unauthenticated
Browse files Browse the repository at this point in the history
Signed-off-by: Haytham Abuelfutuh <[email protected]>
  • Loading branch information
EngHabu committed Apr 27, 2024
1 parent 95e9ac8 commit 0c40d53
Show file tree
Hide file tree
Showing 76 changed files with 316 additions and 261 deletions.
3 changes: 3 additions & 0 deletions flyteidl/clients/go/admin/auth_interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ func MaterializeCredentials(ctx context.Context, cfg *Config, tokenCache cache.T

wrappedTokenSource := NewCustomHeaderTokenSource(tokenSource, cfg.UseInsecureConnection, authorizationMetadataKey)
perRPCCredentials.Store(wrappedTokenSource)
// Clear the token cache so that subsequent calls will get a fresh token
tokenCache.Purge()

return nil
}

Expand Down
10 changes: 7 additions & 3 deletions flyteidl/clients/go/admin/auth_interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,13 @@ func Test_newAuthInterceptor(t *testing.T) {

f := NewPerRPCCredentialsFuture()
p := NewPerRPCCredentialsFuture()
c := &mocks.TokenCache{}
c.On("Purge").Return()
interceptor := NewAuthInterceptor(&Config{
Endpoint: config.URL{URL: *u},
UseInsecureConnection: true,
AuthType: AuthTypeClientSecret,
}, &mocks.TokenCache{}, f, p)
}, c, f, p)
unauthenticated := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
return status.New(codes.Unauthenticated, "").Err()
}
Expand Down Expand Up @@ -240,6 +242,8 @@ func Test_newAuthInterceptor(t *testing.T) {

func TestMaterializeCredentials(t *testing.T) {
port := rand.IntnRange(10000, 60000)
c := &mocks.TokenCache{}
c.On("Purge").Return()
t.Run("No oauth2 metadata endpoint or Public client config lookup", func(t *testing.T) {
m := &adminMocks.AuthMetadataServiceServer{}
m.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(nil, errors.New("unexpected call to get oauth2 metadata"))
Expand All @@ -263,7 +267,7 @@ func TestMaterializeCredentials(t *testing.T) {
Scopes: []string{"all"},
Audience: "http://localhost:30081",
AuthorizationHeader: "authorization",
}, &mocks.TokenCache{}, f, p)
}, c, f, p)
assert.NoError(t, err)
})
t.Run("Failed to fetch client metadata", func(t *testing.T) {
Expand All @@ -288,7 +292,7 @@ func TestMaterializeCredentials(t *testing.T) {
AuthType: AuthTypeClientSecret,
TokenURL: fmt.Sprintf("http://localhost:%d/api/v1/token", port),
Scopes: []string{"all"},
}, &mocks.TokenCache{}, f, p)
}, c, f, p)
assert.EqualError(t, err, "failed to fetch client metadata. Error: rpc error: code = Unknown desc = expected err")
})
}
Expand Down
11 changes: 10 additions & 1 deletion flyteidl/clients/go/admin/cache/token_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@ type TokenCache interface {
// SaveToken saves the token securely to cache.
SaveToken(token *oauth2.Token) error

// Retrieves the token from the cache.
// GetToken retrieves the token from the cache.
GetToken() (*oauth2.Token, error)

// Purge the token from the cache.
Purge()

// Lock the cache.
Lock()

// Unlock the cache.
Unlock()
}
33 changes: 28 additions & 5 deletions flyteidl/clients/go/admin/cache/token_cache_inmemory.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,46 @@ package cache

import (
"fmt"
"sync"
"sync/atomic"

"golang.org/x/oauth2"
)

type TokenCacheInMemoryProvider struct {
token *oauth2.Token
token atomic.Value
mu *sync.Mutex
}

func (t *TokenCacheInMemoryProvider) SaveToken(token *oauth2.Token) error {
t.token = token
t.token.Store(token)
return nil
}

func (t TokenCacheInMemoryProvider) GetToken() (*oauth2.Token, error) {
if t.token == nil {
func (t *TokenCacheInMemoryProvider) GetToken() (*oauth2.Token, error) {
tkn := t.token.Load()
if tkn == nil {
return nil, fmt.Errorf("cannot find token in cache")
}

return t.token, nil
return tkn.(*oauth2.Token), nil
}

func (t *TokenCacheInMemoryProvider) Purge() {
t.token.Store(nil)
}

func (t *TokenCacheInMemoryProvider) Lock() {
t.mu.Lock()
}

func (t *TokenCacheInMemoryProvider) Unlock() {
t.mu.Unlock()
}

func NewTokenCacheInMemoryProvider() *TokenCacheInMemoryProvider {
return &TokenCacheInMemoryProvider{
mu: &sync.Mutex{},
token: atomic.Value{},
}
}
2 changes: 1 addition & 1 deletion flyteidl/clients/go/admin/client_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func (cb *ClientsetBuilder) WithDialOptions(opts ...grpc.DialOption) *ClientsetB
// Build the clientset using the current state of the ClientsetBuilder
func (cb *ClientsetBuilder) Build(ctx context.Context) (*Clientset, error) {
if cb.tokenCache == nil {
cb.tokenCache = &cache.TokenCacheInMemoryProvider{}
cb.tokenCache = cache.NewTokenCacheInMemoryProvider()
}

if cb.config == nil {
Expand Down
4 changes: 2 additions & 2 deletions flyteidl/clients/go/admin/client_builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ func TestClientsetBuilder_Build(t *testing.T) {
cb := NewClientsetBuilder().WithConfig(&Config{
UseInsecureConnection: true,
Endpoint: config.URL{URL: *u},
}).WithTokenCache(&cache.TokenCacheInMemoryProvider{})
}).WithTokenCache(cache.NewTokenCacheInMemoryProvider())
ctx := context.Background()
_, err := cb.Build(ctx)
assert.NoError(t, err)
assert.True(t, reflect.TypeOf(cb.tokenCache) == reflect.TypeOf(&cache.TokenCacheInMemoryProvider{}))
assert.True(t, reflect.TypeOf(cb.tokenCache) == reflect.TypeOf(cache.NewTokenCacheInMemoryProvider()))
}
4 changes: 3 additions & 1 deletion flyteidl/clients/go/admin/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ func TestGetAuthenticationDialOptionPkce(t *testing.T) {
mockAuthClient := new(mocks.AuthMetadataServiceClient)
mockTokenCache.OnGetTokenMatch().Return(&tokenData, nil)
mockTokenCache.OnSaveTokenMatch(mock.Anything).Return(nil)
mockTokenCache.On("Lock").Return()
mockTokenCache.On("Unlock").Return()
mockAuthClient.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(metadata, nil)
mockAuthClient.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(clientMetatadata, nil)
tokenSourceProvider, err := NewTokenSourceProvider(ctx, adminServiceConfig, mockTokenCache, mockAuthClient)
Expand Down Expand Up @@ -288,7 +290,7 @@ func Test_getPkceAuthTokenSource(t *testing.T) {
assert.NoError(t, err)

// populate the cache
tokenCache := &cache.TokenCacheInMemoryProvider{}
tokenCache := cache.NewTokenCacheInMemoryProvider()
assert.NoError(t, tokenCache.SaveToken(&tokenData))

baseOrchestrator := tokenorchestrator.BaseTokenOrchestrator{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (
func TestFetchFromAuthFlow(t *testing.T) {
ctx := context.Background()
t.Run("fetch from auth flow", func(t *testing.T) {
tokenCache := &cache.TokenCacheInMemoryProvider{}
tokenCache := cache.NewTokenCacheInMemoryProvider()
orchestrator, err := NewDeviceFlowTokenOrchestrator(tokenorchestrator.BaseTokenOrchestrator{
ClientConfig: &oauth.Config{
Config: &oauth2.Config{
Expand Down Expand Up @@ -97,7 +97,7 @@ func TestFetchFromAuthFlow(t *testing.T) {
}))
defer fakeServer.Close()

tokenCache := &cache.TokenCacheInMemoryProvider{}
tokenCache := cache.NewTokenCacheInMemoryProvider()
orchestrator, err := NewDeviceFlowTokenOrchestrator(tokenorchestrator.BaseTokenOrchestrator{
ClientConfig: &oauth.Config{
Config: &oauth2.Config{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
func TestFetchFromAuthFlow(t *testing.T) {
ctx := context.Background()
t.Run("fetch from auth flow", func(t *testing.T) {
tokenCache := &cache.TokenCacheInMemoryProvider{}
tokenCache := cache.NewTokenCacheInMemoryProvider()
orchestrator, err := NewTokenOrchestrator(tokenorchestrator.BaseTokenOrchestrator{
ClientConfig: &oauth.Config{
Config: &oauth2.Config{
Expand Down
14 changes: 11 additions & 3 deletions flyteidl/clients/go/admin/token_source_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ func NewClientCredentialsTokenSourceProvider(ctx context.Context, cfg *Config, s
}
secret = strings.TrimSpace(secret)
if tokenCache == nil {
tokenCache = &cache.TokenCacheInMemoryProvider{}
tokenCache = cache.NewTokenCacheInMemoryProvider()
}
return ClientCredentialsTokenSourceProvider{
ccConfig: clientcredentials.Config{
Expand Down Expand Up @@ -225,16 +225,24 @@ func (s *customTokenSource) Token() (*oauth2.Token, error) {
return token, nil
}

s.tokenCache.Lock()
defer s.tokenCache.Unlock()

// Check again here in case another goroutine has already updated the token
if token, err := s.tokenCache.GetToken(); err == nil && token.Valid() {
return token, nil
}

token, err := s.new.Token()
if err != nil {
logger.Warnf(s.ctx, "failed to get token: %w", err)
logger.Warnf(s.ctx, "failed to get token: %v", err)
return nil, fmt.Errorf("failed to get token: %w", err)
}
logger.Infof(s.ctx, "retrieved token with expiry %v", token.Expiry)

err = s.tokenCache.SaveToken(token)
if err != nil {
logger.Warnf(s.ctx, "failed to cache token: %w", err)
logger.Warnf(s.ctx, "failed to cache token: %v", err)
}

return token, nil
Expand Down
3 changes: 3 additions & 0 deletions flyteidl/clients/go/admin/token_source_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ func TestCustomTokenSource_Token(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
tokenCache := &tokenCacheMocks.TokenCache{}
tokenCache.OnGetToken().Return(test.token, nil).Once()
tokenCache.OnGetToken().Return(test.token, nil).Maybe()
tokenCache.On("Lock").Return().Maybe()
tokenCache.On("Unlock").Return().Maybe()
provider, err := NewClientCredentialsTokenSourceProvider(ctx, cfg, []string{}, "", tokenCache, "")
assert.NoError(t, err)
source, err := provider.GetTokenSource(ctx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ package tokenorchestrator
import (
"context"
"fmt"
"time"

"golang.org/x/oauth2"

"github.com/flyteorg/flyte/flyteidl/clients/go/admin/cache"
Expand Down Expand Up @@ -53,16 +51,17 @@ func (t BaseTokenOrchestrator) FetchTokenFromCacheOrRefreshIt(ctx context.Contex
return nil, err
}

if !token.Valid() {
return nil, fmt.Errorf("token from cache is invalid")
if token.Valid() {
return token, nil
}

// If token doesn't need to be refreshed, return it.
if time.Now().Before(token.Expiry.Add(-tokenRefreshGracePeriod.Duration)) {
logger.Infof(ctx, "found the token in the cache")
return token, nil
t.TokenCache.Lock()
defer t.TokenCache.Unlock()

token, err = t.TokenCache.GetToken()
if err != nil {
return nil, err
}
token.Expiry = token.Expiry.Add(-tokenRefreshGracePeriod.Duration)

token, err = t.RefreshToken(ctx, token)
if err != nil {
Expand All @@ -73,6 +72,7 @@ func (t BaseTokenOrchestrator) FetchTokenFromCacheOrRefreshIt(ctx context.Contex
return nil, fmt.Errorf("refreshed token is invalid")
}

token.Expiry = token.Expiry.Add(-tokenRefreshGracePeriod.Duration)
err = t.TokenCache.SaveToken(token)
if err != nil {
return nil, fmt.Errorf("failed to save token in the token cache. Error: %w", err)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func TestRefreshTheToken(t *testing.T) {
ClientID: "dummyClient",
},
}
tokenCacheProvider := &cache.TokenCacheInMemoryProvider{}
tokenCacheProvider := cache.NewTokenCacheInMemoryProvider()
orchestrator := BaseTokenOrchestrator{
ClientConfig: clientConf,
TokenCache: tokenCacheProvider,
Expand Down Expand Up @@ -58,7 +58,7 @@ func TestFetchFromCache(t *testing.T) {
mockAuthClient.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(clientMetatadata, nil)

t.Run("no token in cache", func(t *testing.T) {
tokenCacheProvider := &cache.TokenCacheInMemoryProvider{}
tokenCacheProvider := cache.NewTokenCacheInMemoryProvider()

orchestrator, err := NewBaseTokenOrchestrator(ctx, tokenCacheProvider, mockAuthClient)

Expand All @@ -69,7 +69,7 @@ func TestFetchFromCache(t *testing.T) {
})

t.Run("token in cache", func(t *testing.T) {
tokenCacheProvider := &cache.TokenCacheInMemoryProvider{}
tokenCacheProvider := cache.NewTokenCacheInMemoryProvider()
orchestrator, err := NewBaseTokenOrchestrator(ctx, tokenCacheProvider, mockAuthClient)
assert.NoError(t, err)
fileData, _ := os.ReadFile("testdata/token.json")
Expand All @@ -86,7 +86,7 @@ func TestFetchFromCache(t *testing.T) {
})

t.Run("expired token in cache", func(t *testing.T) {
tokenCacheProvider := &cache.TokenCacheInMemoryProvider{}
tokenCacheProvider := cache.NewTokenCacheInMemoryProvider()
orchestrator, err := NewBaseTokenOrchestrator(ctx, tokenCacheProvider, mockAuthClient)
assert.NoError(t, err)
fileData, _ := os.ReadFile("testdata/token.json")
Expand Down
4 changes: 2 additions & 2 deletions flyteidl/gen/pb-go/gateway/flyteidl/admin/agent.swagger.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 0c40d53

Please sign in to comment.