diff --git a/build.sbt b/build.sbt index 9ef6c7a8..df888f9b 100644 --- a/build.sbt +++ b/build.sbt @@ -1,6 +1,6 @@ val sparkVersion = // "3.5.1" // - "4.0.0-preview2" // must have the apache_snaps configured + "4.1.0-SNAPSHOT" // must have the apache_snaps configured 4.1.0-SNAPSHOT val spark34Version = "3.4.2" val spark33Version = "3.3.4" val catsCoreVersion = "2.10.0" @@ -13,7 +13,7 @@ val scalacheck = "1.17.0" val scalacheckEffect = "1.0.4" val refinedVersion = "0.11.1" val nakedFSVersion = "0.1.0" -val shimVersion = "0.0.1-RC5-SNAPSHOT" +val shimVersion = "0.0.2-SNAPSHOT" val Scala212 = "2.12.19" val Scala213 = "2.13.13" diff --git a/core/src/main/scala/frameless/Injection.scala b/core/src/main/scala/frameless/Injection.scala index cf9ecb33..3cee7ae7 100644 --- a/core/src/main/scala/frameless/Injection.scala +++ b/core/src/main/scala/frameless/Injection.scala @@ -16,4 +16,5 @@ object Injection { def apply(a: A): B = f(a) def invert(b: B): A = g(b) } + } diff --git a/dataset/src/main/scala/frameless/FramelessInternals.scala b/dataset/src/main/scala/frameless/FramelessInternals.scala index ec4c1bd6..38088b7d 100644 --- a/dataset/src/main/scala/frameless/FramelessInternals.scala +++ b/dataset/src/main/scala/frameless/FramelessInternals.scala @@ -67,7 +67,7 @@ object FramelessInternals { plan: LogicalPlan, encoder: Encoder[T] ): Dataset[T] = - new Dataset(sqlContext, plan, encoder) + new classic.Dataset(sqlContext, plan, encoder) def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame = shimUtils.ofRows(sparkSession, logicalPlan) diff --git a/dataset/src/main/scala/frameless/RecordEncoder.scala b/dataset/src/main/scala/frameless/RecordEncoder.scala index 574ce427..9febbf6a 100644 --- a/dataset/src/main/scala/frameless/RecordEncoder.scala +++ b/dataset/src/main/scala/frameless/RecordEncoder.scala @@ -1,17 +1,11 @@ package frameless -import com.sparkutils.shim.expressions.{ - CreateNamedStruct1 => CreateNamedStruct, - GetStructField3 => GetStructField, - UnwrapOption2 => UnwrapOption, - WrapOption2 => WrapOption -} -import com.sparkutils.shim.{ deriveUnitLiteral, ifIsNull } -import org.apache.spark.sql.catalyst.expressions.{ Expression, Literal } -import org.apache.spark.sql.shim.{ - Invoke5 => Invoke, - NewInstance4 => NewInstance -} +import com.sparkutils.shim.expressions.{CreateNamedStruct1 => CreateNamedStruct, GetStructField3 => GetStructField, UnwrapOption2 => UnwrapOption, WrapOption2 => WrapOption} +import com.sparkutils.shim.{deriveUnitLiteral, ifIsNull} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{EncoderField, ProductEncoder} +import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} +import org.apache.spark.sql.shim.{Invoke5 => Invoke, NewInstance4 => NewInstance} import org.apache.spark.sql.types._ import shapeless._ import shapeless.labelled.FieldType @@ -37,7 +31,7 @@ object RecordEncoderFields { implicit def deriveRecordLast[K <: Symbol, H]( implicit key: Witness.Aux[K], - head: RecordFieldEncoder[H] + head: TypedEncoder[H] ): RecordEncoderFields[FieldType[K, H] :: HNil] = new RecordEncoderFields[FieldType[K, H] :: HNil] { def value: List[RecordEncoderField] = fieldEncoder[K, H] :: Nil @@ -46,7 +40,7 @@ object RecordEncoderFields { implicit def deriveRecordCons[K <: Symbol, H, T <: HList]( implicit key: Witness.Aux[K], - head: RecordFieldEncoder[H], + head: TypedEncoder[H], tail: RecordEncoderFields[T] ): RecordEncoderFields[FieldType[K, H] :: T] = new RecordEncoderFields[FieldType[K, H] :: T] { @@ -60,8 +54,8 @@ object RecordEncoderFields { private def fieldEncoder[K <: Symbol, H]( implicit key: Witness.Aux[K], - e: RecordFieldEncoder[H] - ): RecordEncoderField = RecordEncoderField(0, key.value.name, e.encoder) + e: TypedEncoder[H] + ): RecordEncoderField = RecordEncoderField(0, key.value.name, e) } /** @@ -154,7 +148,19 @@ class RecordEncoder[F, G <: HList, H <: HList]( newInstanceExprs: Lazy[NewInstanceExprs[G]], classTag: ClassTag[F]) extends TypedEncoder[F] { - def nullable: Boolean = false + + override def agnosticEncoder: AgnosticEncoder[F] = + ProductEncoder[F]( + classTag, + fields.value.value.map(f => EncoderField( + f.name, + f.encoder.agnosticEncoder, + f.encoder.nullable, + Metadata.empty) ), + None) + + override def jvmRepr: DataType = FramelessInternals.objectTypeFor[F] + /*def nullable: Boolean = false def jvmRepr: DataType = FramelessInternals.objectTypeFor[F] @@ -201,17 +207,11 @@ class RecordEncoder[F, G <: HList, H <: HList]( NewInstance(classTag.runtimeClass, newArgs, jvmRepr, propagateNull = true) ifIsNull(jvmRepr, path, newExpr) - } -} + }*/ -final class RecordFieldEncoder[T]( - val encoder: TypedEncoder[T], - private[frameless] val jvmRepr: DataType, - private[frameless] val fromCatalyst: Expression => Expression, - private[frameless] val toCatalyst: Expression => Expression) - extends Serializable +} -object RecordFieldEncoder extends RecordFieldEncoderLowPriority { +object RecordFieldEncoder /*extends RecordFieldEncoderLowPriority */{ /** * @tparam F the value class @@ -235,7 +235,9 @@ object RecordFieldEncoder extends RecordFieldEncoderLowPriority { i4: IsHCons.Aux[KS, K, HNil], i5: TypedEncoder[V], i6: ClassTag[F] - ): RecordFieldEncoder[Option[F]] = { + ): TypedEncoder[Option[F]] = { + TypedEncoder.optionEncoder(valueClass) + /* val fieldName = i4.head(i3()).name val innerJvmRepr = ObjectType(i6.runtimeClass) @@ -278,7 +280,7 @@ object RecordFieldEncoder extends RecordFieldEncoderLowPriority { jvmRepr = jvmr, fromCatalyst = fromCatalyst, toCatalyst = catalyst - ) + ) */ } /** @@ -302,8 +304,17 @@ object RecordFieldEncoder extends RecordFieldEncoderLowPriority { i4: IsHCons.Aux[KS, K, HNil], i5: TypedEncoder[V], i6: ClassTag[F] - ): RecordFieldEncoder[F] = { - val cls = i6.runtimeClass + ): TypedEncoder[F] = new TypedEncoder[F]() { + override def agnosticEncoder: AgnosticEncoder[F] = ProductEncoder[F]( + i6, + Seq(EncoderField( + i4.head(i3()).name, + i5.agnosticEncoder, + i5.nullable, + Metadata.empty)), + None) + + /* { val cls = i6.runtimeClass val jvmr = i5.jvmRepr val fieldName = i4.head(i3()).name @@ -335,10 +346,10 @@ object RecordFieldEncoder extends RecordFieldEncoderLowPriority { toCatalyst = { expr: Expression => i5.toCatalyst(Invoke(expr, fieldName, jvmr)) } - ) + )*/ } } - +/* private[frameless] sealed trait RecordFieldEncoderLowPriority { implicit def apply[T]( @@ -347,3 +358,4 @@ private[frameless] sealed trait RecordFieldEncoderLowPriority { ): RecordFieldEncoder[T] = new RecordFieldEncoder[T](e, e.jvmRepr, e.fromCatalyst, e.toCatalyst) } +*/ \ No newline at end of file diff --git a/dataset/src/main/scala/frameless/TypedDataset.scala b/dataset/src/main/scala/frameless/TypedDataset.scala index a4267fc5..95c4864d 100644 --- a/dataset/src/main/scala/frameless/TypedDataset.scala +++ b/dataset/src/main/scala/frameless/TypedDataset.scala @@ -9,6 +9,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint} import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.ShimUtils.column +import org.apache.spark.sql.classic.ClassicConversions.castToImpl import org.apache.spark.sql.types.StructType import shapeless._ import shapeless.labelled.FieldType @@ -756,8 +757,8 @@ class TypedDataset[T] protected[frameless] ( e: TypedEncoder[(T, U)] ): TypedDataset[(T, U)] = new TypedDataset( - ShimUtils.joinWith(dataset, other.dataset, column(Literal(true)), "cross")(TypedExpressionEncoder[(T, U)]) - //self.dataset.joinWith(other.dataset, column(Literal(true)), "cross") + //ShimUtils.joinWith(dataset, other.dataset, column(Literal(true)), "cross")(TypedExpressionEncoder[(T, U)]) + self.dataset.joinWith(other.dataset, column(Literal(true)), "cross") ) /** @@ -772,13 +773,13 @@ class TypedDataset[T] protected[frameless] ( to: TypedEncoder[(T, U)] ): TypedDataset[(Option[T], Option[U])] = new TypedDataset( - ShimUtils.joinWith(dataset, other.dataset, condition.untyped, "full")(TypedExpressionEncoder[(T, U)]) - .as[(Option[T], Option[U])](TypedExpressionEncoder[(Option[T], Option[U])]) - /*self.dataset + //ShimUtils.joinWith(dataset, other.dataset, condition.untyped, "full")(TypedExpressionEncoder[(T, U)]) + // .as[(Option[T], Option[U])](TypedExpressionEncoder[(Option[T], Option[U])]) + self.dataset .joinWith(other.dataset, condition.untyped, "full") .as[(Option[T], Option[U])]( TypedExpressionEncoder[(Option[T], Option[U])] - )*/ + ) ) /** @@ -817,11 +818,11 @@ class TypedDataset[T] protected[frameless] ( to: TypedEncoder[(T, U)] ): TypedDataset[(T, Option[U])] = new TypedDataset( - ShimUtils.joinWith(dataset, other.dataset, condition.untyped, "left_outer")(TypedExpressionEncoder[(T, U)]) - .as[(T, Option[U])](TypedExpressionEncoder[(T, Option[U])]) - /*self.dataset + //ShimUtils.joinWith(dataset, other.dataset, condition.untyped, "left_outer")(TypedExpressionEncoder[(T, U)]) + // .as[(T, Option[U])](TypedExpressionEncoder[(T, Option[U])]) + self.dataset .joinWith(other.dataset, condition.untyped, "left_outer") - .as[(T, Option[U])](TypedExpressionEncoder[(T, Option[U])])*/ + .as[(T, Option[U])](TypedExpressionEncoder[(T, Option[U])]) ) /** diff --git a/dataset/src/main/scala/frameless/TypedDatasetForwarded.scala b/dataset/src/main/scala/frameless/TypedDatasetForwarded.scala index d417caf8..8124b119 100644 --- a/dataset/src/main/scala/frameless/TypedDatasetForwarded.scala +++ b/dataset/src/main/scala/frameless/TypedDatasetForwarded.scala @@ -1,8 +1,8 @@ package frameless import java.util - import org.apache.spark.rdd.RDD +import org.apache.spark.sql.classic.ClassicConversions.castToImpl import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.streaming.DataStreamWriter import org.apache.spark.sql.types.StructType diff --git a/dataset/src/main/scala/frameless/TypedEncoder.scala b/dataset/src/main/scala/frameless/TypedEncoder.scala index 235060cf..78d93e97 100644 --- a/dataset/src/main/scala/frameless/TypedEncoder.scala +++ b/dataset/src/main/scala/frameless/TypedEncoder.scala @@ -2,58 +2,49 @@ package frameless import java.math.BigInteger import java.util.Date -import java.time.{ Duration, Instant, LocalDate, Period } +import java.time.{Duration, Instant, LocalDate, Period} import java.sql.Timestamp import scala.reflect.ClassTag import FramelessInternals.UserDefinedType -import org.apache.spark.sql.catalyst.expressions.{ - Expression, - UnsafeArrayData, - Literal -} - -import org.apache.spark.sql.catalyst.util.{ - ArrayBasedMapData, - DateTimeUtils, - GenericArrayData -} +import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, UnsafeArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import shapeless._ import shapeless.ops.hlist.IsHCons -import com.sparkutils.shim.expressions.{ - UnwrapOption2 => UnwrapOption, - WrapOption2 => WrapOption, - MapObjects5 => MapObjects, - ExternalMapToCatalyst7 => ExternalMapToCatalyst -} -import frameless.{ reflection => ScalaReflection } -import org.apache.spark.sql.shim.{ - StaticInvoke4 => StaticInvoke, - NewInstance4 => NewInstance, - Invoke5 => Invoke -} +import com.sparkutils.shim.expressions.{ExternalMapToCatalyst7 => ExternalMapToCatalyst, MapObjects5 => MapObjects, UnwrapOption2 => UnwrapOption, WrapOption2 => WrapOption} +import frameless.{reflection => ScalaReflection} +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, Codec} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, DEFAULT_JAVA_DECIMAL_ENCODER, DEFAULT_SCALA_DECIMAL_ENCODER, IterableEncoder, JavaBigIntEncoder, MapEncoder, OptionEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, STRICT_DATE_ENCODER, STRICT_INSTANT_ENCODER, STRICT_TIMESTAMP_ENCODER, ScalaBigIntEncoder, ScalaDecimalEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder} +import org.apache.spark.sql.shim.{Invoke5 => Invoke, NewInstance4 => NewInstance, StaticInvoke4 => StaticInvoke} -import scala.collection.immutable.{ ListSet, TreeSet } +import java.sql +import scala.collection.immutable.{ListSet, TreeSet} abstract class TypedEncoder[T]( implicit val classTag: ClassTag[T]) extends Serializable { - def nullable: Boolean + def nullable: Boolean = agnosticEncoder.nullable - def jvmRepr: DataType - def catalystRepr: DataType + def jvmRepr: DataType = agnosticEncoder.dataType + def catalystRepr: DataType = agnosticEncoder.dataType /** - * From Catalyst representation to T + * Create the underlying AgnosticEncoder */ - def fromCatalyst(path: Expression): Expression + def agnosticEncoder: AgnosticEncoder[T] +} - /** - * T to Catalyst representation - */ - def toCatalyst(path: Expression): Expression +object InjectionCodecs { + + def wrap[A, B](injection: Injection[A,B]): () => Codec[A, B] = + () => + new Codec[A, B] { + override def encode(in: A): B = injection.apply(in) + + override def decode(out: B): A = injection.invert(out) + } } // Waiting on scala 2.12 @@ -70,358 +61,155 @@ object TypedEncoder { def apply[T: TypedEncoder]: TypedEncoder[T] = implicitly[TypedEncoder[T]] implicit val stringEncoder: TypedEncoder[String] = new TypedEncoder[String] { - def nullable: Boolean = false - - def jvmRepr: DataType = FramelessInternals.objectTypeFor[String] - def catalystRepr: DataType = StringType - - def toCatalyst(path: Expression): Expression = - StaticInvoke(classOf[UTF8String], catalystRepr, "fromString", path :: Nil) - - def fromCatalyst(path: Expression): Expression = - Invoke(path, "toString", jvmRepr) - - override val toString = "stringEncoder" + override def agnosticEncoder: AgnosticEncoder[String] = StringEncoder + override def jvmRepr: DataType = FramelessInternals.objectTypeFor[String] } implicit val booleanEncoder: TypedEncoder[Boolean] = new TypedEncoder[Boolean] { - def nullable: Boolean = false - - def jvmRepr: DataType = BooleanType - def catalystRepr: DataType = BooleanType - - def toCatalyst(path: Expression): Expression = path - def fromCatalyst(path: Expression): Expression = path + override def agnosticEncoder: AgnosticEncoder[Boolean] = PrimitiveBooleanEncoder } implicit val intEncoder: TypedEncoder[Int] = new TypedEncoder[Int] { - def nullable: Boolean = false - - def jvmRepr: DataType = IntegerType - def catalystRepr: DataType = IntegerType - - def toCatalyst(path: Expression): Expression = path - def fromCatalyst(path: Expression): Expression = path - - override def toString = "intEncoder" + override def agnosticEncoder: AgnosticEncoder[Int] = PrimitiveIntEncoder } implicit val longEncoder: TypedEncoder[Long] = new TypedEncoder[Long] { - def nullable: Boolean = false - - def jvmRepr: DataType = LongType - def catalystRepr: DataType = LongType - - def toCatalyst(path: Expression): Expression = path - def fromCatalyst(path: Expression): Expression = path + override def agnosticEncoder: AgnosticEncoder[Long] = PrimitiveLongEncoder } implicit val shortEncoder: TypedEncoder[Short] = new TypedEncoder[Short] { - def nullable: Boolean = false - - def jvmRepr: DataType = ShortType - def catalystRepr: DataType = ShortType - - def toCatalyst(path: Expression): Expression = path - def fromCatalyst(path: Expression): Expression = path + override def agnosticEncoder: AgnosticEncoder[Short] = PrimitiveShortEncoder } implicit val charEncoder: TypedEncoder[Char] = new TypedEncoder[Char] { - // tricky because while Char is primitive type, Spark doesn't support it - implicit val charAsString: Injection[java.lang.Character, String] = - new Injection[java.lang.Character, String] { - def apply(a: java.lang.Character): String = String.valueOf(a) + val charAsString: Injection[Char, String] = + new Injection[Char, String] { + def apply(a: Char): String = String.valueOf(a) - def invert(b: String): java.lang.Character = { + def invert(b: String): Char = { require(b.length == 1) b.charAt(0) } } - val underlying = usingInjection[java.lang.Character, String] - - def nullable: Boolean = false - - // this line fixes underlying encoder - def jvmRepr: DataType = + override def jvmRepr: DataType = FramelessInternals.objectTypeFor[java.lang.Character] - def catalystRepr: DataType = StringType - - def toCatalyst(path: Expression): Expression = underlying.toCatalyst(path) - - def fromCatalyst(path: Expression): Expression = - underlying.fromCatalyst(path) + override def agnosticEncoder: AgnosticEncoder[Char] = + TransformingEncoder[Char, String]( + classTag, + StringEncoder, + InjectionCodecs.wrap(charAsString)) } implicit val byteEncoder: TypedEncoder[Byte] = new TypedEncoder[Byte] { - def nullable: Boolean = false - - def jvmRepr: DataType = ByteType - def catalystRepr: DataType = ByteType - - def toCatalyst(path: Expression): Expression = path - def fromCatalyst(path: Expression): Expression = path + override def agnosticEncoder: AgnosticEncoder[Byte] = PrimitiveByteEncoder } implicit val floatEncoder: TypedEncoder[Float] = new TypedEncoder[Float] { - def nullable: Boolean = false - - def jvmRepr: DataType = FloatType - def catalystRepr: DataType = FloatType - - def toCatalyst(path: Expression): Expression = path - def fromCatalyst(path: Expression): Expression = path + override def agnosticEncoder: AgnosticEncoder[Float] = PrimitiveFloatEncoder } implicit val doubleEncoder: TypedEncoder[Double] = new TypedEncoder[Double] { - def nullable: Boolean = false - - def jvmRepr: DataType = DoubleType - def catalystRepr: DataType = DoubleType - - def toCatalyst(path: Expression): Expression = path - def fromCatalyst(path: Expression): Expression = path + override def agnosticEncoder: AgnosticEncoder[Double] = PrimitiveDoubleEncoder } implicit val bigDecimalEncoder: TypedEncoder[BigDecimal] = new TypedEncoder[BigDecimal] { - def nullable: Boolean = false - - def jvmRepr: DataType = ScalaReflection.dataTypeFor[BigDecimal] - def catalystRepr: DataType = DecimalType.SYSTEM_DEFAULT - - def toCatalyst(path: Expression): Expression = - StaticInvoke( - Decimal.getClass, - DecimalType.SYSTEM_DEFAULT, - "apply", - path :: Nil - ) - - def fromCatalyst(path: Expression): Expression = - Invoke(path, "toBigDecimal", jvmRepr) - - override def toString: String = "bigDecimalEncoder" + override def jvmRepr: DataType = ScalaReflection.dataTypeFor[BigDecimal] + override def agnosticEncoder: AgnosticEncoder[BigDecimal] = DEFAULT_SCALA_DECIMAL_ENCODER } implicit val javaBigDecimalEncoder: TypedEncoder[java.math.BigDecimal] = new TypedEncoder[java.math.BigDecimal] { - def nullable: Boolean = false - - def jvmRepr: DataType = ScalaReflection.dataTypeFor[java.math.BigDecimal] - def catalystRepr: DataType = DecimalType.SYSTEM_DEFAULT - - def toCatalyst(path: Expression): Expression = - StaticInvoke( - Decimal.getClass, - DecimalType.SYSTEM_DEFAULT, - "apply", - path :: Nil - ) - - def fromCatalyst(path: Expression): Expression = - Invoke(path, "toJavaBigDecimal", jvmRepr) - - override def toString: String = "javaBigDecimalEncoder" + override def jvmRepr: DataType = ScalaReflection.dataTypeFor[java.math.BigDecimal] + override def agnosticEncoder: AgnosticEncoder[java.math.BigDecimal] = DEFAULT_JAVA_DECIMAL_ENCODER } implicit val bigIntEncoder: TypedEncoder[BigInt] = new TypedEncoder[BigInt] { - def nullable: Boolean = false - - def jvmRepr: DataType = ScalaReflection.dataTypeFor[BigInt] - def catalystRepr: DataType = DecimalType(DecimalType.MAX_PRECISION, 0) - - def toCatalyst(path: Expression): Expression = - StaticInvoke( - Decimal.getClass, - catalystRepr, - "apply", - path :: Nil - ) - - def fromCatalyst(path: Expression): Expression = - Invoke(path, "toScalaBigInt", jvmRepr) - - override def toString: String = "bigIntEncoder" + override def jvmRepr: DataType = ScalaReflection.dataTypeFor[BigInt] + override def agnosticEncoder: AgnosticEncoder[BigInt] = ScalaBigIntEncoder } implicit val javaBigIntEncoder: TypedEncoder[BigInteger] = new TypedEncoder[BigInteger] { - def nullable: Boolean = false - - def jvmRepr: DataType = ScalaReflection.dataTypeFor[BigInteger] - def catalystRepr: DataType = DecimalType(DecimalType.MAX_PRECISION, 0) - - def toCatalyst(path: Expression): Expression = - StaticInvoke( - Decimal.getClass, - catalystRepr, - "apply", - path :: Nil - ) - - def fromCatalyst(path: Expression): Expression = - Invoke(path, "toJavaBigInteger", jvmRepr) - - override def toString: String = "javaBigIntEncoder" + override def jvmRepr: DataType = ScalaReflection.dataTypeFor[BigInteger] + override def agnosticEncoder: AgnosticEncoder[BigInteger] = JavaBigIntEncoder } implicit val sqlDate: TypedEncoder[SQLDate] = new TypedEncoder[SQLDate] { - def nullable: Boolean = false + // No direct equivalent of invoke <-> staticinvoke pairs but injection works + override def jvmRepr: DataType = ScalaReflection.dataTypeFor[SQLDate] - def jvmRepr: DataType = ScalaReflection.dataTypeFor[SQLDate] - def catalystRepr: DataType = DateType + val sqlDateAsDate: Injection[SQLDate, Int] = + new Injection[SQLDate, Int] { + def apply(a: SQLDate): Int = a.days - def toCatalyst(path: Expression): Expression = - Invoke(path, "days", DateType) + def invert(b: Int): SQLDate = SQLDate(b) + } - def fromCatalyst(path: Expression): Expression = - StaticInvoke( - staticObject = SQLDate.getClass, - dataType = jvmRepr, - functionName = "apply", - arguments = path :: Nil, - propagateNull = true - ) + override def agnosticEncoder: AgnosticEncoder[SQLDate] = + TransformingEncoder[SQLDate, Int]( + classTag, + PrimitiveIntEncoder, + InjectionCodecs.wrap(sqlDateAsDate)) } implicit val timestampEncoder: TypedEncoder[Timestamp] = new TypedEncoder[Timestamp] { - def nullable: Boolean = false - - def jvmRepr: DataType = ScalaReflection.dataTypeFor[Timestamp] - def catalystRepr: DataType = TimestampType - - def toCatalyst(path: Expression): Expression = - StaticInvoke( - DateTimeUtils.getClass, - TimestampType, - "fromJavaTimestamp", - path :: Nil, - returnNullable = false - ) - - def fromCatalyst(path: Expression): Expression = - StaticInvoke( - staticObject = DateTimeUtils.getClass, - dataType = jvmRepr, - functionName = "toJavaTimestamp", - arguments = path :: Nil, - propagateNull = true - ) - - override def toString: String = "timestampEncoder" + override def jvmRepr: DataType = ScalaReflection.dataTypeFor[Timestamp] + override def agnosticEncoder: AgnosticEncoder[Timestamp] = STRICT_TIMESTAMP_ENCODER } implicit val dateEncoder: TypedEncoder[Date] = new TypedEncoder[Date] { - def nullable: Boolean = false + // No direct equivalent of invoke <-> staticinvoke pairs but injection works - def jvmRepr: DataType = ScalaReflection.dataTypeFor[Date] - def catalystRepr: DataType = TimestampType + override def jvmRepr: DataType = ScalaReflection.dataTypeFor[Date] - private val instantRepr = ScalaReflection.dataTypeFor[Instant] + val dateAsInstant: Injection[Date, Instant] = + new Injection[Date, Instant] { + def apply(a: Date): Instant = a.toInstant - def toCatalyst(path: Expression): Expression = - timeInstant.toCatalyst(Invoke(path, "toInstant", instantRepr)) - - def fromCatalyst(path: Expression): Expression = - StaticInvoke( - staticObject = classOf[Date], - dataType = jvmRepr, - functionName = "from", - arguments = timeInstant.fromCatalyst(path) :: Nil, - propagateNull = true - ) + def invert(b: Instant): Date = Date.from(b) + } - override def toString: String = "dateEncoder" + override def agnosticEncoder: AgnosticEncoder[Date] = + TransformingEncoder[Date, Instant]( + classTag, + STRICT_INSTANT_ENCODER, + InjectionCodecs.wrap(dateAsInstant)) } implicit val sqlDateEncoder: TypedEncoder[java.sql.Date] = new TypedEncoder[java.sql.Date] { - def nullable: Boolean = false - - def jvmRepr: DataType = ScalaReflection.dataTypeFor[java.sql.Date] - def catalystRepr: DataType = DateType - - def toCatalyst(path: Expression): Expression = - StaticInvoke( - staticObject = DateTimeUtils.getClass, - dataType = catalystRepr, - functionName = "fromJavaDate", - arguments = path :: Nil, - propagateNull = true - ) - - private val localDateRepr = ScalaReflection.dataTypeFor[LocalDate] - - def fromCatalyst(path: Expression): Expression = { - val toLocalDate = StaticInvoke( - staticObject = DateTimeUtils.getClass, - dataType = localDateRepr, - functionName = "daysToLocalDate", - arguments = path :: Nil, - propagateNull = true - ) - - StaticInvoke( - staticObject = classOf[java.sql.Date], - dataType = jvmRepr, - functionName = "valueOf", - arguments = toLocalDate :: Nil, - propagateNull = true - ) - } - - override def toString: String = "sqlDateEncoder" + override def jvmRepr: DataType = ScalaReflection.dataTypeFor[java.sql.Date] + override def agnosticEncoder: AgnosticEncoder[sql.Date] = STRICT_DATE_ENCODER } implicit val sqlTimestamp: TypedEncoder[SQLTimestamp] = new TypedEncoder[SQLTimestamp] { - def nullable: Boolean = false + override def jvmRepr: DataType = ScalaReflection.dataTypeFor[SQLTimestamp] - def jvmRepr: DataType = ScalaReflection.dataTypeFor[SQLTimestamp] - def catalystRepr: DataType = TimestampType + val sqlTimestampAsLong: Injection[SQLTimestamp, Long] = + new Injection[SQLTimestamp, Long] { + def apply(a: SQLTimestamp): Long = a.us - def toCatalyst(path: Expression): Expression = - Invoke(path, "us", TimestampType) + def invert(b: Long): SQLTimestamp = SQLTimestamp(b) + } - def fromCatalyst(path: Expression): Expression = - StaticInvoke( - staticObject = SQLTimestamp.getClass, - dataType = jvmRepr, - functionName = "apply", - arguments = path :: Nil, - propagateNull = true - ) + override def agnosticEncoder: AgnosticEncoder[SQLTimestamp] = + TransformingEncoder[SQLTimestamp, Long]( + classTag, + PrimitiveLongEncoder, + InjectionCodecs.wrap(sqlTimestampAsLong)) } /** java.time Encoders, Spark uses https://github.com/apache/spark/blob/v3.2.0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala for encoding / decoding. */ implicit val timeInstant: TypedEncoder[Instant] = new TypedEncoder[Instant] { - def nullable: Boolean = false - - def jvmRepr: DataType = ScalaReflection.dataTypeFor[Instant] - def catalystRepr: DataType = TimestampType - - def toCatalyst(path: Expression): Expression = - StaticInvoke( - DateTimeUtils.getClass, - TimestampType, - "instantToMicros", - path :: Nil, - returnNullable = false - ) - - def fromCatalyst(path: Expression): Expression = - StaticInvoke( - staticObject = DateTimeUtils.getClass, - dataType = jvmRepr, - functionName = "microsToInstant", - arguments = path :: Nil, - propagateNull = true - ) + override def jvmRepr: DataType = ScalaReflection.dataTypeFor[Instant] + override def agnosticEncoder: AgnosticEncoder[Instant] = STRICT_INSTANT_ENCODER } /** @@ -447,68 +235,18 @@ object TypedEncoder { implicit def arrayEncoder[T: ClassTag]( implicit - i0: Lazy[RecordFieldEncoder[T]] + i0: Lazy[TypedEncoder[T]] ): TypedEncoder[Array[T]] = new TypedEncoder[Array[T]] { - private lazy val encodeT = i0.value.encoder + private lazy val encodeT = i0.value - def nullable: Boolean = false - - lazy val jvmRepr: DataType = i0.value.jvmRepr match { + override def jvmRepr: DataType = encodeT.jvmRepr match { case ByteType => BinaryType case _ => FramelessInternals.objectTypeFor[Array[T]] } - lazy val catalystRepr: DataType = i0.value.jvmRepr match { - case ByteType => BinaryType - case _ => ArrayType(encodeT.catalystRepr, encodeT.nullable) - } - - def toCatalyst(path: Expression): Expression = { - val enc = i0.value - - enc.jvmRepr match { - case IntegerType | LongType | DoubleType | FloatType | ShortType | - BooleanType => - StaticInvoke( - classOf[UnsafeArrayData], - catalystRepr, - "fromPrimitiveArray", - path :: Nil - ) - - case ByteType => path - - case _ => - MapObjects(enc.toCatalyst, path, enc.jvmRepr, encodeT.nullable) - } - } - - def fromCatalyst(path: Expression): Expression = - encodeT.jvmRepr match { - case IntegerType => Invoke(path, "toIntArray", jvmRepr) - case LongType => Invoke(path, "toLongArray", jvmRepr) - case DoubleType => Invoke(path, "toDoubleArray", jvmRepr) - case FloatType => Invoke(path, "toFloatArray", jvmRepr) - case ShortType => Invoke(path, "toShortArray", jvmRepr) - case BooleanType => Invoke(path, "toBooleanArray", jvmRepr) - - case ByteType => path - - case _ => - Invoke( - MapObjects( - i0.value.fromCatalyst, - path, - encodeT.catalystRepr, - encodeT.nullable - ), - "array", - jvmRepr - ) - } - - override def toString: String = s"arrayEncoder($jvmRepr)" + override def agnosticEncoder: AgnosticEncoder[Array[T]] = + ArrayEncoder(encodeT.agnosticEncoder, containsNull = false) } /** @@ -564,30 +302,29 @@ object TypedEncoder { implicit def seqEncoder[C[X] <: Seq[X], T]( implicit - i0: Lazy[RecordFieldEncoder[T]], + i0: Lazy[TypedEncoder[T]], i1: ClassTag[C[T]], i2: CollectionConversion[Seq, C, T] ) = collectionEncoder[Seq, C, T] implicit def setEncoder[C[X] <: Set[X], T]( implicit - i0: Lazy[RecordFieldEncoder[T]], + i0: Lazy[TypedEncoder[T]], i1: ClassTag[C[T]], i2: CollectionConversion[Set, C, T] ) = collectionEncoder[Set, C, T] def collectionEncoder[O[_], C[X], T]( implicit - i0: Lazy[RecordFieldEncoder[T]], + i0: Lazy[TypedEncoder[T]], i1: ClassTag[C[T]], i2: CollectionConversion[O, C, T] ): TypedEncoder[C[T]] = new TypedEncoder[C[T]] { - private lazy val encodeT = i0.value.encoder - - def nullable: Boolean = false + private lazy val encodeT = i0.value - def jvmRepr: DataType = FramelessInternals.objectTypeFor[C[T]](i1) + override def jvmRepr: DataType = FramelessInternals.objectTypeFor[C[T]](i1) +/* def catalystRepr: DataType = ArrayType(encodeT.catalystRepr, encodeT.nullable) @@ -619,7 +356,17 @@ object TypedEncoder { implicitly[CollectionConversion[O, C, T]] ) // This will convert Seq to the appropriate C[_] when eval'ing. - override def toString: String = s"collectionEncoder($jvmRepr)" + override def toString: String = s"collectionEncoder($jvmRepr)" */ + + /** + * Create the underlying AgnosticEncoder + */ + override def agnosticEncoder: AgnosticEncoder[C[T]] = + IterableEncoder( + ClassTag(i1.runtimeClass), + encodeT.agnosticEncoder, + containsNull = false, + lenientSerialization = false).asInstanceOf[AgnosticEncoder[C[T]]] // only C is provided } /** @@ -630,16 +377,21 @@ object TypedEncoder { */ implicit def mapEncoder[A: NotCatalystNullable, B]( implicit - i0: Lazy[RecordFieldEncoder[A]], - i1: Lazy[RecordFieldEncoder[B]] + i0: Lazy[TypedEncoder[A]], + i1: Lazy[TypedEncoder[B]] ): TypedEncoder[Map[A, B]] = new TypedEncoder[Map[A, B]] { - def nullable: Boolean = false - - def jvmRepr: DataType = FramelessInternals.objectTypeFor[Map[A, B]] - - private lazy val encodeA = i0.value.encoder - private lazy val encodeB = i1.value.encoder - + override def jvmRepr: DataType = FramelessInternals.objectTypeFor[Map[A, B]] + + private lazy val encodeA = i0.value + private lazy val encodeB = i1.value + + override def agnosticEncoder: AgnosticEncoder[Map[A, B]] = + MapEncoder( + classTag, + encodeA.agnosticEncoder, + encodeB.agnosticEncoder, + valueContainsNull = false) +/* lazy val catalystRepr: DataType = MapType(encodeA.catalystRepr, encodeB.catalystRepr, encodeB.nullable) @@ -691,7 +443,9 @@ object TypedEncoder { ) } - override def toString = s"mapEncoder($jvmRepr)" + override def toString = s"mapEncoder($jvmRepr)"*/ + + } implicit def optionEncoder[A]( @@ -699,11 +453,10 @@ object TypedEncoder { underlying: TypedEncoder[A] ): TypedEncoder[Option[A]] = new TypedEncoder[Option[A]] { - def nullable: Boolean = true - def jvmRepr: DataType = + override def jvmRepr: DataType = FramelessInternals.objectTypeFor[Option[A]](classTag) - +/* def catalystRepr: DataType = underlying.catalystRepr def toCatalyst(path: Expression): Expression = { @@ -770,7 +523,12 @@ object TypedEncoder { } def fromCatalyst(path: Expression): Expression = - WrapOption(underlying.fromCatalyst(path), underlying.jvmRepr) + WrapOption(underlying.fromCatalyst(path), underlying.jvmRepr)*/ + + /** + * Create the underlying AgnosticEncoder + */ + override def agnosticEncoder: AgnosticEncoder[Option[A]] = OptionEncoder(underlying.agnosticEncoder) } /** Encodes things using injection if there is one defined */ @@ -780,19 +538,13 @@ object TypedEncoder { trb: TypedEncoder[B] ): TypedEncoder[A] = new TypedEncoder[A] { - def nullable: Boolean = trb.nullable - def jvmRepr: DataType = FramelessInternals.objectTypeFor[A](classTag) - def catalystRepr: DataType = trb.catalystRepr - - def fromCatalyst(path: Expression): Expression = { - val bexpr = trb.fromCatalyst(path) - Invoke(Literal.fromObject(inj), "invert", jvmRepr, Seq(bexpr)) - } + override def jvmRepr: DataType = FramelessInternals.objectTypeFor[A](classTag) - def toCatalyst(path: Expression): Expression = - trb.toCatalyst( - Invoke(Literal.fromObject(inj), "apply", trb.jvmRepr, Seq(path)) - ) + override def agnosticEncoder: AgnosticEncoder[A] = + TransformingEncoder[A, B]( + classTag, + trb.agnosticEncoder, + InjectionCodecs.wrap(inj)) } /** Encodes things as records if there is no Injection defined */ @@ -811,19 +563,15 @@ object TypedEncoder { A >: Null: UserDefinedType: ClassTag ]: TypedEncoder[A] = { val udt = implicitly[UserDefinedType[A]] - val udtInstance = - NewInstance(udt.getClass, Nil, dataType = ObjectType(udt.getClass)) new TypedEncoder[A] { - def nullable: Boolean = false - def jvmRepr: DataType = ObjectType(udt.userClass) - def catalystRepr: DataType = udt + override def jvmRepr: DataType = ObjectType(udt.userClass) - def toCatalyst(path: Expression): Expression = - Invoke(udtInstance, "serialize", udt, Seq(path)) - - def fromCatalyst(path: Expression): Expression = - Invoke(udtInstance, "deserialize", ObjectType(udt.userClass), Seq(path)) + /** + * Create the underlying AgnosticEncoder + */ + override def agnosticEncoder: AgnosticEncoder[A] = + UDTEncoder[A](udt, udt.getClass) } } diff --git a/dataset/src/main/scala/frameless/TypedExpressionEncoder.scala b/dataset/src/main/scala/frameless/TypedExpressionEncoder.scala index 62d06a80..26b29701 100644 --- a/dataset/src/main/scala/frameless/TypedExpressionEncoder.scala +++ b/dataset/src/main/scala/frameless/TypedExpressionEncoder.scala @@ -1,6 +1,7 @@ package frameless import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.types.StructType object TypedExpressionEncoder { @@ -18,15 +19,6 @@ object TypedExpressionEncoder { def apply[T]( implicit encoder: TypedEncoder[T] - ): Encoder[T] = { - import encoder._ - org.apache.spark.sql.ShimUtils.expressionEncoder[T]( - jvmRepr, - nullable, - toCatalyst, - catalystRepr, - fromCatalyst - ) - } + ): AgnosticEncoder[T] = encoder.agnosticEncoder } diff --git a/dataset/src/main/scala/frameless/functions/NonAggregateFunctions.scala b/dataset/src/main/scala/frameless/functions/NonAggregateFunctions.scala index 396b7ff4..3ee42f45 100644 --- a/dataset/src/main/scala/frameless/functions/NonAggregateFunctions.scala +++ b/dataset/src/main/scala/frameless/functions/NonAggregateFunctions.scala @@ -346,7 +346,7 @@ trait NonAggregateFunctions { r.typed( sparkFunctions.atan2(l.cast[Double].untyped, r.cast[Double].untyped) ) - +/* def atan2[B, T]( l: Double, r: TypedColumn[T, B] @@ -378,7 +378,7 @@ trait NonAggregateFunctions { i0: CatalystCast[A, Double] ): TypedAggregate[T, Double] = atan2(l, l.lit(r)) - +*/ /** * Non-Aggregate function: returns the square root value of a numeric column. * diff --git a/dataset/src/main/scala/frameless/functions/Udf.scala b/dataset/src/main/scala/frameless/functions/Udf.scala index c34e8561..380d058d 100644 --- a/dataset/src/main/scala/frameless/functions/Udf.scala +++ b/dataset/src/main/scala/frameless/functions/Udf.scala @@ -1,14 +1,11 @@ package frameless package functions -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{ - Expression, - LeafExpression, - NonSQLExpression -} +import org.apache.spark.sql.catalyst.{InternalRow, SerializerBuildHelper} +import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, NonSQLExpression} import org.apache.spark.sql.catalyst.expressions.codegen._ import Block._ +import org.apache.spark.sql.catalyst.CatalystTypeConverters.{createToCatalystConverter, isPrimitive} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.types.DataType import shapeless.syntax.std.tuple._ @@ -155,7 +152,7 @@ case class FramelessUdf[T, R]( override def toString: String = s"FramelessUdf(${children.mkString(", ")})" lazy val typedEnc = - TypedExpressionEncoder[R](rencoder).asInstanceOf[ExpressionEncoder[R]] + ExpressionEncoder(TypedExpressionEncoder[R](rencoder)) lazy val isSerializedAsStructForTopLevel = typedEnc.isSerializedAsStructForTopLevel @@ -177,9 +174,23 @@ case class FramelessUdf[T, R]( retval } + private def catalystConverter: Any => Any = { + val toRow = typedEnc.createSerializer().asInstanceOf[Any => Any] + if (isSerializedAsStructForTopLevel) { + value: Any => + if (value == null) null else toRow(value).asInstanceOf[InternalRow] + } else { + value: Any => + if (value == null) null else toRow(value).asInstanceOf[InternalRow].get(0, dataType) + } + } + def dataType: DataType = rencoder.catalystRepr override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val retConverter = catalystConverter + val retConverterTerm = ctx.addReferenceObj("retConverter", retConverter, classOf[Any => Any].getName) + ctx.references += this // save reference to `function` field from `FramelessUdf` to call it later @@ -211,9 +222,21 @@ case class FramelessUdf[T, R]( val internalTpe = CodeGenerator.boxedType(rencoder.jvmRepr) val internalTerm = ctx.addMutableState(internalTpe, ctx.freshName("internal")) - val internalNullTerm = - ctx.addMutableState("boolean", ctx.freshName("internalNull")) - // CTw - can't inject the term, may have to duplicate old code for parity + + val actualFuncCall = s"($internalTpe)$funcTerm.apply(${funcArguments.mkString(", ")})" + + // invocation logic taken from Spark4 ScalaUDF + val funcInvocation = + if (rencoder.agnosticEncoder.isPrimitive + // If the output is nullable, the returned value must be unwrapped from the Option + && !nullable) { + s"$internalTerm = $actualFuncCall;" + } else { + s"""$internalTerm = ($internalTpe)$retConverterTerm.apply( + $funcTerm.apply(${funcArguments.mkString(", ")}) + );""" + } + /*// CTw - can't inject the term, may have to duplicate old code for parity val internalExpr = Spark2_4_LambdaVariable( internalTerm, internalNullTerm, @@ -221,20 +244,19 @@ case class FramelessUdf[T, R]( true ) - val resultEval = rencoder.toCatalyst(internalExpr).genCode(ctx) - + val resultEval = typedEnc.createSerializer(). + .toCatalyst(internalExpr).genCode(ctx) +*/ ev.copy( code = code""" ${argsCode.mkString("\n")} + $funcInvocation - $internalTerm = - ($internalTpe)$funcTerm.apply(${funcArguments.mkString(", ")}); - $internalNullTerm = $internalTerm == null; - - ${resultEval.code} - """, - value = resultEval.value, - isNull = resultEval.isNull + boolean ${ev.isNull} = $internalTerm == null; + if (!${ev.isNull}) { + ${ev.value} = $internalTerm; + } + """ ) } @@ -291,7 +313,7 @@ object FramelessUdf { ): FramelessUdf[T, R] = FramelessUdf( function = function, encoders = cols.map(_.uencoder).toList, - children = cols.map(x => x.uencoder.fromCatalyst(x.expr)).toList, + children = cols.map(_.expr).toList, rencoder = rencoder, evalFunction = evalFunction ) diff --git a/dataset/src/main/scala/frameless/ops/GroupByOps.scala b/dataset/src/main/scala/frameless/ops/GroupByOps.scala index 8ffaa563..bc20b5f6 100644 --- a/dataset/src/main/scala/frameless/ops/GroupByOps.scala +++ b/dataset/src/main/scala/frameless/ops/GroupByOps.scala @@ -3,19 +3,13 @@ package ops import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias import org.apache.spark.sql.catalyst.plans.logical.Project -import org.apache.spark.sql.{ Column, Dataset, RelationalGroupedDataset } +import org.apache.spark.sql.{Column, Dataset, RelationalGroupedDataset} import shapeless._ -import shapeless.ops.hlist.{ - Length, - Mapped, - Prepend, - ToList, - ToTraversable, - Tupler -} -import com.sparkutils.shim.expressions.{ MapGroups4 => MapGroups } +import shapeless.ops.hlist.{Length, Mapped, Prepend, ToList, ToTraversable, Tupler} +import com.sparkutils.shim.expressions.{MapGroups4 => MapGroups} import frameless.FramelessInternals import org.apache.spark.sql.ShimUtils.column +import org.apache.spark.sql.classic.ClassicConversions.castToImpl class GroupedByManyOps[T, TK <: HList, K <: HList, KT]( self: TypedDataset[T],