Skip to content

Commit

Permalink
use tokenservice endpoint instead of resolve at runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
mslhrotk committed Feb 7, 2024
1 parent dc1cffa commit 3c2dd56
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import spray.json.JsValue
import java.util.UUID
import scala.io.Source
import scala.util.{Failure, Success, Try}
import java.net.URL

object FabricClient extends RESTUtils {
private val PbiGlobalServiceEndpoints = Map(
Expand All @@ -29,6 +30,7 @@ object FabricClient extends RESTUtils {
var ArtifactID = "";
var PbiEnv = "";
var FabricContext: Map[String, String] = Map[String, String]();
var MLWorkloadHost = "";

private val WorkloadEndpointTypeML = "ML";
private val WorkloadEndpointTypeLLMPlugin = "LlmPlugin"
Expand All @@ -37,7 +39,7 @@ object FabricClient extends RESTUtils {
private val WorkloadEndpointTypeAdmin = "MLAdmin"

lazy val PbiSharedHost: String = getPbiSharedHost;
lazy val MLWorkloadHost: String = getMLWorkloadHost;

lazy val MLWorkloadEndpointML: String = getMLWorkloadEndpoint(WorkloadEndpointTypeML);
lazy val MLWorkloadEndpointLLMPlugin: String = getMLWorkloadEndpoint(WorkloadEndpointTypeLLMPlugin);
lazy val MLWorkloadEndpointAutomatic: String = getMLWorkloadEndpoint(WorkloadEndpointTypeAutomatic);
Expand All @@ -62,6 +64,22 @@ object FabricClient extends RESTUtils {
this.WorkspaceID = this.FabricContext.getOrElse("trident.artifact.workspace.id", "");
this.ArtifactID = this.FabricContext.getOrElse("trident.artifact.id", "");
this.PbiEnv = this.FabricContext.getOrElse("spark.trident.pbienv", "public").toLowerCase();
this.MLWorkloadHost = this.extractSchemeAndHost(
this.FabricContext.getOrElse("trident.lakehouse.tokenservice.endpoint", "https://")
).getOrElse("");
}

def extractSchemeAndHost(urlString: String): Option[String] = {
try {
val url = new URL(urlString)
val scheme = url.getProtocol
val host = url.getHost
Some(s"$scheme://$host")
} catch {
case _: Exception =>
// Handle MalformedURLException or other exceptions
None
}
}

def readFabricContextFile(): Unit = {
Expand Down Expand Up @@ -123,22 +141,6 @@ object FabricClient extends RESTUtils {
usageGet(clusterDetailUrl, headers).asJsObject.fields("clusterUrl").convertTo[String];
}

def getMLWorkloadHost: String = {
val payload =
s"""{
|"capacityObjectId": "$CapacityID",
|"workspaceObjectId": "$WorkspaceID",
|"workloadType": "ML"
|}""".stripMargin

val tokenUrl: String = s"$PbiSharedHost/metadata/v201606/generatemwctokenv2"

val targetHost: String = usagePost(tokenUrl, payload, getHeaders)
.asJsObject.fields("TargetUriHost").convertTo[String];

s"https://$targetHost"
}

def getMLWorkloadEndpoint(endpointType: String): String = {
s"$MLWorkloadHost/webapi/capacities/$CapacityID/workloads/ML/$endpointType/Automatic/workspaceid/$WorkspaceID/"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ object OpenAITokenLibrary extends SynapseMLLogging with AuthHeaderProvider {
def getAuthHeader: String = {
if (MLMWCToken != "" && !isTokenExpired(MLMWCToken)) {
logInfo("using cached openai mwc token")
MLMWCToken
"MwcToken " + MLMWCToken
}
else {
val artifactId = FabricClient.ArtifactID
Expand All @@ -30,7 +30,8 @@ object OpenAITokenLibrary extends SynapseMLLogging with AuthHeaderProvider {

try {
val token = FabricClient.usagePost(url, payload).asJsObject.fields("Token").convertTo[String];
logInfo("successfully fetch openai mwc token")
logInfo("successfully fetch openai mwc token");
MLMWCToken = token;
"MwcToken " + token
} catch {
case e: Throwable =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,76 +3,26 @@

package com.microsoft.azure.synapse.ml.logging.fabric

import com.microsoft.azure.synapse.ml.fabric.{RESTUtils, TokenLibrary}
import org.apache.spark.sql.SparkSession
import com.microsoft.azure.synapse.ml.fabric.FabricClient
import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails.runningOnFabric
import spray.json.DefaultJsonProtocol.{StringJsonFormat, _}
import spray.json._

import java.time.Instant
import java.util.UUID
import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails.runningOnFabric

object CertifiedEventClient extends RESTUtils {

private val PbiGlobalServiceEndpoints = Map(
"public" -> "https://api.powerbi.com/",
"fairfax" -> "https://api.powerbigov.us",
"mooncake" -> "https://api.powerbi.cn",
"blackforest" -> "https://app.powerbi.de",
"msit" -> "https://api.powerbi.com/",
"prod" -> "https://api.powerbi.com/",
"int3" -> "https://biazure-int-edog-redirect.analysis-df.windows.net/",
"dxt" -> "https://powerbistagingapi.analysis.windows.net/",
"edog" -> "https://biazure-int-edog-redirect.analysis-df.windows.net/",
"dev" -> "https://onebox-redirect.analysis.windows-int.net/",
"console" -> "http://localhost:5001/",
"daily" -> "https://dailyapi.powerbi.com/")


object CertifiedEventClient {
private lazy val CertifiedEventUri = getCertifiedEventUri

private def getHeaders: Map[String, String] = {
Map(
"Authorization" -> s"Bearer ${TokenLibrary.getAccessToken}",
"RequestId" -> UUID.randomUUID().toString,
"Content-Type" -> "application/json",
"x-ms-workload-resource-moniker" -> UUID.randomUUID().toString
)
}

private def getCertifiedEventUri: String = {
val sc = SparkSession.builder().getOrCreate().sparkContext
val workspaceId = sc.hadoopConfiguration.get("trident.artifact.workspace.id")
val capacityId = sc.hadoopConfiguration.get("trident.capacity.id")
val pbiEnv = sc.getConf.get("spark.trident.pbienv").toLowerCase()

val clusterDetailUrl = s"${PbiGlobalServiceEndpoints(pbiEnv)}powerbi/globalservice/v201606/clusterDetails"
val headers = getHeaders

val clusterUrl = usageGet(clusterDetailUrl, headers)
.asJsObject.fields("clusterUrl").convertTo[String]
val tokenUrl: String = s"$clusterUrl/metadata/v201606/generatemwctokenv2"

val payload =
s"""{
|"capacityObjectId": "$capacityId",
|"workspaceObjectId": "$workspaceId",
|"workloadType": "ML"
|}""".stripMargin


val host = usagePost(tokenUrl, payload, headers)
.asJsObject.fields("TargetUriHost").convertTo[String]

s"https://$host/webapi/Capacities/$capacityId/workloads/ML/MLAdmin/Automatic/workspaceid/$workspaceId/telemetry"
def getCertifiedEventUri: String = {
s"${FabricClient.MLWorkloadEndpointAdmin}/telemetry"
}


private[ml] def logToCertifiedEvents(featureName: String,
def logToCertifiedEvents(featureName: String,
activityName: String,
attributes: Map[String, String]): Unit = {

if (runningOnFabric) {
if (runningOnFabric) {
val payload =
s"""{
|"timestamp":${Instant.now().getEpochSecond},
Expand All @@ -81,7 +31,7 @@ object CertifiedEventClient extends RESTUtils {
|"attributes":${attributes.toJson.compactPrint}
|}""".stripMargin

usagePost(CertifiedEventUri, payload, getHeaders)
FabricClient.usagePost(CertifiedEventUri, payload)
}
}
}

0 comments on commit 3c2dd56

Please sign in to comment.