diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/fabric/FabricClient.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/fabric/FabricClient.scala index f81d586408..ceeb565f6f 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/fabric/FabricClient.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/fabric/FabricClient.scala @@ -8,6 +8,7 @@ import spray.json.JsValue import java.util.UUID import scala.io.Source import scala.util.{Failure, Success, Try} +import java.net.URL object FabricClient extends RESTUtils { private val PbiGlobalServiceEndpoints = Map( @@ -29,6 +30,7 @@ object FabricClient extends RESTUtils { var ArtifactID = ""; var PbiEnv = ""; var FabricContext: Map[String, String] = Map[String, String](); + var MLWorkloadHost = ""; private val WorkloadEndpointTypeML = "ML"; private val WorkloadEndpointTypeLLMPlugin = "LlmPlugin" @@ -37,7 +39,7 @@ object FabricClient extends RESTUtils { private val WorkloadEndpointTypeAdmin = "MLAdmin" lazy val PbiSharedHost: String = getPbiSharedHost; - lazy val MLWorkloadHost: String = getMLWorkloadHost; + lazy val MLWorkloadEndpointML: String = getMLWorkloadEndpoint(WorkloadEndpointTypeML); lazy val MLWorkloadEndpointLLMPlugin: String = getMLWorkloadEndpoint(WorkloadEndpointTypeLLMPlugin); lazy val MLWorkloadEndpointAutomatic: String = getMLWorkloadEndpoint(WorkloadEndpointTypeAutomatic); @@ -62,6 +64,22 @@ object FabricClient extends RESTUtils { this.WorkspaceID = this.FabricContext.getOrElse("trident.artifact.workspace.id", ""); this.ArtifactID = this.FabricContext.getOrElse("trident.artifact.id", ""); this.PbiEnv = this.FabricContext.getOrElse("spark.trident.pbienv", "public").toLowerCase(); + this.MLWorkloadHost = this.extractSchemeAndHost( + this.FabricContext.getOrElse("trident.lakehouse.tokenservice.endpoint", "https://") + ).getOrElse(""); + } + + def extractSchemeAndHost(urlString: String): Option[String] = { + try { + val url = new URL(urlString) + val scheme = url.getProtocol + val host = url.getHost + Some(s"$scheme://$host") + } catch { + case _: Exception => + // Handle MalformedURLException or other exceptions + None + } } def readFabricContextFile(): Unit = { @@ -123,22 +141,6 @@ object FabricClient extends RESTUtils { usageGet(clusterDetailUrl, headers).asJsObject.fields("clusterUrl").convertTo[String]; } - def getMLWorkloadHost: String = { - val payload = - s"""{ - |"capacityObjectId": "$CapacityID", - |"workspaceObjectId": "$WorkspaceID", - |"workloadType": "ML" - |}""".stripMargin - - val tokenUrl: String = s"$PbiSharedHost/metadata/v201606/generatemwctokenv2" - - val targetHost: String = usagePost(tokenUrl, payload, getHeaders) - .asJsObject.fields("TargetUriHost").convertTo[String]; - - s"https://$targetHost" - } - def getMLWorkloadEndpoint(endpointType: String): String = { s"$MLWorkloadHost/webapi/capacities/$CapacityID/workloads/ML/$endpointType/Automatic/workspaceid/$WorkspaceID/" } diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/fabric/OpenAITokenLibrary.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/fabric/OpenAITokenLibrary.scala index d7a68f1095..d7b7dadd78 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/fabric/OpenAITokenLibrary.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/fabric/OpenAITokenLibrary.scala @@ -16,7 +16,7 @@ object OpenAITokenLibrary extends SynapseMLLogging with AuthHeaderProvider { def getAuthHeader: String = { if (MLMWCToken != "" && !isTokenExpired(MLMWCToken)) { logInfo("using cached openai mwc token") - MLMWCToken + "MwcToken " + MLMWCToken } else { val artifactId = FabricClient.ArtifactID @@ -30,7 +30,8 @@ object OpenAITokenLibrary extends SynapseMLLogging with AuthHeaderProvider { try { val token = FabricClient.usagePost(url, payload).asJsObject.fields("Token").convertTo[String]; - logInfo("successfully fetch openai mwc token") + logInfo("successfully fetch openai mwc token"); + MLMWCToken = token; "MwcToken " + token } catch { case e: Throwable => diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/logging/fabric/CertifiedEventClient.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/logging/fabric/CertifiedEventClient.scala index a136b34acf..eb60098040 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/logging/fabric/CertifiedEventClient.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/logging/fabric/CertifiedEventClient.scala @@ -3,76 +3,26 @@ package com.microsoft.azure.synapse.ml.logging.fabric -import com.microsoft.azure.synapse.ml.fabric.{RESTUtils, TokenLibrary} -import org.apache.spark.sql.SparkSession +import com.microsoft.azure.synapse.ml.fabric.FabricClient +import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails.runningOnFabric import spray.json.DefaultJsonProtocol.{StringJsonFormat, _} import spray.json._ import java.time.Instant -import java.util.UUID -import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails.runningOnFabric - -object CertifiedEventClient extends RESTUtils { - - private val PbiGlobalServiceEndpoints = Map( - "public" -> "https://api.powerbi.com/", - "fairfax" -> "https://api.powerbigov.us", - "mooncake" -> "https://api.powerbi.cn", - "blackforest" -> "https://app.powerbi.de", - "msit" -> "https://api.powerbi.com/", - "prod" -> "https://api.powerbi.com/", - "int3" -> "https://biazure-int-edog-redirect.analysis-df.windows.net/", - "dxt" -> "https://powerbistagingapi.analysis.windows.net/", - "edog" -> "https://biazure-int-edog-redirect.analysis-df.windows.net/", - "dev" -> "https://onebox-redirect.analysis.windows-int.net/", - "console" -> "http://localhost:5001/", - "daily" -> "https://dailyapi.powerbi.com/") - +object CertifiedEventClient { private lazy val CertifiedEventUri = getCertifiedEventUri - private def getHeaders: Map[String, String] = { - Map( - "Authorization" -> s"Bearer ${TokenLibrary.getAccessToken}", - "RequestId" -> UUID.randomUUID().toString, - "Content-Type" -> "application/json", - "x-ms-workload-resource-moniker" -> UUID.randomUUID().toString - ) - } - - private def getCertifiedEventUri: String = { - val sc = SparkSession.builder().getOrCreate().sparkContext - val workspaceId = sc.hadoopConfiguration.get("trident.artifact.workspace.id") - val capacityId = sc.hadoopConfiguration.get("trident.capacity.id") - val pbiEnv = sc.getConf.get("spark.trident.pbienv").toLowerCase() - - val clusterDetailUrl = s"${PbiGlobalServiceEndpoints(pbiEnv)}powerbi/globalservice/v201606/clusterDetails" - val headers = getHeaders - - val clusterUrl = usageGet(clusterDetailUrl, headers) - .asJsObject.fields("clusterUrl").convertTo[String] - val tokenUrl: String = s"$clusterUrl/metadata/v201606/generatemwctokenv2" - - val payload = - s"""{ - |"capacityObjectId": "$capacityId", - |"workspaceObjectId": "$workspaceId", - |"workloadType": "ML" - |}""".stripMargin - - - val host = usagePost(tokenUrl, payload, headers) - .asJsObject.fields("TargetUriHost").convertTo[String] - - s"https://$host/webapi/Capacities/$capacityId/workloads/ML/MLAdmin/Automatic/workspaceid/$workspaceId/telemetry" + def getCertifiedEventUri: String = { + s"${FabricClient.MLWorkloadEndpointAdmin}/telemetry" } - private[ml] def logToCertifiedEvents(featureName: String, + def logToCertifiedEvents(featureName: String, activityName: String, attributes: Map[String, String]): Unit = { - if (runningOnFabric) { + if (runningOnFabric) { val payload = s"""{ |"timestamp":${Instant.now().getEpochSecond}, @@ -81,7 +31,7 @@ object CertifiedEventClient extends RESTUtils { |"attributes":${attributes.toJson.compactPrint} |}""".stripMargin - usagePost(CertifiedEventUri, payload, getHeaders) + FabricClient.usagePost(CertifiedEventUri, payload) } } }