Skip to content

Commit

Permalink
all tests passing on spark with TransformingEncoder fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-twiner committed Feb 11, 2025
1 parent 68f15b2 commit d4e3c18
Showing 1 changed file with 18 additions and 71 deletions.
89 changes: 18 additions & 71 deletions dataset/src/main/scala/frameless/TypedEncoder.scala
Original file line number Diff line number Diff line change
@@ -1,26 +1,21 @@
package frameless

import java.math.BigInteger
import java.util.Date
import java.time.{Duration, Instant, LocalDate, Period}
import java.sql.Timestamp
import scala.reflect.ClassTag
import FramelessInternals.UserDefinedType
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, UnsafeArrayData}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import shapeless._
import shapeless.ops.hlist.IsHCons
import com.sparkutils.shim.expressions.{ExternalMapToCatalyst7 => ExternalMapToCatalyst, MapObjects5 => MapObjects, UnwrapOption2 => UnwrapOption, WrapOption2 => WrapOption}
import frameless.FramelessInternals.UserDefinedType
import frameless.{reflection => ScalaReflection}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, Codec}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BinaryEncoder, DEFAULT_JAVA_DECIMAL_ENCODER, DEFAULT_SCALA_DECIMAL_ENCODER, IterableEncoder, JavaBigIntEncoder, MapEncoder, NullEncoder, OptionEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, STRICT_DATE_ENCODER, STRICT_INSTANT_ENCODER, STRICT_TIMESTAMP_ENCODER, ScalaBigIntEncoder, ScalaDecimalEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder}
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, microsToInstant}
import org.apache.spark.sql.shim.{Invoke5 => Invoke, NewInstance4 => NewInstance, StaticInvoke4 => StaticInvoke}
import org.apache.spark.sql.types._
import shapeless._
import shapeless.ops.hlist.IsHCons

import java.math.BigInteger
import java.sql
import java.sql.Timestamp
import java.time.{Duration, Instant, Period}
import java.util.Date
import scala.collection.immutable.{ListSet, TreeSet}
import scala.reflect.ClassTag

abstract class TypedEncoder[T](
implicit
Expand All @@ -47,24 +42,6 @@ object InjectionCodecs {
override def decode(out: B): A = injection.invert(out)
}


val decode: Codec[_, _] => Function[_, _] = (codec: Codec[_, _]) => codec.decode _
val encode: Codec[_, _] => Function[_, _] = (codec: Codec[_, _]) => codec.encode _

def convertPossibleValueClass[A, B](recordFieldEncoder: RecordFieldEncoder[_], op: Codec[_,_] => _ => _): A => B = {
recordFieldEncoder.encoder.agnosticEncoder match {
case tEnc: TransformingEncoder[_, _] =>
val dec =
op(tEnc.
codecProvider().asInstanceOf[Codec[_, _]])

recordFieldEncoder.valueClassUnderlying.fold[_ => Any]((a: Any) => a)(_ => {
a => dec(a)
}).asInstanceOf[A => B]
case _ => a => a.asInstanceOf[B]
}
}

}

// Waiting on scala 2.12
Expand Down Expand Up @@ -175,7 +152,6 @@ object TypedEncoder {
}

implicit val sqlDate: TypedEncoder[SQLDate] = new TypedEncoder[SQLDate] {
// No direct equivalent of invoke <-> staticinvoke pairs but injection works
override def jvmRepr: DataType = ScalaReflection.dataTypeFor[SQLDate]

val sqlDateAsDate: Injection[SQLDate, Int] =
Expand Down Expand Up @@ -293,17 +269,8 @@ object TypedEncoder {
override def agnosticEncoder: AgnosticEncoder[Array[T]] =
encodeT.jvmRepr match {
case ByteType => BinaryEncoder.asInstanceOf[AgnosticEncoder[Array[T]]]
//case IntegerType | LongType | DoubleType | FloatType | ShortType |
// BooleanType =>
case _ =>
ArrayEncoder(encodeT.agnosticEncoder, encodeT.nullable)
/*case _ =>
IterableEncoder(
classTag,
encodeT.agnosticEncoder,
encodeT.nullable,
lenientSerialization = false).asInstanceOf[AgnosticEncoder[Array[T]]]*/
//collectionEncoder(encodeT.agnosticEncoder, containsNull = false)
}

override def toString: String = s"ArrayEncoder[$jvmRepr]"
Expand Down Expand Up @@ -402,7 +369,7 @@ object TypedEncoder {
TransformingEncoder(
classTag,
IterableEncoder(
i3, // we need the base supported type either Seq or Set, TreeSet or any other builders will fail
i3,
encodeT.agnosticEncoder,
encodeT.nullable,
lenientSerialization = false),
Expand Down Expand Up @@ -433,35 +400,13 @@ object TypedEncoder {
private lazy val encodeA = i0.value
private lazy val encodeB = i1.value

private lazy val convertA: Any => A = InjectionCodecs.convertPossibleValueClass(encodeA, InjectionCodecs.decode)
private lazy val convertB: Any => B = InjectionCodecs.convertPossibleValueClass(encodeB, InjectionCodecs.decode)

private lazy val revertA: Any => A = InjectionCodecs.convertPossibleValueClass(encodeA, InjectionCodecs.encode)
private lazy val revertB: Any => B = InjectionCodecs.convertPossibleValueClass(encodeB, InjectionCodecs.encode)

val provider = () => new Codec[Map[A,B], Map[_,_]] {

override def decode(in: Map[_, _]): Map[A, B] = in.map { p =>
(convertA(p._1), convertB(p._2))
}

override def encode(out: Map[A, B]): Map[_, _] = out.map { p =>
(revertA(p._1), revertB(p._2))
}
}

// MAP key / values with TransformingEncoder as top level do not seem to work
override def agnosticEncoder: AgnosticEncoder[Map[A, B]] = {
TransformingEncoder[Map[A,B],Map[_,_]](
override def agnosticEncoder: AgnosticEncoder[Map[A, B]] =
MapEncoder(
classTag,
MapEncoder(
classTag.asInstanceOf[ClassTag[Map[_,_]]],
encodeA.valueClassUnderlying.fold[AgnosticEncoder[A]](encodeA.encoder.agnosticEncoder)(_.agnosticEncoder.asInstanceOf[AgnosticEncoder[A]]),
encodeB.valueClassUnderlying.fold[AgnosticEncoder[B]](encodeB.encoder.agnosticEncoder)(_.agnosticEncoder.asInstanceOf[AgnosticEncoder[B]]),
valueContainsNull = encodeB.encoder.nullable),
provider
)
}
encodeA.encoder.agnosticEncoder,
encodeB.encoder.agnosticEncoder,
valueContainsNull = encodeB.encoder.nullable)

override def toString: String = s"MapEncoder[$jvmRepr]"
}
Expand Down Expand Up @@ -491,7 +436,9 @@ object TypedEncoder {
trb: TypedEncoder[B]
): TypedEncoder[A] =
new TypedEncoder[A] {
override def nullable: Boolean = trb.nullable
override def jvmRepr: DataType = FramelessInternals.objectTypeFor[A](classTag)
override def catalystRepr: DataType = trb.catalystRepr

override def agnosticEncoder: AgnosticEncoder[A] =
TransformingEncoder[A, B](
Expand Down

0 comments on commit d4e3c18

Please sign in to comment.