From 3b96de605ff4898aaf1f19c919c9574616b3dbed Mon Sep 17 00:00:00 2001 From: Ray Luo Date: Wed, 7 Feb 2024 13:11:14 -0800 Subject: [PATCH] Implement remove_tokens_for_client() --- msal/application.py | 13 +++++++++++++ tests/test_application.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/msal/application.py b/msal/application.py index f82ea2e3..0f86be5b 100644 --- a/msal/application.py +++ b/msal/application.py @@ -2178,6 +2178,19 @@ def _acquire_token_for_client( telemetry_context.update_telemetry(response) return response + def remove_tokens_for_client(self): + """Remove all tokens that were previously acquired via + :func:`~acquire_token_for_client()` for the current client.""" + for env in [self.authority.instance] + self._get_authority_aliases( + self.authority.instance): + for at in self.token_cache.find(TokenCache.CredentialType.ACCESS_TOKEN, query={ + "client_id": self.client_id, + "environment": env, + "home_account_id": None, # These are mostly app-only tokens + }): + self.token_cache.remove_at(at) + # acquire_token_for_client() obtains no RTs, so we have no RT to remove + def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=None, **kwargs): """Acquires token using on-behalf-of (OBO) flow. diff --git a/tests/test_application.py b/tests/test_application.py index fc529f01..cebc7225 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -662,6 +662,35 @@ def test_organizations_authority_should_emit_warning(self): authority="https://login.microsoftonline.com/organizations") +class TestRemoveTokensForClient(unittest.TestCase): + def test_remove_tokens_for_client_should_remove_client_tokens_only(self): + at_for_user = "AT for user" + cca = msal.ConfidentialClientApplication( + "client_id", client_credential="secret", + authority="https://login.microsoftonline.com/microsoft.onmicrosoft.com") + self.assertEqual( + 0, len(cca.token_cache.find(msal.TokenCache.CredentialType.ACCESS_TOKEN))) + cca.acquire_token_for_client( + ["scope"], + post=lambda url, **kwargs: MinimalResponse( + status_code=200, text=json.dumps({"access_token": "AT for client"}))) + self.assertEqual( + 1, len(cca.token_cache.find(msal.TokenCache.CredentialType.ACCESS_TOKEN))) + cca.acquire_token_by_username_password( + "johndoe", "password", ["scope"], + post=lambda url, **kwargs: MinimalResponse( + status_code=200, text=json.dumps(build_response( + access_token=at_for_user, expires_in=3600, + uid="uid", utid="utid", # This populates home_account_id + )))) + self.assertEqual( + 2, len(cca.token_cache.find(msal.TokenCache.CredentialType.ACCESS_TOKEN))) + cca.remove_tokens_for_client() + remaining_tokens = cca.token_cache.find(msal.TokenCache.CredentialType.ACCESS_TOKEN) + self.assertEqual(1, len(remaining_tokens)) + self.assertEqual(at_for_user, remaining_tokens[0].get("secret")) + + class TestScopeDecoration(unittest.TestCase): def _test_client_id_should_be_a_valid_scope(self, client_id, other_scopes): # B2C needs this https://learn.microsoft.com/en-us/azure/active-directory-b2c/access-tokens#openid-connect-scopes