Skip to content

Commit

Permalink
handle top level injection wrapping
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-twiner committed Feb 11, 2025
1 parent ebf4679 commit 68f15b2
Showing 1 changed file with 30 additions and 3 deletions.
33 changes: 30 additions & 3 deletions dataset/src/main/scala/frameless/TypedExpressionEncoder.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package frameless

import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, Codec, JavaSerializationCodec, KryoSerializationCodecImpl}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{EncoderField, ProductEncoder, TransformingEncoder}
import org.apache.spark.sql.types.{Metadata, StructType}

import scala.reflect.ClassTag

object TypedExpressionEncoder {

Expand All @@ -18,6 +21,30 @@ object TypedExpressionEncoder {
def apply[T](
implicit
encoder: TypedEncoder[T]
): AgnosticEncoder[T] = encoder.agnosticEncoder
): AgnosticEncoder[T] = {

import encoder.classTag

// spark special cases option as a top return value
// it cannot cascade this through agnostic encoders up from nested encoders
// a simple way to verify if we have a need for option is if the top encoder is itself nullable
// An injection of type [Int, I[Option[Int]]] will not be a struct
if (encoder.nullable && encoder.catalystRepr.isInstanceOf[StructType]) {
TransformingEncoder(
implicitly[ClassTag[T]],
ProductEncoder(
implicitly[ClassTag[SparkValueClass[T]]],
Seq(EncoderField("a", encoder.agnosticEncoder, nullable = true, Metadata.empty)),
None),
codecProvider = () => new Codec[T, SparkValueClass[T]] {
override def encode(in: T): SparkValueClass[T] = SparkValueClass(in)
override def decode(out: SparkValueClass[T]): T = out.a
}
)
} else
encoder.agnosticEncoder
}

}

private case class SparkValueClass[A](a: A)

0 comments on commit 68f15b2

Please sign in to comment.