diff --git a/cmd/auth/auth.go b/cmd/auth/auth.go index ceceae25c5..967e79a2c3 100644 --- a/cmd/auth/auth.go +++ b/cmd/auth/auth.go @@ -31,6 +31,7 @@ GCP: https://docs.gcp.databricks.com/dev-tools/auth/index.html`, cmd.AddCommand(newProfilesCommand()) cmd.AddCommand(newTokenCommand(&perisistentAuth)) cmd.AddCommand(newDescribeCommand()) + cmd.AddCommand(newLogoutCommand(&perisistentAuth)) return cmd } diff --git a/cmd/auth/logout.go b/cmd/auth/logout.go new file mode 100644 index 0000000000..2153cfd7d0 --- /dev/null +++ b/cmd/auth/logout.go @@ -0,0 +1,110 @@ +package auth + +import ( + "context" + "errors" + "fmt" + "io/fs" + + "github.com/databricks/cli/libs/auth" + "github.com/databricks/cli/libs/auth/cache" + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/databrickscfg/profile" + "github.com/databricks/databricks-sdk-go/config" + "github.com/spf13/cobra" +) + +type logoutSession struct { + profile string + file config.File + persistentAuth *auth.PersistentAuth +} + +func (l *logoutSession) load(ctx context.Context, profileName string, persistentAuth *auth.PersistentAuth) error { + l.profile = profileName + l.persistentAuth = persistentAuth + iniFile, err := profile.DefaultProfiler.Get(ctx) + if errors.Is(err, fs.ErrNotExist) { + return err + } else if err != nil { + return fmt.Errorf("cannot parse config file: %w", err) + } + l.file = *iniFile + if err := l.setHostAndAccountIdFromProfile(); err != nil { + return err + } + return nil +} + +func (l *logoutSession) setHostAndAccountIdFromProfile() error { + sectionMap, err := l.getConfigSectionMap() + if err != nil { + return err + } + if sectionMap["host"] == "" { + return fmt.Errorf("no host configured for profile %s", l.profile) + } + l.persistentAuth.Host = sectionMap["host"] + l.persistentAuth.AccountID = sectionMap["account_id"] + return nil +} + +func (l *logoutSession) getConfigSectionMap() (map[string]string, error) { + section, err := l.file.GetSection(l.profile) + if err != nil { + return map[string]string{}, fmt.Errorf("profile does not exist in config file: %w", err) + } + return section.KeysHash(), nil +} + +// clear token from ~/.databricks/token-cache.json +func (l *logoutSession) clearTokenCache(ctx context.Context) error { + return l.persistentAuth.ClearToken(ctx) +} + +func newLogoutCommand(persistentAuth *auth.PersistentAuth) *cobra.Command { + cmd := &cobra.Command{ + Use: "logout [PROFILE]", + Short: "Logout from specified profile", + Long: "Removes the OAuth token from the token-cache", + } + + cmd.RunE = func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + profileNameFromFlag := cmd.Flag("profile").Value.String() + // If both [PROFILE] and --profile are provided, return an error. + if len(args) > 0 && profileNameFromFlag != "" { + return fmt.Errorf("please only provide a profile as an argument or a flag, not both") + } + // Determine the profile name from either args or the flag. + profileName := profileNameFromFlag + if len(args) > 0 { + profileName = args[0] + } + // If the user has not specified a profile name, prompt for one. + if profileName == "" { + var err error + profileName, err = promptForProfile(ctx, persistentAuth.ProfileName()) + if err != nil { + return err + } + } + defer persistentAuth.Close() + logoutSession := &logoutSession{} + err := logoutSession.load(ctx, profileName, persistentAuth) + if err != nil { + return err + } + err = logoutSession.clearTokenCache(ctx) + if err != nil { + if errors.Is(err, cache.ErrNotConfigured) { + // It is OK to not have OAuth configured + } else { + return err + } + } + cmdio.LogString(ctx, fmt.Sprintf("Profile %s is logged out", profileName)) + return nil + } + return cmd +} diff --git a/cmd/auth/logout_test.go b/cmd/auth/logout_test.go new file mode 100644 index 0000000000..27549274da --- /dev/null +++ b/cmd/auth/logout_test.go @@ -0,0 +1,62 @@ +package auth + +import ( + "context" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/databricks/cli/libs/auth" + "github.com/databricks/cli/libs/databrickscfg" + "github.com/databricks/databricks-sdk-go/config" +) + +func TestLogout_setHostAndAccountIdFromProfile(t *testing.T) { + ctx := context.Background() + path := filepath.Join(t.TempDir(), "databrickscfg") + + err := databrickscfg.SaveToProfile(ctx, &config.Config{ + ConfigFile: path, + Profile: "abc", + Host: "https://foo", + Token: "xyz", + }) + require.NoError(t, err) + iniFile, err := config.LoadFile(path) + require.NoError(t, err) + logout := &logoutSession{ + profile: "abc", + file: *iniFile, + persistentAuth: &auth.PersistentAuth{}, + } + err = logout.setHostAndAccountIdFromProfile() + assert.NoError(t, err) + assert.Equal(t, logout.persistentAuth.Host, "https://foo") + assert.Empty(t, logout.persistentAuth.AccountID) +} + +func TestLogout_getConfigSectionMap(t *testing.T) { + ctx := context.Background() + path := filepath.Join(t.TempDir(), "databrickscfg") + + err := databrickscfg.SaveToProfile(ctx, &config.Config{ + ConfigFile: path, + Profile: "abc", + Host: "https://foo", + Token: "xyz", + }) + require.NoError(t, err) + iniFile, err := config.LoadFile(path) + require.NoError(t, err) + logout := &logoutSession{ + profile: "abc", + file: *iniFile, + persistentAuth: &auth.PersistentAuth{}, + } + configSectionMap, err := logout.getConfigSectionMap() + assert.NoError(t, err) + assert.Equal(t, configSectionMap["host"], "https://foo") + assert.Equal(t, configSectionMap["token"], "xyz") +} diff --git a/libs/auth/cache/cache.go b/libs/auth/cache/cache.go index 097353e74c..2c88d093f8 100644 --- a/libs/auth/cache/cache.go +++ b/libs/auth/cache/cache.go @@ -9,6 +9,7 @@ import ( type TokenCache interface { Store(key string, t *oauth2.Token) error Lookup(key string) (*oauth2.Token, error) + Delete(key string) error } var tokenCache int diff --git a/libs/auth/cache/file.go b/libs/auth/cache/file.go index 38dfea9f2c..f1ac4c64d7 100644 --- a/libs/auth/cache/file.go +++ b/libs/auth/cache/file.go @@ -52,11 +52,7 @@ func (c *FileTokenCache) Store(key string, t *oauth2.Token) error { c.Tokens = map[string]*oauth2.Token{} } c.Tokens[key] = t - raw, err := json.MarshalIndent(c, "", " ") - if err != nil { - return fmt.Errorf("marshal: %w", err) - } - return os.WriteFile(c.fileLocation, raw, ownerReadWrite) + return c.write() } func (c *FileTokenCache) Lookup(key string) (*oauth2.Token, error) { @@ -73,6 +69,24 @@ func (c *FileTokenCache) Lookup(key string) (*oauth2.Token, error) { return t, nil } +func (c *FileTokenCache) Delete(key string) error { + err := c.load() + if errors.Is(err, fs.ErrNotExist) { + return ErrNotConfigured + } else if err != nil { + return fmt.Errorf("load: %w", err) + } + if c.Tokens == nil { + c.Tokens = map[string]*oauth2.Token{} + } + _, ok := c.Tokens[key] + if !ok { + return ErrNotConfigured + } + delete(c.Tokens, key) + return c.write() +} + func (c *FileTokenCache) location() (string, error) { home, err := os.UserHomeDir() if err != nil { @@ -105,4 +119,12 @@ func (c *FileTokenCache) load() error { return nil } +func (c *FileTokenCache) write() error { + raw, err := json.MarshalIndent(c, "", " ") + if err != nil { + return fmt.Errorf("marshal: %w", err) + } + return os.WriteFile(c.fileLocation, raw, ownerReadWrite) +} + var _ TokenCache = (*FileTokenCache)(nil) diff --git a/libs/auth/cache/file_test.go b/libs/auth/cache/file_test.go index 3e4aae36f4..48ab5c98af 100644 --- a/libs/auth/cache/file_test.go +++ b/libs/auth/cache/file_test.go @@ -1,6 +1,7 @@ package cache import ( + "encoding/json" "os" "path/filepath" "runtime" @@ -103,3 +104,64 @@ func TestStoreOnDev(t *testing.T) { // macOS: read-only file system assert.Error(t, err) } + +func TestStoreAndDeleteKey(t *testing.T) { + setup(t) + c := &FileTokenCache{} + err := c.Store("x", &oauth2.Token{ + AccessToken: "abc", + }) + require.NoError(t, err) + + err = c.Store("y", &oauth2.Token{ + AccessToken: "bcd", + }) + require.NoError(t, err) + + l := &FileTokenCache{} + err = l.Delete("x") + require.NoError(t, err) + assert.Equal(t, 1, len(l.Tokens)) + + _, err = l.Lookup("x") + assert.Equal(t, ErrNotConfigured, err) + + tok, err := l.Lookup("y") + require.NoError(t, err) + assert.Equal(t, "bcd", tok.AccessToken) +} + +func TestDeleteKeyNotExist(t *testing.T) { + c := &FileTokenCache{ + Tokens: map[string]*oauth2.Token{}, + } + err := c.Delete("x") + assert.Equal(t, ErrNotConfigured, err) + + _, err = c.Lookup("x") + assert.Equal(t, ErrNotConfigured, err) +} + +func TestWrite(t *testing.T) { + tempFile := filepath.Join(t.TempDir(), "token-cache.json") + + tokenMap := map[string]*oauth2.Token{} + token := &oauth2.Token{ + AccessToken: "some-access-token", + } + tokenMap["test"] = token + + cache := &FileTokenCache{ + fileLocation: tempFile, + Tokens: tokenMap, + } + + err := cache.write() + assert.NoError(t, err) + + content, err := os.ReadFile(tempFile) + require.NoError(t, err) + + expected, _ := json.MarshalIndent(&cache, "", " ") + assert.Equal(t, content, expected) +} diff --git a/libs/auth/cache/in_memory.go b/libs/auth/cache/in_memory.go index 469d45575a..756002e08e 100644 --- a/libs/auth/cache/in_memory.go +++ b/libs/auth/cache/in_memory.go @@ -23,4 +23,14 @@ func (i *InMemoryTokenCache) Store(key string, t *oauth2.Token) error { return nil } +// Delete implements TokenCache. +func (i *InMemoryTokenCache) Delete(key string) error { + _, ok := i.Tokens[key] + if !ok { + return ErrNotConfigured + } + delete(i.Tokens, key) + return nil +} + var _ TokenCache = (*InMemoryTokenCache)(nil) diff --git a/libs/auth/cache/in_memory_test.go b/libs/auth/cache/in_memory_test.go index d8394d3b26..7399523179 100644 --- a/libs/auth/cache/in_memory_test.go +++ b/libs/auth/cache/in_memory_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.org/x/oauth2" ) @@ -42,3 +43,40 @@ func TestInMemoryCacheStore(t *testing.T) { assert.Equal(t, res, token) assert.NoError(t, err) } + +func TestInMemoryDeleteKey(t *testing.T) { + c := &InMemoryTokenCache{ + Tokens: map[string]*oauth2.Token{}, + } + err := c.Store("x", &oauth2.Token{ + AccessToken: "abc", + }) + require.NoError(t, err) + + err = c.Store("y", &oauth2.Token{ + AccessToken: "bcd", + }) + require.NoError(t, err) + + err = c.Delete("x") + require.NoError(t, err) + assert.Equal(t, 1, len(c.Tokens)) + + _, err = c.Lookup("x") + assert.Equal(t, ErrNotConfigured, err) + + tok, err := c.Lookup("y") + require.NoError(t, err) + assert.Equal(t, "bcd", tok.AccessToken) +} + +func TestInMemoryDeleteKeyNotExist(t *testing.T) { + c := &InMemoryTokenCache{ + Tokens: map[string]*oauth2.Token{}, + } + err := c.Delete("x") + assert.Equal(t, ErrNotConfigured, err) + + _, err = c.Lookup("x") + assert.Equal(t, ErrNotConfigured, err) +} diff --git a/libs/auth/oauth.go b/libs/auth/oauth.go index 7c1cb95768..ec615fe717 100644 --- a/libs/auth/oauth.go +++ b/libs/auth/oauth.go @@ -143,6 +143,18 @@ func (a *PersistentAuth) Challenge(ctx context.Context) error { return nil } +func (a *PersistentAuth) ClearToken(ctx context.Context) error { + if a.Host == "" && a.AccountID == "" { + return ErrFetchCredentials + } + if a.cache == nil { + a.cache = cache.GetTokenCache(ctx) + } + // lookup token identified by host (and possibly the account id) + key := a.key() + return a.cache.Delete(key) +} + func (a *PersistentAuth) init(ctx context.Context) error { if a.Host == "" && a.AccountID == "" { return ErrFetchCredentials diff --git a/libs/auth/oauth_test.go b/libs/auth/oauth_test.go index ea6a8061e6..a8a1da70ef 100644 --- a/libs/auth/oauth_test.go +++ b/libs/auth/oauth_test.go @@ -55,6 +55,7 @@ func TestOidcForWorkspace(t *testing.T) { type tokenCacheMock struct { store func(key string, t *oauth2.Token) error lookup func(key string) (*oauth2.Token, error) + delete func(key string) error } func (m *tokenCacheMock) Store(key string, t *oauth2.Token) error { @@ -71,6 +72,13 @@ func (m *tokenCacheMock) Lookup(key string) (*oauth2.Token, error) { return m.lookup(key) } +func (m *tokenCacheMock) Delete(key string) error { + if m.delete == nil { + panic("no deleteKey mock") + } + return m.delete(key) +} + func TestLoad(t *testing.T) { p := &PersistentAuth{ Host: "abc", @@ -228,3 +236,49 @@ func TestChallengeFailed(t *testing.T) { assert.EqualError(t, err, "authorize: access_denied: Policy evaluation failed for this request") }) } + +func TestClearToken(t *testing.T) { + p := &PersistentAuth{ + Host: "abc", + AccountID: "xyz", + cache: &tokenCacheMock{ + lookup: func(key string) (*oauth2.Token, error) { + assert.Equal(t, "https://abc/oidc/accounts/xyz", key) + return &oauth2.Token{}, ErrNotConfigured + }, + delete: func(key string) error { + assert.Equal(t, "https://abc/oidc/accounts/xyz", key) + return nil + }, + }, + } + defer p.Close() + err := p.ClearToken(context.Background()) + assert.NoError(t, err) + key := p.key() + _, err = p.cache.Lookup(key) + assert.Equal(t, ErrNotConfigured, err) +} + +func TestClearTokenNotExist(t *testing.T) { + p := &PersistentAuth{ + Host: "abc", + AccountID: "xyz", + cache: &tokenCacheMock{ + lookup: func(key string) (*oauth2.Token, error) { + assert.Equal(t, "https://abc/oidc/accounts/xyz", key) + return &oauth2.Token{}, ErrNotConfigured + }, + delete: func(key string) error { + assert.Equal(t, "https://abc/oidc/accounts/xyz", key) + return ErrNotConfigured + }, + }, + } + defer p.Close() + err := p.ClearToken(context.Background()) + assert.Equal(t, ErrNotConfigured, err) + key := p.key() + _, err = p.cache.Lookup(key) + assert.Equal(t, ErrNotConfigured, err) +}