Skip to content

Commit

Permalink
Adding support for Custom MultiLabel Classification and Single Label …
Browse files Browse the repository at this point in the history
…classification. Unit tests are added to validate that requests and response are correct. Also added tiemout for AbstractiveSummary requests.
  • Loading branch information
FMasudMsft committed Jan 31, 2025
1 parent 2227cd6 commit 214a578
Show file tree
Hide file tree
Showing 5 changed files with 175 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,38 +41,28 @@ case class ExtractedSummarySentence(text: String,
offset: Int,
length: Int)

object ExtractedSummarySentence extends SparkBindings[ExtractedSummarySentence]

case class ExtractedSummaryDocumentResult(id: String,
warnings: Seq[DocumentWarning],
statistics: Option[RequestStatistics],
sentences: Seq[ExtractedSummarySentence])

object ExtractedSummaryDocumentResult extends SparkBindings[ExtractedSummaryDocumentResult]

case class ExtractiveSummarizationResult(errors: Seq[ATError],
statistics: Option[RequestStatistics],
modelVersion: String,
documents: Seq[ExtractedSummaryDocumentResult])

object ExtractiveSummarizationResult extends SparkBindings[ExtractiveSummarizationResult]

case class ExtractiveSummarizationLROResult(results: ExtractiveSummarizationResult,
lastUpdateDateTime: String,
status: String,
taskName: Option[String],
kind: String)

object ExtractiveSummarizationLROResult extends SparkBindings[ExtractiveSummarizationLROResult]

case class ExtractiveSummarizationTaskResult(completed: Int,
failed: Int,
inProgress: Int,
total: Int,
items: Option[Seq[ExtractiveSummarizationLROResult]])

object ExtractiveSummarizationTaskResult extends SparkBindings[ExtractiveSummarizationTaskResult]

case class ExtractiveSummarizationJobState(displayName: Option[String],
createdDateTime: String,
expirationDateTime: Option[String],
Expand Down Expand Up @@ -112,8 +102,6 @@ case class AbstractiveSummarizationJobsInput(displayName: Option[String],
case class AbstractiveSummary(text: String,
contexts: Option[Seq[SummaryContext]])

object AbstractiveSummary extends SparkBindings[AbstractiveSummary]

case class AbstractiveSummaryDocumentResult(id: String,
warnings: Seq[DocumentWarning],
statistics: Option[RequestStatistics],
Expand All @@ -126,25 +114,18 @@ case class AbstractiveSummarizationResult(errors: Seq[ATError],
modelVersion: String,
documents: Seq[AbstractiveSummaryDocumentResult])

object AbstractiveSummarizationResult extends SparkBindings[AbstractiveSummarizationResult]

case class AbstractiveSummarizationLROResult(results: AbstractiveSummarizationResult,
lastUpdateDateTime: String,
status: String,
taskName: Option[String],
kind: String)

object AbstractiveSummarizationLROResult extends SparkBindings[AbstractiveSummarizationLROResult]


case class AbstractiveSummarizationTaskResult(completed: Int,
failed: Int,
inProgress: Int,
total: Int,
items: Option[Seq[AbstractiveSummarizationLROResult]])

object AbstractiveSummarizationTaskResult extends SparkBindings[AbstractiveSummarizationTaskResult]

case class AbstractiveSummarizationJobState(displayName: Option[String],
createdDateTime: String,
expirationDateTime: Option[String],
Expand Down Expand Up @@ -179,17 +160,13 @@ case class HealthcareAssertion(conditionality: Option[String],
association: Option[String],
temporality: Option[String])

object HealthcareAssertion extends SparkBindings[HealthcareAssertion]

case class HealthcareEntitiesDocumentResult(id: String,
warnings: Seq[DocumentWarning],
statistics: Option[RequestStatistics],
entities: Seq[HealthcareEntity],
relations: Seq[HealthcareRelation],
fhirBundle: Option[String])

object HealthcareEntitiesDocumentResult extends SparkBindings[HealthcareEntitiesDocumentResult]

case class HealthcareEntity(text: String,
category: String,
subcategory: Option[String],
Expand All @@ -200,48 +177,33 @@ case class HealthcareEntity(text: String,
name: Option[String],
links: Option[Seq[HealthcareEntityLink]])

object HealthcareEntity extends SparkBindings[HealthcareEntity]

case class HealthcareEntityLink(dataSource: String,
id: String)

object HealthcareEntityLink extends SparkBindings[HealthcareEntityLink]

case class HealthcareLROResult(results: HealthcareResult,
lastUpdateDateTime: String,
status: String,
taskName: Option[String],
kind: String)

object HealthcareLROResult extends SparkBindings[HealthcareLROResult]


case class HealthcareRelation(relationType: String,
entities: Seq[HealthcareRelationEntity],
confidenceScore: Option[Double])

object HealthcareRelation extends SparkBindings[HealthcareRelation]

case class HealthcareRelationEntity(ref: String,
role: String)

object HealthcareRelationEntity extends SparkBindings[HealthcareRelationEntity]

case class HealthcareResult(errors: Seq[DocumentError],
statistics: Option[RequestStatistics],
modelVersion: String,
documents: Seq[HealthcareEntitiesDocumentResult])

object HealthcareResult extends SparkBindings[HealthcareResult]

case class HealthcareTaskResult(completed: Int,
failed: Int,
inProgress: Int,
total: Int,
items: Option[Seq[HealthcareLROResult]])

object HealthcareTaskResult extends SparkBindings[HealthcareTaskResult]

case class HealthcareJobState(displayName: Option[String],
createdDateTime: String,
expirationDateTime: Option[String],
Expand Down Expand Up @@ -274,16 +236,12 @@ case class SentimentAnalysisLROResult(results: SentimentResult,
taskName: Option[String],
kind: String)

object SentimentAnalysisLROResult extends SparkBindings[SentimentAnalysisLROResult]

case class SentimentAnalysisTaskResult(completed: Int,
failed: Int,
inProgress: Int,
total: Int,
items: Option[Seq[SentimentAnalysisLROResult]])

object SentimentAnalysisTaskResult extends SparkBindings[SentimentAnalysisTaskResult]

case class SentimentAnalysisJobState(displayName: Option[String],
createdDateTime: String,
expirationDateTime: Option[String],
Expand Down Expand Up @@ -315,16 +273,12 @@ case class KeyPhraseExtractionLROResult(results: KeyPhraseExtractionResult,
taskName: Option[String],
kind: String)

object KeyPhraseExtractionLROResult extends SparkBindings[KeyPhraseExtractionLROResult]

case class KeyPhraseExtractionTaskResult(completed: Int,
failed: Int,
inProgress: Int,
total: Int,
items: Option[Seq[KeyPhraseExtractionLROResult]])

object KeyPhraseExtractionTaskResult extends SparkBindings[KeyPhraseExtractionTaskResult]

case class KeyPhraseExtractionJobState(displayName: Option[String],
createdDateTime: String,
expirationDateTime: Option[String],
Expand Down Expand Up @@ -355,23 +309,18 @@ case class PiiEntityRecognitionJobsInput(displayName: Option[String],
analysisInput: MultiLanguageAnalysisInput,
tasks: Seq[PiiEntityRecognitionLROTask])


case class PiiEntityRecognitionLROResult(results: PIIResult,
lastUpdateDateTime: String,
status: String,
taskName: Option[String],
kind: String)

object PiiEntityRecognitionLROResult extends SparkBindings[PiiEntityRecognitionLROResult]

case class PiiEntityRecognitionTaskResult(completed: Int,
failed: Int,
inProgress: Int,
total: Int,
items: Option[Seq[PiiEntityRecognitionLROResult]])

object PiiEntityRecognitionTaskResult extends SparkBindings[PiiEntityRecognitionTaskResult]

case class PiiEntityRecognitionJobState(displayName: Option[String],
createdDateTime: String,
expirationDateTime: Option[String],
Expand Down Expand Up @@ -405,16 +354,12 @@ case class EntityLinkingLROResult(results: EntityLinkingResult,
taskName: Option[String],
kind: String)

object EntityLinkingLROResult extends SparkBindings[EntityLinkingLROResult]

case class EntityLinkingTaskResult(completed: Int,
failed: Int,
inProgress: Int,
total: Int,
items: Option[Seq[EntityLinkingLROResult]])

object EntityLinkingTaskResult extends SparkBindings[EntityLinkingTaskResult]

case class EntityLinkingJobState(displayName: Option[String],
createdDateTime: String,
expirationDateTime: Option[String],
Expand Down Expand Up @@ -459,16 +404,12 @@ case class EntityRecognitionLROResult(results: EntityRecognitionResult,
taskName: Option[String],
kind: String)

object EntityRecognitionLROResult extends SparkBindings[EntityRecognitionLROResult]

case class EntityRecognitionTaskResult(completed: Int,
failed: Int,
inProgress: Int,
total: Int,
items: Option[Seq[EntityRecognitionLROResult]])

object EntityRecognitionTaskResult extends SparkBindings[EntityRecognitionTaskResult]

case class EntityRecognitionJobState(displayName: Option[String],
createdDateTime: String,
expirationDateTime: Option[String],
Expand Down Expand Up @@ -518,9 +459,9 @@ case class CustomLabelJobsInput(displayName: Option[String],
case class ClassificationDocumentResult(id: String,
warnings: Seq[DocumentWarning],
statistics: Option[RequestStatistics],
classes: Seq[ClassificationResult])
classifications: Seq[ClassificationResult])

object ClassificationDocumentResult extends SparkBindings[ClassificationDocumentResult]
//object ClassificationDocumentResult extends SparkBindings[ClassificationDocumentResult]

case class ClassificationResult(category: String,
confidenceScore: Double)
Expand All @@ -532,24 +473,18 @@ case class CustomLabelResult(errors: Seq[DocumentError],
modelVersion: String,
documents: Seq[ClassificationDocumentResult])

object CustomLabelResult extends SparkBindings[CustomLabelResult]

case class CustomLabelLROResult(results: CustomLabelResult,
lastUpdateDateTime: String,
status: String,
taskName: Option[String],
kind: String)

object CustomLabelLROResult extends SparkBindings[CustomLabelLROResult]

case class CustomLabelTaskResult(completed: Int,
failed: Int,
inProgress: Int,
total: Int,
items: Option[Seq[CustomLabelLROResult]])

object CustomLabelTaskResult extends SparkBindings[CustomLabelTaskResult]

case class CustomLabelJobState(displayName: Option[String],
createdDateTime: String,
expirationDateTime: Option[String],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,23 @@

package com.microsoft.azure.synapse.ml.services.language

import com.microsoft.azure.synapse.ml.io.http.{ EntityData, HTTPResponseData }
import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging
import com.microsoft.azure.synapse.ml.param.ServiceParam
import com.microsoft.azure.synapse.ml.services.HasServiceParams
import com.microsoft.azure.synapse.ml.services.language.ATLROJSONFormat._
import com.microsoft.azure.synapse.ml.services.language.PiiDomain.PiiDomain
import com.microsoft.azure.synapse.ml.services.language.SummaryLength.SummaryLength
import com.microsoft.azure.synapse.ml.services.vision.BasicAsyncReply
import org.apache.commons.io.IOUtils
import org.apache.http.impl.client.CloseableHttpClient
import org.apache.spark.ml.param.ParamValidators
import org.apache.spark.sql.Row
import spray.json.DefaultJsonProtocol._
import spray.json.enrichAny

import java.net.URI

object AnalysisTaskKind extends Enumeration {
type AnalysisTaskKind = Value
val SentimentAnalysis,
Expand Down Expand Up @@ -511,8 +518,101 @@ trait HandleCustomEntityRecognition extends HasServiceParams
}
}

/**
* Trait `ModifiableAsyncReply` extends `BasicAsyncReply` and provides a mechanism to modify the HTTP response
* received from an asynchronous service call. This trait is designed to be mixed into classes that require
* custom handling of the response data.
*
* The primary purpose of this trait is to allow modification of the response before it is processed further.
* This is particularly useful in scenarios where the response needs to be transformed or certain fields need
* to be renamed to comply with specific requirements or constraints.
*
* In this implementation, the `queryForResult` method is overridden and marked as `final` to prevent further
* overriding. This ensures that the response modification logic is consistently applied across all subclasses.
*
* @note This trait is designed to be used with the `SynapseMLLogging` trait for consistent logging.
*/
trait ModifiableAsyncReply extends BasicAsyncReply {
self: SynapseMLLogging =>

protected def modifyResponse(response: Option[HTTPResponseData]): Option[HTTPResponseData] = response

/**
* Queries for the result of an asynchronous service call and applies the response modification logic.
*/
override final protected def queryForResult(key: Option[String],
client: CloseableHttpClient,
location: URI): Option[HTTPResponseData] = {
val originalResponse = super.queryForResult(key, client, location)
logDebug(s"Original response: ${ originalResponse }")
modifyResponse(originalResponse)
}
}


/**
* Trait `HandleCustomLabelClassification` extends `HasServiceParams` and `HasCustomLanguageModelParam` to handle
* custom label classification tasks. This trait provides the necessary methods to create requests for custom
* multi-label classification and to modify the response to comply with specific requirements.
*
* The primary purpose of this trait is to address the limitation in Spark where fields named "class" cannot be
* directly bound. To work around this limitation, the response is modified to rename the "class" field to
* "classifications".
*
* This trait is designed to be mixed into classes that require custom label classification functionality and
* response modification logic.
*
* @note This trait is designed to be used with the `ModifiableAsyncReply` and `SynapseMLLogging` traits for
* consistent response handling and logging.
*/
trait HandleCustomLabelClassification extends HasServiceParams
with HasCustomLanguageModelParam {
self: ModifiableAsyncReply
with SynapseMLLogging =>

private def isCustomLabelClassification: Boolean = {
val kind = getKind
kind == AnalysisTaskKind.CustomSingleLabelClassification.toString ||
kind == AnalysisTaskKind.CustomMultiLabelClassification.toString
}

/**
* Modifies the entity in the HTTP response to rename the "class" field to "classifications".
*
* @param response The original HTTP response.
* @return The modified HTTP response with the "class" field renamed to "classifications".
*/
private def modifyEntity(response: HTTPResponseData): HTTPResponseData = {
val modifiedEntity = response.entity.flatMap { entity =>
val strEntity = IOUtils.toString(entity.content, "UTF-8")
val modifiedEntity = strEntity.replace("\"class\":", "\"classifications\":")
logDebug(s"Original entity: $strEntity\t Modified entity: $modifiedEntity")
Some(new EntityData(
content = modifiedEntity.getBytes,
contentEncoding = entity.contentEncoding,
contentLength = Some(strEntity.length),
contentType = entity.contentType,
isChunked = entity.isChunked,
isRepeatable = entity.isRepeatable,
isStreaming = entity.isStreaming
))
}
new HTTPResponseData(response.headers, modifiedEntity, response.statusLine, response.locale)
}

/**
* Modifies the HTTP response if the task kind is custom label classification.
*/
override def modifyResponse(response: Option[HTTPResponseData]): Option[HTTPResponseData] = {
if (!isCustomLabelClassification) {
logDebug(s"Kind is not CustomSingleLabelClassification or CustomMultiLabelClassification. Kind: $getKind")
response
} else {
response.map(modifyEntity)
}
}


def getKind: String

def createCustomMultiLabelRequest(row: Row,
Expand Down
Loading

0 comments on commit 214a578

Please sign in to comment.