Skip to content

Commit

Permalink
update openai setTimeout
Browse files Browse the repository at this point in the history
  • Loading branch information
JessicaXYWang committed Dec 5, 2023
1 parent a86c03f commit 9d2a813
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
package com.microsoft.azure.synapse.ml.services.openai

import com.microsoft.azure.synapse.ml.codegen.GenerationUtils
import com.microsoft.azure.synapse.ml.services.{HasAPIVersion, HasServiceParams}
import com.microsoft.azure.synapse.ml.services.{CognitiveServicesBase, HasAPIVersion, HasServiceParams}
import com.microsoft.azure.synapse.ml.param.ServiceParam
import org.apache.spark.sql.Row
import spray.json.DefaultJsonProtocol._
Expand Down Expand Up @@ -244,3 +244,6 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {
}
}

abstract class OpenAIServicesBase(override val uid: String) extends CognitiveServicesBase(uid: String) {
setDefault(timeout -> 360.0)
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,11 @@ import scala.language.existentials

object OpenAIChatCompletion extends ComplexParamsReadable[OpenAIChatCompletion]

class OpenAIChatCompletion(override val uid: String) extends CognitiveServicesBase(uid)
class OpenAIChatCompletion(override val uid: String) extends OpenAIServicesBase(uid)
with HasOpenAITextParams with HasCognitiveServiceInput
with HasInternalJsonOutputParser with SynapseMLLogging {
logClass(FeatureNames.AiServices.OpenAI)

setDefault(timeout -> 360.0)

val messagesCol: Param[String] = new Param[String](
this, "messagesCol", "The column messages to generate chat completions for," +
" in the chat format. This column should have type Array(Struct(role: String, content: String)).")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,11 @@ import scala.language.existentials

object OpenAICompletion extends ComplexParamsReadable[OpenAICompletion]

class OpenAICompletion(override val uid: String) extends CognitiveServicesBase(uid)
class OpenAICompletion(override val uid: String) extends OpenAIServicesBase(uid)
with HasOpenAITextParams with HasPromptInputs with HasCognitiveServiceInput
with HasInternalJsonOutputParser with SynapseMLLogging {
logClass(FeatureNames.AiServices.OpenAI)

setDefault(timeout -> 360.0)

def this() = this(Identifiable.randomUID("OpenAICompletion"))

def urlPath: String = ""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@ import scala.language.existentials

object OpenAIEmbedding extends ComplexParamsReadable[OpenAIEmbedding]

class OpenAIEmbedding (override val uid: String) extends CognitiveServicesBase(uid)
class OpenAIEmbedding (override val uid: String) extends OpenAIServicesBase(uid)
with HasOpenAISharedParams with HasCognitiveServiceInput with SynapseMLLogging {
logClass(FeatureNames.AiServices.OpenAI)

setDefault(timeout -> 360.0)

def this() = this(Identifiable.randomUID("OpenAIEmbedding"))

def urlPath: String = ""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ class OpenAIPrompt(override val uid: String) extends Transformer
with HasSubscriptionKey with HasAADToken with HasCustomAuthHeader
with ComplexParamsWritable with SynapseMLLogging {

setDefault(timeout -> 360.0)

logClass(FeatureNames.AiServices.OpenAI)

def this() = this(Identifiable.randomUID("OpenAIPrompt"))
Expand Down Expand Up @@ -75,7 +73,8 @@ class OpenAIPrompt(override val uid: String) extends Transformer
postProcessingOptions -> Map.empty,
outputCol -> (this.uid + "_output"),
errorCol -> (this.uid + "_error"),
dropPrompt -> true
dropPrompt -> true,
timeout -> 360.0
)

override def setCustomServiceName(v: String): this.type = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ trait ConcurrencyParams extends Wrappable {
case Some(v) => setConcurrentTimeout(v)
case None => clear(concurrentTimeout)
}

setDefault(concurrency -> 1, timeout -> 60.0)

}

trait HasURL extends Params {
Expand Down

0 comments on commit 9d2a813

Please sign in to comment.