Skip to content

Commit

Permalink
add codec derivation and tests, convert all TypedEncoder wrap usage t…
Browse files Browse the repository at this point in the history
…o remove indirection and provide enum support
  • Loading branch information
chris-twiner committed Feb 11, 2025
1 parent d4e3c18 commit 8ae0940
Show file tree
Hide file tree
Showing 3 changed files with 280 additions and 64 deletions.
71 changes: 64 additions & 7 deletions dataset/src/main/scala/frameless/InjectionEnum.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,65 @@
package frameless

import frameless.InjectionCodecs.codec
import org.apache.spark.sql.catalyst.encoders.Codec
import shapeless._

import scala.reflect.ClassTag

trait InjectionEnum {
object CodecEnums {
/**
* names are kept to shadow
*/

implicit val cnilInjectionEnum: TypedInjection[CNil, String] =
TypedInjection[CNil, String](
// $COVERAGE-OFF$No value of type CNil so impossible to test
_ => throw new Exception("Impossible"),
// $COVERAGE-ON$
name =>
throw new IllegalArgumentException(
s"Cannot construct a value of type CNil: $name did not match data constructor names"
)
)

implicit def coproductInjectionEnum[H, T <: Coproduct](
implicit
typeable: Typeable[H],
gen: Generic.Aux[H, HNil],
tInjectionEnum: TypedInjection[T, String]
): TypedInjection[H :+: T, String] = {
val dataConstructorName = typeable.describe.takeWhile(_ != '.')

val underlying = tInjectionEnum.codecProvider()

TypedInjection(
{
case Inl(_) => dataConstructorName
case Inr(t) => underlying.encode(t)
},
{ name =>
if (name == dataConstructorName)
Inl(gen.from(HNil))
else
Inr(underlying.decode(name))
}
)
}

implicit def genericInjectionEnum[A: ClassTag, R](
implicit
gen: Generic.Aux[A, R],
rInjectionEnum: TypedInjection[R, String]
): TypedInjection[A, String] = {
val underlying = rInjectionEnum.codecProvider()
TypedInjection[A, String](
value => underlying.encode(gen.to(value)),
name => gen.from(underlying.decode(name))
)
}
}

implicit val cnilInjectionEnum: Injection[CNil, String] =
Injection(
// $COVERAGE-OFF$No value of type CNil so impossible to test
Expand All @@ -15,10 +72,10 @@ trait InjectionEnum {
)

implicit def coproductInjectionEnum[H, T <: Coproduct](
implicit
typeable: Typeable[H] ,
gen: Generic.Aux[H, HNil],
tInjectionEnum: Injection[T, String]
implicit
typeable: Typeable[H] ,
gen: Generic.Aux[H, HNil],
tInjectionEnum: Injection[T, String]
): Injection[H :+: T, String] = {
val dataConstructorName = typeable.describe.takeWhile(_ != '.')

Expand All @@ -37,9 +94,9 @@ trait InjectionEnum {
}

implicit def genericInjectionEnum[A, R](
implicit
gen: Generic.Aux[A, R],
rInjectionEnum: Injection[R, String]
implicit
gen: Generic.Aux[A, R],
rInjectionEnum: Injection[R, String]
): Injection[A, String] =
Injection(
value => rInjectionEnum(gen.to(value)),
Expand Down
143 changes: 86 additions & 57 deletions dataset/src/main/scala/frameless/TypedEncoder.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package frameless

import frameless.FramelessInternals.UserDefinedType
import frameless.InjectionCodecs.codec
import frameless.{reflection => ScalaReflection}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, Codec}
Expand Down Expand Up @@ -32,6 +33,47 @@ abstract class TypedEncoder[T](
def agnosticEncoder: AgnosticEncoder[T]
}

/**
* Created to allow derivation of encoders via manual invocation and Enums.
* Encoding via the underlying implicit () => Codec is still possible
* @tparam A
* @tparam B
*/
abstract class TypedInjection[A,B](encoderName: String)(
implicit
_classTag_ : ClassTag[A],
val codecProvider: () => Codec[A,B],
val trb: TypedEncoder[B]
) extends TypedEncoder[A] with Codec[A,B] {
override def nullable: Boolean = trb.nullable
override def jvmRepr: DataType = FramelessInternals.objectTypeFor[A](classTag)
override def catalystRepr: DataType = trb.catalystRepr

override def agnosticEncoder: AgnosticEncoder[A] =
TransformingEncoder[A, B](
classTag,
trb.agnosticEncoder,
codecProvider)

override def toString: String = encoderName

lazy val codec = codecProvider()

override def encode(in: A): B = codec.encode(in)

override def decode(out: B): A = codec.decode(out)
}

object TypedInjection {
def apply[A: ClassTag, B: TypedEncoder](from: A => B, to: B => A): TypedInjection[A,B] = {
implicit val provider: () => Codec[A, B] = codec(from, to)
new TypedInjection[A,B]("DirectTypedInjection"){}
}

def of[A, B](implicit injection: TypedInjection[A,B]): TypedInjection[A,B] = injection

}

object InjectionCodecs {

def wrap[A, B](injection: Injection[A,B]): () => Codec[A, B] =
Expand All @@ -42,6 +84,13 @@ object InjectionCodecs {
override def decode(out: B): A = injection.invert(out)
}

def codec[A, B](from: A => B, to: B => A):() => Codec[A, B] =
() =>
new Codec[A, B] with Serializable {
override def encode(in: A): B = from(in)

override def decode(out: B): A = to(out)
}
}

// Waiting on scala 2.12
Expand Down Expand Up @@ -87,15 +136,12 @@ object TypedEncoder {

implicit val charEncoder: TypedEncoder[Char] = new TypedEncoder[Char] {

val charAsString: Injection[java.lang.Character, String] =
new Injection[java.lang.Character, String] {
def apply(a: java.lang.Character): String = String.valueOf(a)

def invert(b: String): java.lang.Character = {
require(b.length == 1)
b.charAt(0)
}
}
private val charAsString =
codec[java.lang.Character, String](String.valueOf(_),
out => {
require(out.length == 1)
out.charAt(0)
})

override def jvmRepr: DataType =
FramelessInternals.objectTypeFor[java.lang.Character]
Expand All @@ -104,7 +150,7 @@ object TypedEncoder {
TransformingEncoder[java.lang.Character, String](
implicitly[ClassTag[java.lang.Character]],
StringEncoder,
InjectionCodecs.wrap(charAsString)).asInstanceOf[AgnosticEncoder[Char]] // same types but code gen needs exact
charAsString).asInstanceOf[AgnosticEncoder[Char]] // same types but code gen needs exact

override def toString: String = s"CharEncoder"
}
Expand Down Expand Up @@ -154,18 +200,11 @@ object TypedEncoder {
implicit val sqlDate: TypedEncoder[SQLDate] = new TypedEncoder[SQLDate] {
override def jvmRepr: DataType = ScalaReflection.dataTypeFor[SQLDate]

val sqlDateAsDate: Injection[SQLDate, Int] =
new Injection[SQLDate, Int] {
def apply(a: SQLDate): Int = a.days

def invert(b: Int): SQLDate = SQLDate(b)
}

override def agnosticEncoder: AgnosticEncoder[SQLDate] =
TransformingEncoder[SQLDate, Int](
classTag,
PrimitiveIntEncoder,
InjectionCodecs.wrap(sqlDateAsDate))
codec(_.days, SQLDate))

override def toString: String = s"SQLDateEncoder"
}
Expand All @@ -182,18 +221,11 @@ object TypedEncoder {

override def jvmRepr: DataType = ScalaReflection.dataTypeFor[Date]

val dateAsInstant: Injection[Date, Instant] =
new Injection[Date, Instant] {
def apply(a: Date): Instant = a.toInstant

def invert(b: Instant): Date = Date.from(b)
}

override def agnosticEncoder: AgnosticEncoder[Date] =
TransformingEncoder[Date, Instant](
classTag,
STRICT_INSTANT_ENCODER,
InjectionCodecs.wrap(dateAsInstant))
codec(_.toInstant, Date.from))

override def toString: String = s"DateEncoder"
}
Expand All @@ -209,18 +241,15 @@ object TypedEncoder {
new TypedEncoder[SQLTimestamp] {
override def jvmRepr: DataType = ScalaReflection.dataTypeFor[SQLTimestamp]

val sqlTimestampAsLong: Injection[SQLTimestamp, java.sql.Timestamp] =
new Injection[SQLTimestamp, java.sql.Timestamp] with Serializable{
def apply(a: SQLTimestamp): java.sql.Timestamp = Timestamp.from(microsToInstant(a.us))

def invert(b: java.sql.Timestamp): SQLTimestamp = SQLTimestamp(instantToMicros(b.toInstant))
}
private val sqlTimestampAsLong =
codec[SQLTimestamp, java.sql.Timestamp](in => Timestamp.from(microsToInstant(in.us)),
out => SQLTimestamp(instantToMicros(out.toInstant)))

override def agnosticEncoder: AgnosticEncoder[SQLTimestamp] =
TransformingEncoder[SQLTimestamp, java.sql.Timestamp](
classTag,
TimestampEncoder(true),
InjectionCodecs.wrap(sqlTimestampAsLong)
sqlTimestampAsLong
)

override def toString: String = s"SQLTimestampEncoder"
Expand All @@ -241,18 +270,18 @@ object TypedEncoder {
* * https://github.com/apache/spark/blob/v3.2.0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala#L1075-L1087
*/
// DayTimeIntervalType
implicit val timeDurationInjection: Injection[Duration, Long] =
Injection(_.toMillis, Duration.ofMillis)
implicit val timeDurationInjection: () => Codec[Duration, Long] =
codec(_.toMillis, Duration.ofMillis)

// YearMonthIntervalType
implicit val timePeriodInjection: Injection[Period, Int] =
Injection(_.getDays, Period.ofDays)
implicit val timePeriodInjection: () => Codec[Period, Int] =
codec(_.getDays, Period.ofDays)

implicit val timePeriodEncoder: TypedEncoder[Period] =
TypedEncoder.usingInjection
TypedEncoder.usingCodec

implicit val timeDurationEncoder: TypedEncoder[Duration] =
TypedEncoder.usingInjection
TypedEncoder.usingCodec

implicit def arrayEncoder[T: ClassTag](
implicit
Expand Down Expand Up @@ -373,12 +402,7 @@ object TypedEncoder {
encodeT.agnosticEncoder,
encodeT.nullable,
lenientSerialization = false),
() => new Codec[C[T], O[T]] {

override def decode(in: O[T]): C[T] = i2.convert(in)

override def encode(out: C[T]): O[T] = i2.reverse(out)
}
codec(i2.reverse, i2.convert)
)

override def toString: String = s"CollectionEncoder[$jvmRepr]"
Expand Down Expand Up @@ -434,20 +458,25 @@ object TypedEncoder {
implicit
inj: Injection[A, B],
trb: TypedEncoder[B]
): TypedEncoder[A] =
new TypedEncoder[A] {
override def nullable: Boolean = trb.nullable
override def jvmRepr: DataType = FramelessInternals.objectTypeFor[A](classTag)
override def catalystRepr: DataType = trb.catalystRepr
): TypedEncoder[A] = {
implicit val wrapped: () => Codec[A, B] = InjectionCodecs.wrap(inj)
usingConversion("InjectionEncoder")
}

override def agnosticEncoder: AgnosticEncoder[A] =
TransformingEncoder[A, B](
classTag,
trb.agnosticEncoder,
InjectionCodecs.wrap(inj))
/** Encodes things using Spark's TransformingEncoder Codec if there is one defined */
private[frameless] def usingConversion[A: ClassTag, B](encoderName: String)(
implicit
inj: () => Codec[A, B],
trb: TypedEncoder[B]
): TypedInjection[A, B] =
new TypedInjection[A, B](encoderName) {}

override def toString: String = s"InjectionEncoder"
}
/** Encodes things using Spark's TransformingEncoder Codec if there is one defined */
implicit def usingCodec[A: ClassTag, B](
implicit
inj: () => Codec[A, B],
trb: TypedEncoder[B]
): TypedEncoder[A] = usingConversion("CodecEncoder")

/** Encodes things as records if there is no Injection defined */
implicit def usingDerivation[F, G <: HList, H <: HList](
Expand Down
Loading

0 comments on commit 8ae0940

Please sign in to comment.