Skip to content

Commit

Permalink
Add ML MSI Source (#5053)
Browse files Browse the repository at this point in the history
* ese
nit

Update ManagedIdentity environment variables and add MachineLearning source#	src/client/Microsoft.Identity.Client/PublicApi/net472/PublicAPI.Unshipped.txt

* pr comments

* Update src/client/Microsoft.Identity.Client/ManagedIdentity/MachineLearningManagedIdentitySource.cs

Co-authored-by: Neha Bhargava <[email protected]>

* tests

* Metadata

* improve tests

---------

Co-authored-by: Gladwin Johnson <[email protected]>
Co-authored-by: Neha Bhargava <[email protected]>
  • Loading branch information
3 people authored Jan 9, 2025
1 parent 1dc9597 commit 03f4b6d
Show file tree
Hide file tree
Showing 14 changed files with 367 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ internal class EnvironmentVariables
public static string PodIdentityEndpoint => Environment.GetEnvironmentVariable("AZURE_POD_IDENTITY_AUTHORITY_HOST");
public static string ImdsEndpoint => Environment.GetEnvironmentVariable("IMDS_ENDPOINT");
public static string MsiEndpoint => Environment.GetEnvironmentVariable("MSI_ENDPOINT");
public static string MsiSecret => Environment.GetEnvironmentVariable("MSI_SECRET");
public static string IdentityServerThumbprint => Environment.GetEnvironmentVariable("IDENTITY_SERVER_THUMBPRINT");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Globalization;
using Microsoft.Identity.Client.Core;
using Microsoft.Identity.Client.Internal;

namespace Microsoft.Identity.Client.ManagedIdentity
{
internal class MachineLearningManagedIdentitySource : AbstractManagedIdentity
{
private const string MachineLearningMsiApiVersion = "2017-09-01";
private const string SecretHeaderName = "secret";

private readonly Uri _endpoint;
private readonly string _secret;

public static AbstractManagedIdentity Create(RequestContext requestContext)
{
requestContext.Logger.Info(() => "[Managed Identity] Machine learning managed identity is available.");

return TryValidateEnvVars(EnvironmentVariables.MsiEndpoint, requestContext.Logger, out Uri endpointUri)
? new MachineLearningManagedIdentitySource(requestContext, endpointUri, EnvironmentVariables.MsiSecret)
: null;
}

private MachineLearningManagedIdentitySource(RequestContext requestContext, Uri endpoint, string secret)
: base(requestContext, ManagedIdentitySource.MachineLearning)
{
_endpoint = endpoint;
_secret = secret;
}

private static bool TryValidateEnvVars(string msiEndpoint, ILoggerAdapter logger, out Uri endpointUri)
{
endpointUri = null;

try
{
endpointUri = new Uri(msiEndpoint);
}
catch (FormatException ex)
{
string errorMessage = string.Format(
CultureInfo.InvariantCulture,
MsalErrorMessage.ManagedIdentityEndpointInvalidUriError,
"MSI_ENDPOINT", msiEndpoint, "Machine learning");

// Use the factory to create and throw the exception
var exception = MsalServiceExceptionFactory.CreateManagedIdentityException(
MsalError.InvalidManagedIdentityEndpoint,
errorMessage,
ex,
ManagedIdentitySource.MachineLearning,
null); // statusCode is null in this case

throw exception;
}

logger.Info($"[Managed Identity] Environment variables validation passed for machine learning managed identity. Endpoint URI: {endpointUri}. Creating machine learning managed identity.");
return true;
}

protected override ManagedIdentityRequest CreateRequest(string resource)
{
ManagedIdentityRequest request = new(System.Net.Http.HttpMethod.Get, _endpoint);

request.Headers.Add("Metadata", "true");
request.Headers.Add(SecretHeaderName, _secret);
request.QueryParameters["api-version"] = MachineLearningMsiApiVersion;
request.QueryParameters["resource"] = resource;

switch (_requestContext.ServiceBundle.Config.ManagedIdentityId.IdType)
{
case AppConfig.ManagedIdentityIdType.ClientId:
_requestContext.Logger.Info("[Managed Identity] Adding user assigned client id to the request.");
request.QueryParameters[Constants.ManagedIdentityClientId] = _requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId;
break;

case AppConfig.ManagedIdentityIdType.ResourceId:
_requestContext.Logger.Info("[Managed Identity] Adding user assigned resource id to the request.");
request.QueryParameters[Constants.ManagedIdentityResourceId] = _requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId;
break;

case AppConfig.ManagedIdentityIdType.ObjectId:
_requestContext.Logger.Info("[Managed Identity] Adding user assigned object id to the request.");
request.QueryParameters[Constants.ManagedIdentityObjectId] = _requestContext.ServiceBundle.Config.ManagedIdentityId.UserAssignedId;
break;
}

return request;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ private static AbstractManagedIdentity SelectManagedIdentitySource(RequestContex
{
ManagedIdentitySource.ServiceFabric => ServiceFabricManagedIdentitySource.Create(requestContext),
ManagedIdentitySource.AppService => AppServiceManagedIdentitySource.Create(requestContext),
ManagedIdentitySource.MachineLearning => MachineLearningManagedIdentitySource.Create(requestContext),
ManagedIdentitySource.CloudShell => CloudShellManagedIdentitySource.Create(requestContext),
ManagedIdentitySource.AzureArc => AzureArcManagedIdentitySource.Create(requestContext),
_ => new ImdsManagedIdentitySource(requestContext)
Expand All @@ -57,11 +58,15 @@ internal static ManagedIdentitySource GetManagedIdentitySource(ILoggerAdapter lo
string identityServerThumbprint = EnvironmentVariables.IdentityServerThumbprint;
string msiSecret = EnvironmentVariables.IdentityHeader;
string msiEndpoint = EnvironmentVariables.MsiEndpoint;
string msiSecretMachineLearning = EnvironmentVariables.MsiSecret;
string imdsEndpoint = EnvironmentVariables.ImdsEndpoint;
string podIdentityEndpoint = EnvironmentVariables.PodIdentityEndpoint;


if (!string.IsNullOrEmpty(identityEndpoint) && !string.IsNullOrEmpty(identityHeader))
if (!string.IsNullOrEmpty(msiSecretMachineLearning) && !string.IsNullOrEmpty(msiEndpoint))
{
return ManagedIdentitySource.MachineLearning;
}
else if (!string.IsNullOrEmpty(identityEndpoint) && !string.IsNullOrEmpty(identityHeader))
{
if (!string.IsNullOrEmpty(identityServerThumbprint))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ public enum ManagedIdentitySource
/// Indicates that the source is defaulted to IMDS since no environment variables are set.
/// This is used to detect the managed identity source.
/// </summary>
DefaultToImds
DefaultToImds,

/// <summary>
/// The source to acquire token for managed identity is Machine Learning Service.
/// </summary>
MachineLearning
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource.MachineLearning = 7 -> Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource.MachineLearning = 7 -> Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource.MachineLearning = 7 -> Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource.MachineLearning = 7 -> Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource.MachineLearning = 7 -> Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource.MachineLearning = 7 -> Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ public static void SetEnvironmentVariables(ManagedIdentitySource managedIdentity
Environment.SetEnvironmentVariable("IDENTITY_HEADER", secret);
Environment.SetEnvironmentVariable("IDENTITY_SERVER_THUMBPRINT", thumbprint);
break;
case ManagedIdentitySource.MachineLearning:
Environment.SetEnvironmentVariable("MSI_ENDPOINT", endpoint);
Environment.SetEnvironmentVariable("MSI_SECRET", secret);
break;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,6 @@ public static void AddManagedIdentityMockHandler(

httpManager.AddMockHandler(httpMessageHandler);
}


private static MockHttpMessageHandler BuildMockHandlerForManagedIdentitySource(ManagedIdentitySource managedIdentitySourceType, string resource)
{
Expand Down Expand Up @@ -433,6 +432,13 @@ private static MockHttpMessageHandler BuildMockHandlerForManagedIdentitySource(M
expectedQueryParams.Add("api-version", "2019-07-01-preview");
expectedQueryParams.Add("resource", resource);
break;
case ManagedIdentitySource.MachineLearning:
httpMessageHandler.ExpectedMethod = HttpMethod.Get;
expectedRequestHeaders.Add("secret", "secret");
expectedRequestHeaders.Add("Metadata", "true");
expectedQueryParams.Add("api-version", "2017-09-01");
expectedQueryParams.Add("resource", resource);
break;
}

if (managedIdentitySourceType != ManagedIdentitySource.CloudShell)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Globalization;
using System.Net;
using System.Threading.Tasks;
using Microsoft.Identity.Client;
using Microsoft.Identity.Client.AppConfig;
using Microsoft.Identity.Client.ManagedIdentity;
using Microsoft.Identity.Test.Common;
using Microsoft.Identity.Test.Common.Core.Helpers;
using Microsoft.Identity.Test.Common.Core.Mocks;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using static Microsoft.Identity.Test.Common.Core.Helpers.ManagedIdentityTestUtil;

namespace Microsoft.Identity.Test.Unit.ManagedIdentityTests
{
[TestClass]
public class MachineLearningTests : TestBase
{
private const string MachineLearning = "Machine learning";

[TestMethod]
public async Task MachineLearningTestsInvalidEndpointAsync()
{
using (new EnvVariableContext())
using (var httpManager = new MockHttpManager(isManagedIdentity: true))
{
SetEnvironmentVariables(ManagedIdentitySource.MachineLearning, "127.0.0.1:41564/msi/token");

var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned)
.WithHttpManager(httpManager);

// Disabling shared cache options to avoid cross test pollution.
miBuilder.Config.AccessorOptions = null;

var mi = miBuilder.Build();

MsalServiceException ex = await Assert.ThrowsExceptionAsync<MsalServiceException>(async () =>
await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource)
.ExecuteAsync().ConfigureAwait(false)).ConfigureAwait(false);

Assert.IsNotNull(ex);
Assert.AreEqual(MsalError.InvalidManagedIdentityEndpoint, ex.ErrorCode);
Assert.AreEqual(ManagedIdentitySource.MachineLearning.ToString(), ex.AdditionalExceptionData[MsalException.ManagedIdentitySource]);
Assert.AreEqual(string.Format(CultureInfo.InvariantCulture, MsalErrorMessage.ManagedIdentityEndpointInvalidUriError, "MSI_ENDPOINT", "127.0.0.1:41564/msi/token", MachineLearning), ex.Message);
}
}
}
}
Loading

0 comments on commit 03f4b6d

Please sign in to comment.