Skip to content

Commit

Permalink
chore: OpenAIPrompt bug fixes (#2334)
Browse files Browse the repository at this point in the history
* Fix bug where we check for params that don't exist

* Rename error column and move to back of df

* Fix style

* reverting some changes

* Change transform schema to move error column to the back
  • Loading branch information
sss04 authored Jan 7, 2025
1 parent 3ec2ccd commit ec2c1c1
Showing 1 changed file with 16 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row, functions => F, types => T}
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{DataType, StructField, StructType}

import scala.collection.JavaConverters._

Expand Down Expand Up @@ -172,32 +172,30 @@ class OpenAIPrompt(override val uid: String) extends Transformer
getParser.parse(F.element_at(F.col(completionNamed.getOutputCol).getField("choices"), 1)
.getField("message").getField("content")))
.drop(completionNamed.getOutputCol)

val resultsFinal = results.select(results.columns.filter(_ != getErrorCol).map(col) :+ col(getErrorCol): _*)
if (getDropPrompt) {
results.drop(messageColName)
resultsFinal.drop(messageColName)
} else {
results
resultsFinal
}

case completion: OpenAICompletion =>
if (isSet(responseFormat)) {
throw new IllegalArgumentException("responseFormat is not supported for completion models")
}
val promptColName = df.withDerivativeCol("prompt")
val dfTemplated = df.withColumn(promptColName, promptCol)
val completionNamed = completion.setPromptCol(promptColName)
// run completion
val results = completionNamed
.transform(dfTemplated)
.withColumn(getOutputCol,
getParser.parse(F.element_at(F.col(completionNamed.getOutputCol).getField("choices"), 1)
.getField("text")))
.drop(completionNamed.getOutputCol)

val resultsFinal = results.select(results.columns.filter(_ != getErrorCol).map(col) :+ col(getErrorCol): _*)
if (getDropPrompt) {
results.drop(promptColName)
resultsFinal.drop(promptColName)
} else {
results
resultsFinal
}
}
}, dataset.columns.length)
Expand Down Expand Up @@ -238,7 +236,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer
}
// apply all parameters
extractParamMap().toSeq
.filter(p => !localParamNames.contains(p.param.name))
.filter(p => !localParamNames.contains(p.param.name) && completion.hasParam(p.param.name))
.foreach(p => completion.set(completion.getParam(p.param.name), p.value))

completion
Expand Down Expand Up @@ -267,7 +265,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer
}

override def transformSchema(schema: StructType): StructType = {
openAICompletion match {
val transformedSchema = openAICompletion match {
case chatCompletion: OpenAIChatCompletion =>
chatCompletion
.transformSchema(schema.add(getMessagesCol, StructType(Seq())))
Expand All @@ -277,6 +275,12 @@ class OpenAIPrompt(override val uid: String) extends Transformer
.transformSchema(schema)
.add(getPostProcessing, getParser.outputSchema)
}

// Move error column to back
val errorFieldOpt: Option[StructField] = transformedSchema.fields.find(_.name == getErrorCol)
val fieldsWithoutError: Array[StructField] = transformedSchema.fields.filterNot(_.name == getErrorCol)
val reorderedFields = Array.concat(fieldsWithoutError, errorFieldOpt.toArray)
StructType(reorderedFields)
}
}

Expand Down

0 comments on commit ec2c1c1

Please sign in to comment.