From 03f4b6dda24a9bc9a1b08677625923b45601928c Mon Sep 17 00:00:00 2001 From: Gladwin Johnson <90415114+gladjohn@users.noreply.github.com> Date: Thu, 9 Jan 2025 10:55:48 -0800 Subject: [PATCH] Add ML MSI Source (#5053) * ese nit Update ManagedIdentity environment variables and add MachineLearning source# src/client/Microsoft.Identity.Client/PublicApi/net472/PublicAPI.Unshipped.txt * pr comments * Update src/client/Microsoft.Identity.Client/ManagedIdentity/MachineLearningManagedIdentitySource.cs Co-authored-by: Neha Bhargava <61847233+neha-bhargava@users.noreply.github.com> * tests * Metadata * improve tests --------- Co-authored-by: Gladwin Johnson Co-authored-by: Neha Bhargava <61847233+neha-bhargava@users.noreply.github.com> --- .../ManagedIdentity/EnvironmentVariables.cs | 1 + .../MachineLearningManagedIdentitySource.cs | 95 +++++++++ .../ManagedIdentity/ManagedIdentityClient.cs | 9 +- .../ManagedIdentity/ManagedIdentitySource.cs | 7 +- .../PublicApi/net462/PublicAPI.Unshipped.txt | 1 + .../PublicApi/net472/PublicAPI.Unshipped.txt | 1 + .../net8.0-android/PublicAPI.Unshipped.txt | 1 + .../net8.0-ios/PublicAPI.Unshipped.txt | 1 + .../PublicApi/net8.0/PublicAPI.Unshipped.txt | 1 + .../netstandard2.0/PublicAPI.Unshipped.txt | 1 + .../Core/Helpers/ManagedIdentityTestUtil.cs | 4 + .../Core/Mocks/MockHttpManagerExtensions.cs | 8 +- .../MachineLearningTests.cs | 51 +++++ .../ManagedIdentityTests.cs | 190 ++++++++++++++++++ 14 files changed, 367 insertions(+), 4 deletions(-) create mode 100644 src/client/Microsoft.Identity.Client/ManagedIdentity/MachineLearningManagedIdentitySource.cs create mode 100644 tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/MachineLearningTests.cs diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/EnvironmentVariables.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/EnvironmentVariables.cs index d4d2603619..aed4821dae 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/EnvironmentVariables.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/EnvironmentVariables.cs @@ -12,6 +12,7 @@ internal class EnvironmentVariables public static string PodIdentityEndpoint => Environment.GetEnvironmentVariable("AZURE_POD_IDENTITY_AUTHORITY_HOST"); public static string ImdsEndpoint => Environment.GetEnvironmentVariable("IMDS_ENDPOINT"); public static string MsiEndpoint => Environment.GetEnvironmentVariable("MSI_ENDPOINT"); + public static string MsiSecret => Environment.GetEnvironmentVariable("MSI_SECRET"); public static string IdentityServerThumbprint => Environment.GetEnvironmentVariable("IDENTITY_SERVER_THUMBPRINT"); } } diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/MachineLearningManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/MachineLearningManagedIdentitySource.cs new file mode 100644 index 0000000000..8eaf33749c --- /dev/null +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/MachineLearningManagedIdentitySource.cs @@ -0,0 +1,95 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Globalization; +using Microsoft.Identity.Client.Core; +using Microsoft.Identity.Client.Internal; + +namespace Microsoft.Identity.Client.ManagedIdentity +{ + internal class MachineLearningManagedIdentitySource : AbstractManagedIdentity + { + private const string MachineLearningMsiApiVersion = "2017-09-01"; + private const string SecretHeaderName = "secret"; + + private readonly Uri _endpoint; + private readonly string _secret; + + public static AbstractManagedIdentity Create(RequestContext requestContext) + { + requestContext.Logger.Info(() => "[Managed Identity] Machine learning managed identity is available."); + + return TryValidateEnvVars(EnvironmentVariables.MsiEndpoint, requestContext.Logger, out Uri endpointUri) + ? new MachineLearningManagedIdentitySource(requestContext, endpointUri, EnvironmentVariables.MsiSecret) + : null; + } + + private MachineLearningManagedIdentitySource(RequestContext requestContext, Uri endpoint, string secret) + : base(requestContext, ManagedIdentitySource.MachineLearning) + { + _endpoint = endpoint; + _secret = secret; + } + + private static bool TryValidateEnvVars(string msiEndpoint, ILoggerAdapter logger, out Uri endpointUri) + { + endpointUri = null; + + try + { + endpointUri = new Uri(msiEndpoint); + } + catch (FormatException ex) + { + string errorMessage = string.Format( + CultureInfo.InvariantCulture, + MsalErrorMessage.ManagedIdentityEndpointInvalidUriError, + "MSI_ENDPOINT", msiEndpoint, "Machine learning"); + + // Use the factory to create and throw the exception + var exception = MsalServiceExceptionFactory.CreateManagedIdentityException( + MsalError.InvalidManagedIdentityEndpoint, + errorMessage, + ex, + ManagedIdentitySource.MachineLearning, + null); // statusCode is null in this case + + throw exception; + } + + logger.Info($"[Managed Identity] Environment variables validation passed for machine learning managed identity. Endpoint URI: {endpointUri}. Creating machine learning managed identity."); + return true; + } + + protected override ManagedIdentityRequest CreateRequest(string resource) + { + ManagedIdentityRequest request = new(System.Net.Http.HttpMethod.Get, _endpoint); + + request.Headers.Add("Metadata", "true"); + request.Headers.Add(SecretHeaderName, _secret); + request.QueryParameters["api-version"] = MachineLearningMsiApiVersion; + request.QueryParameters["resource"] = resource; + + switch (_requestContext.ServiceBundle.Config.ManagedIdentityId.IdType) + { + case AppConfig.ManagedIdentityIdType.ClientId: + _requestContext.Logger.Info("[Managed Identity] Adding user assigned client id to the request."); + request.QueryParameters[Constants.ManagedIdentityClientId] = _requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId; + break; + + case AppConfig.ManagedIdentityIdType.ResourceId: + _requestContext.Logger.Info("[Managed Identity] Adding user assigned resource id to the request."); + request.QueryParameters[Constants.ManagedIdentityResourceId] = _requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId; + break; + + case AppConfig.ManagedIdentityIdType.ObjectId: + _requestContext.Logger.Info("[Managed Identity] Adding user assigned object id to the request."); + request.QueryParameters[Constants.ManagedIdentityObjectId] = _requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId; + break; + } + + return request; + } + } +} diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityClient.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityClient.cs index 65c4a8ba3f..566a820616 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityClient.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityClient.cs @@ -41,6 +41,7 @@ private static AbstractManagedIdentity SelectManagedIdentitySource(RequestContex { ManagedIdentitySource.ServiceFabric => ServiceFabricManagedIdentitySource.Create(requestContext), ManagedIdentitySource.AppService => AppServiceManagedIdentitySource.Create(requestContext), + ManagedIdentitySource.MachineLearning => MachineLearningManagedIdentitySource.Create(requestContext), ManagedIdentitySource.CloudShell => CloudShellManagedIdentitySource.Create(requestContext), ManagedIdentitySource.AzureArc => AzureArcManagedIdentitySource.Create(requestContext), _ => new ImdsManagedIdentitySource(requestContext) @@ -57,11 +58,15 @@ internal static ManagedIdentitySource GetManagedIdentitySource(ILoggerAdapter lo string identityServerThumbprint = EnvironmentVariables.IdentityServerThumbprint; string msiSecret = EnvironmentVariables.IdentityHeader; string msiEndpoint = EnvironmentVariables.MsiEndpoint; + string msiSecretMachineLearning = EnvironmentVariables.MsiSecret; string imdsEndpoint = EnvironmentVariables.ImdsEndpoint; string podIdentityEndpoint = EnvironmentVariables.PodIdentityEndpoint; - - if (!string.IsNullOrEmpty(identityEndpoint) && !string.IsNullOrEmpty(identityHeader)) + if (!string.IsNullOrEmpty(msiSecretMachineLearning) && !string.IsNullOrEmpty(msiEndpoint)) + { + return ManagedIdentitySource.MachineLearning; + } + else if (!string.IsNullOrEmpty(identityEndpoint) && !string.IsNullOrEmpty(identityHeader)) { if (!string.IsNullOrEmpty(identityServerThumbprint)) { diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentitySource.cs index 8ae1181539..69e3471bdf 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentitySource.cs @@ -48,6 +48,11 @@ public enum ManagedIdentitySource /// Indicates that the source is defaulted to IMDS since no environment variables are set. /// This is used to detect the managed identity source. /// - DefaultToImds + DefaultToImds, + + /// + /// The source to acquire token for managed identity is Machine Learning Service. + /// + MachineLearning } } diff --git a/src/client/Microsoft.Identity.Client/PublicApi/net462/PublicAPI.Unshipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/net462/PublicAPI.Unshipped.txt index e69de29bb2..d1db9319dc 100644 --- a/src/client/Microsoft.Identity.Client/PublicApi/net462/PublicAPI.Unshipped.txt +++ b/src/client/Microsoft.Identity.Client/PublicApi/net462/PublicAPI.Unshipped.txt @@ -0,0 +1 @@ +Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource.MachineLearning = 7 -> Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource \ No newline at end of file diff --git a/src/client/Microsoft.Identity.Client/PublicApi/net472/PublicAPI.Unshipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/net472/PublicAPI.Unshipped.txt index e69de29bb2..6b01d55560 100644 --- a/src/client/Microsoft.Identity.Client/PublicApi/net472/PublicAPI.Unshipped.txt +++ b/src/client/Microsoft.Identity.Client/PublicApi/net472/PublicAPI.Unshipped.txt @@ -0,0 +1 @@ +Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource.MachineLearning = 7 -> Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource diff --git a/src/client/Microsoft.Identity.Client/PublicApi/net8.0-android/PublicAPI.Unshipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/net8.0-android/PublicAPI.Unshipped.txt index e69de29bb2..6b01d55560 100644 --- a/src/client/Microsoft.Identity.Client/PublicApi/net8.0-android/PublicAPI.Unshipped.txt +++ b/src/client/Microsoft.Identity.Client/PublicApi/net8.0-android/PublicAPI.Unshipped.txt @@ -0,0 +1 @@ +Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource.MachineLearning = 7 -> Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource diff --git a/src/client/Microsoft.Identity.Client/PublicApi/net8.0-ios/PublicAPI.Unshipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/net8.0-ios/PublicAPI.Unshipped.txt index e69de29bb2..6b01d55560 100644 --- a/src/client/Microsoft.Identity.Client/PublicApi/net8.0-ios/PublicAPI.Unshipped.txt +++ b/src/client/Microsoft.Identity.Client/PublicApi/net8.0-ios/PublicAPI.Unshipped.txt @@ -0,0 +1 @@ +Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource.MachineLearning = 7 -> Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource diff --git a/src/client/Microsoft.Identity.Client/PublicApi/net8.0/PublicAPI.Unshipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/net8.0/PublicAPI.Unshipped.txt index e69de29bb2..6b01d55560 100644 --- a/src/client/Microsoft.Identity.Client/PublicApi/net8.0/PublicAPI.Unshipped.txt +++ b/src/client/Microsoft.Identity.Client/PublicApi/net8.0/PublicAPI.Unshipped.txt @@ -0,0 +1 @@ +Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource.MachineLearning = 7 -> Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource diff --git a/src/client/Microsoft.Identity.Client/PublicApi/netstandard2.0/PublicAPI.Unshipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/netstandard2.0/PublicAPI.Unshipped.txt index e69de29bb2..6b01d55560 100644 --- a/src/client/Microsoft.Identity.Client/PublicApi/netstandard2.0/PublicAPI.Unshipped.txt +++ b/src/client/Microsoft.Identity.Client/PublicApi/netstandard2.0/PublicAPI.Unshipped.txt @@ -0,0 +1 @@ +Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource.MachineLearning = 7 -> Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource diff --git a/tests/Microsoft.Identity.Test.Common/Core/Helpers/ManagedIdentityTestUtil.cs b/tests/Microsoft.Identity.Test.Common/Core/Helpers/ManagedIdentityTestUtil.cs index 225085029e..5ae110e4da 100644 --- a/tests/Microsoft.Identity.Test.Common/Core/Helpers/ManagedIdentityTestUtil.cs +++ b/tests/Microsoft.Identity.Test.Common/Core/Helpers/ManagedIdentityTestUtil.cs @@ -59,6 +59,10 @@ public static void SetEnvironmentVariables(ManagedIdentitySource managedIdentity Environment.SetEnvironmentVariable("IDENTITY_HEADER", secret); Environment.SetEnvironmentVariable("IDENTITY_SERVER_THUMBPRINT", thumbprint); break; + case ManagedIdentitySource.MachineLearning: + Environment.SetEnvironmentVariable("MSI_ENDPOINT", endpoint); + Environment.SetEnvironmentVariable("MSI_SECRET", secret); + break; } } diff --git a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs index 4678652e7e..0f76875fa1 100644 --- a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs +++ b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHttpManagerExtensions.cs @@ -393,7 +393,6 @@ public static void AddManagedIdentityMockHandler( httpManager.AddMockHandler(httpMessageHandler); } - private static MockHttpMessageHandler BuildMockHandlerForManagedIdentitySource(ManagedIdentitySource managedIdentitySourceType, string resource) { @@ -433,6 +432,13 @@ private static MockHttpMessageHandler BuildMockHandlerForManagedIdentitySource(M expectedQueryParams.Add("api-version", "2019-07-01-preview"); expectedQueryParams.Add("resource", resource); break; + case ManagedIdentitySource.MachineLearning: + httpMessageHandler.ExpectedMethod = HttpMethod.Get; + expectedRequestHeaders.Add("secret", "secret"); + expectedRequestHeaders.Add("Metadata", "true"); + expectedQueryParams.Add("api-version", "2017-09-01"); + expectedQueryParams.Add("resource", resource); + break; } if (managedIdentitySourceType != ManagedIdentitySource.CloudShell) diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/MachineLearningTests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/MachineLearningTests.cs new file mode 100644 index 0000000000..bc9ea5e844 --- /dev/null +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/MachineLearningTests.cs @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Globalization; +using System.Net; +using System.Threading.Tasks; +using Microsoft.Identity.Client; +using Microsoft.Identity.Client.AppConfig; +using Microsoft.Identity.Client.ManagedIdentity; +using Microsoft.Identity.Test.Common; +using Microsoft.Identity.Test.Common.Core.Helpers; +using Microsoft.Identity.Test.Common.Core.Mocks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using static Microsoft.Identity.Test.Common.Core.Helpers.ManagedIdentityTestUtil; + +namespace Microsoft.Identity.Test.Unit.ManagedIdentityTests +{ + [TestClass] + public class MachineLearningTests : TestBase + { + private const string MachineLearning = "Machine learning"; + + [TestMethod] + public async Task MachineLearningTestsInvalidEndpointAsync() + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager(isManagedIdentity: true)) + { + SetEnvironmentVariables(ManagedIdentitySource.MachineLearning, "127.0.0.1:41564/msi/token"); + + var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) + .WithHttpManager(httpManager); + + // Disabling shared cache options to avoid cross test pollution. + miBuilder.Config.AccessorOptions = null; + + var mi = miBuilder.Build(); + + MsalServiceException ex = await Assert.ThrowsExceptionAsync(async () => + await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource) + .ExecuteAsync().ConfigureAwait(false)).ConfigureAwait(false); + + Assert.IsNotNull(ex); + Assert.AreEqual(MsalError.InvalidManagedIdentityEndpoint, ex.ErrorCode); + Assert.AreEqual(ManagedIdentitySource.MachineLearning.ToString(), ex.AdditionalExceptionData[MsalException.ManagedIdentitySource]); + Assert.AreEqual(string.Format(CultureInfo.InvariantCulture, MsalErrorMessage.ManagedIdentityEndpointInvalidUriError, "MSI_ENDPOINT", "127.0.0.1:41564/msi/token", MachineLearning), ex.Message); + } + } + } +} diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs index f2dfe4449b..9ae8c9ce50 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs @@ -29,6 +29,7 @@ public class ManagedIdentityTests : TestBase internal const string Resource = "https://management.azure.com"; internal const string ResourceDefaultSuffix = "https://management.azure.com/.default"; internal const string AppServiceEndpoint = "http://127.0.0.1:41564/msi/token"; + internal const string MachineLearningEndpoint = "http://localhost:7071/msi/token"; internal const string ImdsEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token"; internal const string AzureArcEndpoint = "http://localhost:40342/metadata/identity/oauth2/token"; internal const string CloudShellEndpoint = "http://localhost:40342/metadata/identity/oauth2/token"; @@ -45,6 +46,7 @@ public class ManagedIdentityTests : TestBase [DataRow(AzureArcEndpoint, ManagedIdentitySource.AzureArc, ManagedIdentitySource.AzureArc)] [DataRow(CloudShellEndpoint, ManagedIdentitySource.CloudShell, ManagedIdentitySource.CloudShell)] [DataRow(ServiceFabricEndpoint, ManagedIdentitySource.ServiceFabric, ManagedIdentitySource.ServiceFabric)] + [DataRow(MachineLearningEndpoint, ManagedIdentitySource.MachineLearning, ManagedIdentitySource.MachineLearning)] public void GetManagedIdentityTests( string endpoint, ManagedIdentitySource managedIdentitySource, @@ -70,6 +72,8 @@ public void GetManagedIdentityTests( [DataRow(CloudShellEndpoint, ResourceDefaultSuffix, ManagedIdentitySource.CloudShell)] [DataRow(ServiceFabricEndpoint, Resource, ManagedIdentitySource.ServiceFabric)] [DataRow(ServiceFabricEndpoint, ResourceDefaultSuffix, ManagedIdentitySource.ServiceFabric)] + [DataRow(MachineLearningEndpoint, Resource, ManagedIdentitySource.MachineLearning)] + [DataRow(MachineLearningEndpoint, ResourceDefaultSuffix, ManagedIdentitySource.MachineLearning)] public async Task ManagedIdentityHappyPathAsync( string endpoint, string scope, @@ -119,6 +123,9 @@ public async Task ManagedIdentityHappyPathAsync( [DataRow(ServiceFabricEndpoint, ManagedIdentitySource.ServiceFabric, TestConstants.ClientId, UserAssignedIdentityId.ClientId)] [DataRow(ServiceFabricEndpoint, ManagedIdentitySource.ServiceFabric, TestConstants.MiResourceId, UserAssignedIdentityId .ResourceId)] [DataRow(ServiceFabricEndpoint, ManagedIdentitySource.ServiceFabric, TestConstants.MiResourceId, UserAssignedIdentityId.ObjectId)] + [DataRow(MachineLearningEndpoint, ManagedIdentitySource.MachineLearning, TestConstants.ClientId, UserAssignedIdentityId.ClientId)] + [DataRow(MachineLearningEndpoint, ManagedIdentitySource.MachineLearning, TestConstants.MiResourceId, UserAssignedIdentityId.ResourceId)] + [DataRow(MachineLearningEndpoint, ManagedIdentitySource.MachineLearning, TestConstants.MiResourceId, UserAssignedIdentityId.ObjectId)] public async Task ManagedIdentityUserAssignedHappyPathAsync( string endpoint, ManagedIdentitySource managedIdentitySource, @@ -165,6 +172,7 @@ public async Task ManagedIdentityUserAssignedHappyPathAsync( [DataRow(AzureArcEndpoint, Resource, "https://graph.microsoft.com", ManagedIdentitySource.AzureArc)] [DataRow(CloudShellEndpoint, Resource, "https://graph.microsoft.com", ManagedIdentitySource.CloudShell)] [DataRow(ServiceFabricEndpoint, Resource, "https://graph.microsoft.com", ManagedIdentitySource.ServiceFabric)] + [DataRow(MachineLearningEndpoint, Resource, "https://graph.microsoft.com", ManagedIdentitySource.MachineLearning)] public async Task ManagedIdentityDifferentScopesTestAsync( string endpoint, string scope, @@ -225,6 +233,7 @@ public async Task ManagedIdentityDifferentScopesTestAsync( [DataRow(AzureArcEndpoint, Resource, ManagedIdentitySource.AzureArc)] [DataRow(CloudShellEndpoint, Resource, ManagedIdentitySource.CloudShell)] [DataRow(ServiceFabricEndpoint, Resource, ManagedIdentitySource.ServiceFabric)] + [DataRow(MachineLearningEndpoint, Resource, ManagedIdentitySource.MachineLearning)] public async Task ManagedIdentityForceRefreshTestAsync( string endpoint, string scope, @@ -285,6 +294,7 @@ public async Task ManagedIdentityForceRefreshTestAsync( [DataRow(AzureArcEndpoint, Resource, ManagedIdentitySource.AzureArc)] [DataRow(CloudShellEndpoint, Resource, ManagedIdentitySource.CloudShell)] [DataRow(ServiceFabricEndpoint, Resource, ManagedIdentitySource.ServiceFabric)] + [DataRow(MachineLearningEndpoint, Resource, ManagedIdentitySource.MachineLearning)] public async Task ManagedIdentityWithClaimsAndCapabilitiesTestAsync( string endpoint, string scope, @@ -346,6 +356,7 @@ public async Task ManagedIdentityWithClaimsAndCapabilitiesTestAsync( [DataRow(AzureArcEndpoint, Resource, ManagedIdentitySource.AzureArc)] [DataRow(CloudShellEndpoint, Resource, ManagedIdentitySource.CloudShell)] [DataRow(ServiceFabricEndpoint, Resource, ManagedIdentitySource.ServiceFabric)] + [DataRow(MachineLearningEndpoint, Resource, ManagedIdentitySource.MachineLearning)] public async Task ManagedIdentityWithClaimsTestAsync( string endpoint, string scope, @@ -417,6 +428,9 @@ public async Task ManagedIdentityWithClaimsTestAsync( [DataRow("user.read", ManagedIdentitySource.ServiceFabric, ServiceFabricEndpoint)] [DataRow("https://management.core.windows.net//user_impersonation", ManagedIdentitySource.ServiceFabric, ServiceFabricEndpoint)] [DataRow("s", ManagedIdentitySource.ServiceFabric, ServiceFabricEndpoint)] + [DataRow("user.read", ManagedIdentitySource.MachineLearning, MachineLearningEndpoint)] + [DataRow("https://management.core.windows.net//user_impersonation", ManagedIdentitySource.MachineLearning, MachineLearningEndpoint)] + [DataRow("s", ManagedIdentitySource.MachineLearning, MachineLearningEndpoint)] public async Task ManagedIdentityTestWrongScopeAsync(string resource, ManagedIdentitySource managedIdentitySource, string endpoint) { using (new EnvVariableContext()) @@ -524,6 +538,7 @@ await mi.AcquireTokenForManagedIdentity(resource) [DataRow(ManagedIdentitySource.AzureArc, AzureArcEndpoint)] [DataRow(ManagedIdentitySource.CloudShell, CloudShellEndpoint)] [DataRow(ManagedIdentitySource.ServiceFabric, ServiceFabricEndpoint)] + [DataRow(ManagedIdentitySource.MachineLearning, MachineLearningEndpoint)] public async Task ManagedIdentityErrorResponseNoPayloadTestAsync(ManagedIdentitySource managedIdentitySource, string endpoint) { using (new EnvVariableContext()) @@ -565,6 +580,7 @@ await mi.AcquireTokenForManagedIdentity("scope") [DataRow(ManagedIdentitySource.AzureArc, AzureArcEndpoint)] [DataRow(ManagedIdentitySource.CloudShell, CloudShellEndpoint)] [DataRow(ManagedIdentitySource.ServiceFabric, ServiceFabricEndpoint)] + [DataRow(ManagedIdentitySource.MachineLearning, MachineLearningEndpoint)] public async Task ManagedIdentityNullResponseAsync(ManagedIdentitySource managedIdentitySource, string endpoint) { using (new EnvVariableContext()) @@ -604,6 +620,7 @@ await mi.AcquireTokenForManagedIdentity(Resource) [DataRow(ManagedIdentitySource.AzureArc, AzureArcEndpoint)] [DataRow(ManagedIdentitySource.CloudShell, CloudShellEndpoint)] [DataRow(ManagedIdentitySource.ServiceFabric, ServiceFabricEndpoint)] + [DataRow(ManagedIdentitySource.MachineLearning, MachineLearningEndpoint)] public async Task ManagedIdentityUnreachableNetworkAsync(ManagedIdentitySource managedIdentitySource, string endpoint) { using (new EnvVariableContext()) @@ -643,6 +660,7 @@ await mi.AcquireTokenForManagedIdentity(Resource) [DataRow(ManagedIdentitySource.AzureArc, AzureArcEndpoint, HttpStatusCode.NotFound)] [DataRow(ManagedIdentitySource.CloudShell, CloudShellEndpoint, HttpStatusCode.NotFound)] [DataRow(ManagedIdentitySource.ServiceFabric, ServiceFabricEndpoint, HttpStatusCode.NotFound)] + [DataRow(ManagedIdentitySource.MachineLearning, MachineLearningEndpoint, HttpStatusCode.NotFound)] public async Task ManagedIdentityTestRetryAsync(ManagedIdentitySource managedIdentitySource, string endpoint, HttpStatusCode statusCode) { using (new EnvVariableContext()) @@ -1111,5 +1129,177 @@ await mi.AcquireTokenForManagedIdentity("scope") Assert.AreEqual(MsalErrorMessage.ManagedIdentityJsonParseFailure, ex.Message); } } + + [DataTestMethod] + [DataRow(Resource, "https://graph.microsoft.com", ManagedIdentitySource.AppService, AppServiceEndpoint)] + [DataRow(Resource, "https://graph.microsoft.com", ManagedIdentitySource.Imds, ImdsEndpoint)] + [DataRow(Resource, "https://graph.microsoft.com", ManagedIdentitySource.AzureArc, AzureArcEndpoint)] + [DataRow(Resource, "https://graph.microsoft.com", ManagedIdentitySource.CloudShell, CloudShellEndpoint)] + [DataRow(Resource, "https://graph.microsoft.com", ManagedIdentitySource.ServiceFabric, ServiceFabricEndpoint)] + [DataRow(Resource, "https://graph.microsoft.com", ManagedIdentitySource.MachineLearning, MachineLearningEndpoint)] + public async Task ManagedIdentityRequestTokensForDifferentScopesTestAsync( + string initialResource, + string newResource, + ManagedIdentitySource source, + string endpoint) + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager(isManagedIdentity: true)) + { + SetEnvironmentVariables(source, endpoint); + + ManagedIdentityApplicationBuilder miBuilder = ManagedIdentityApplicationBuilder + .Create(ManagedIdentityId.SystemAssigned) + .WithHttpManager(httpManager); + + // Disabling shared cache options to avoid cross test pollution. + miBuilder.Config.AccessorOptions = null; + + IManagedIdentityApplication mi = miBuilder.Build(); + + // Mock handler for the initial resource request + httpManager.AddManagedIdentityMockHandler(endpoint, initialResource, + MockHelpers.GetMsiSuccessfulResponse(), source); + + // Request token for initial resource + AuthenticationResult result = await mi.AcquireTokenForManagedIdentity(initialResource).ExecuteAsync().ConfigureAwait(false); + Assert.IsNotNull(result.AccessToken); + Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource); + + // Mock handler for the new resource request + httpManager.AddManagedIdentityMockHandler(endpoint, newResource, + MockHelpers.GetMsiSuccessfulResponse(), source); + + // Request token for new resource + result = await mi.AcquireTokenForManagedIdentity(newResource).ExecuteAsync().ConfigureAwait(false); + Assert.IsNotNull(result.AccessToken); + Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource); + + // Request token again for the same initial resource to check cache usage + result = await mi.AcquireTokenForManagedIdentity(initialResource).ExecuteAsync().ConfigureAwait(false); + Assert.IsNotNull(result.AccessToken); + Assert.AreEqual(TokenSource.Cache, result.AuthenticationResultMetadata.TokenSource); + + // Request token again for the new resource to check cache usage + result = await mi.AcquireTokenForManagedIdentity(newResource).ExecuteAsync().ConfigureAwait(false); + Assert.IsNotNull(result.AccessToken); + Assert.AreEqual(TokenSource.Cache, result.AuthenticationResultMetadata.TokenSource); + } + } + + [DataTestMethod] + [DataRow(ManagedIdentitySource.AppService)] + [DataRow(ManagedIdentitySource.Imds)] + [DataRow(ManagedIdentitySource.AzureArc)] + [DataRow(ManagedIdentitySource.CloudShell)] + [DataRow(ManagedIdentitySource.ServiceFabric)] + [DataRow(ManagedIdentitySource.MachineLearning)] + public async Task UnsupportedManagedIdentitySource_ThrowsExceptionDuringTokenAcquisitionAsync( + ManagedIdentitySource managedIdentitySource) + { + string UnsupportedEndpoint = "unsupported://endpoint"; + + using (new EnvVariableContext()) + { + // Set unsupported environment variable + SetEnvironmentVariables(managedIdentitySource, UnsupportedEndpoint); + + // Create the Managed Identity Application + var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned); + + // Build the application + var mi = miBuilder.Build(); + + // Attempt to acquire a token and verify an exception is thrown + MsalServiceException ex = await Assert.ThrowsExceptionAsync(async () => + await mi.AcquireTokenForManagedIdentity("https://management.azure.com") + .ExecuteAsync() + .ConfigureAwait(false)).ConfigureAwait(false); + + // Verify the exception details + Assert.IsNotNull(ex); + Assert.AreEqual(MsalError.ManagedIdentityRequestFailed, ex.ErrorCode); + } + } + + [TestMethod] + public async Task MixedUserAndSystemAssignedManagedIdentityTestAsync() + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager(isManagedIdentity: true)) + { + SetEnvironmentVariables(ManagedIdentitySource.AppService, AppServiceEndpoint); + + // User-assigned identity client ID + string UserAssignedClientId = "d3adb33f-c0de-ed0c-c0de-deadb33fc0d3"; + string SystemAssignedClientId = "system_assigned_managed_identity"; + + // Create a builder for user-assigned identity + var userAssignedBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.WithUserAssignedClientId(UserAssignedClientId)) + .WithHttpManager(httpManager); + + // Disabling shared cache options to avoid cross test pollution. + userAssignedBuilder.Config.AccessorOptions = null; + + var userAssignedMI = userAssignedBuilder.BuildConcrete(); + + // Record token cache access for user-assigned identity + var userAssignedCacheRecorder = userAssignedMI.AppTokenCacheInternal.RecordAccess(); + + // Mock handler for user-assigned token + httpManager.AddManagedIdentityMockHandler( + AppServiceEndpoint, + Resource, + MockHelpers.GetMsiSuccessfulResponse(), + ManagedIdentitySource.AppService, + userAssignedId: UserAssignedClientId, + userAssignedIdentityId: UserAssignedIdentityId.ClientId); + + var userAssignedResult = await userAssignedMI.AcquireTokenForManagedIdentity(Resource).ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(userAssignedResult); + Assert.AreEqual(TokenSource.IdentityProvider, userAssignedResult.AuthenticationResultMetadata.TokenSource); + + // Verify user-assigned cache entries + userAssignedCacheRecorder.AssertAccessCounts(1, 1); + + // Create a builder for system-assigned identity + var systemAssignedBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) + .WithHttpManager(httpManager); + + systemAssignedBuilder.Config.AccessorOptions = null; + + var systemAssignedMI = systemAssignedBuilder.BuildConcrete(); + + // Record token cache access for system-assigned identity + var systemAssignedCacheRecorder = systemAssignedMI.AppTokenCacheInternal.RecordAccess(); + + // Mock handler for system-assigned token + httpManager.AddManagedIdentityMockHandler( + AppServiceEndpoint, + Resource, + MockHelpers.GetMsiSuccessfulResponse(), + ManagedIdentitySource.AppService); + + var systemAssignedResult = await systemAssignedMI.AcquireTokenForManagedIdentity(Resource).ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(systemAssignedResult); + Assert.AreEqual(TokenSource.IdentityProvider, systemAssignedResult.AuthenticationResultMetadata.TokenSource); + + // Verify system-assigned cache entries + systemAssignedCacheRecorder.AssertAccessCounts(1, 1); + + // Ensure the cache contains correct entries for both identities + var userAssignedTokens = userAssignedMI.AppTokenCacheInternal.Accessor.GetAllAccessTokens(); + var systemAssignedTokens = systemAssignedMI.AppTokenCacheInternal.Accessor.GetAllAccessTokens(); + + Assert.AreEqual(1, userAssignedTokens.Count, "User-assigned cache entry missing."); + Assert.AreEqual(1, systemAssignedTokens.Count, "System-assigned cache entry missing."); + + // Verify the ClientId for each cached entry + Assert.AreEqual(UserAssignedClientId, userAssignedTokens[0].ClientId, "User-assigned ClientId mismatch in cache."); + Assert.AreEqual(SystemAssignedClientId, systemAssignedTokens[0].ClientId, "System-assigned ClientId mismatch in cache."); + } + } } }