Skip to content

Commit

Permalink
safety only - lits pending interface update
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-twiner committed Jan 29, 2025
1 parent a549807 commit 8c75d4f
Show file tree
Hide file tree
Showing 11 changed files with 263 additions and 493 deletions.
4 changes: 2 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
val sparkVersion =
// "3.5.1" //
"4.0.0-preview2" // must have the apache_snaps configured
"4.1.0-SNAPSHOT" // must have the apache_snaps configured 4.1.0-SNAPSHOT
val spark34Version = "3.4.2"
val spark33Version = "3.3.4"
val catsCoreVersion = "2.10.0"
Expand All @@ -13,7 +13,7 @@ val scalacheck = "1.17.0"
val scalacheckEffect = "1.0.4"
val refinedVersion = "0.11.1"
val nakedFSVersion = "0.1.0"
val shimVersion = "0.0.1-RC5-SNAPSHOT"
val shimVersion = "0.0.2-SNAPSHOT"

val Scala212 = "2.12.19"
val Scala213 = "2.13.13"
Expand Down
1 change: 1 addition & 0 deletions core/src/main/scala/frameless/Injection.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ object Injection {
def apply(a: A): B = f(a)
def invert(b: B): A = g(b)
}

}
2 changes: 1 addition & 1 deletion dataset/src/main/scala/frameless/FramelessInternals.scala
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ object FramelessInternals {
plan: LogicalPlan,
encoder: Encoder[T]
): Dataset[T] =
new Dataset(sqlContext, plan, encoder)
new classic.Dataset(sqlContext, plan, encoder)

def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame =
shimUtils.ofRows(sparkSession, logicalPlan)
Expand Down
76 changes: 44 additions & 32 deletions dataset/src/main/scala/frameless/RecordEncoder.scala
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
package frameless

import com.sparkutils.shim.expressions.{
CreateNamedStruct1 => CreateNamedStruct,
GetStructField3 => GetStructField,
UnwrapOption2 => UnwrapOption,
WrapOption2 => WrapOption
}
import com.sparkutils.shim.{ deriveUnitLiteral, ifIsNull }
import org.apache.spark.sql.catalyst.expressions.{ Expression, Literal }
import org.apache.spark.sql.shim.{
Invoke5 => Invoke,
NewInstance4 => NewInstance
}
import com.sparkutils.shim.expressions.{CreateNamedStruct1 => CreateNamedStruct, GetStructField3 => GetStructField, UnwrapOption2 => UnwrapOption, WrapOption2 => WrapOption}
import com.sparkutils.shim.{deriveUnitLiteral, ifIsNull}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{EncoderField, ProductEncoder}
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
import org.apache.spark.sql.shim.{Invoke5 => Invoke, NewInstance4 => NewInstance}
import org.apache.spark.sql.types._
import shapeless._
import shapeless.labelled.FieldType
Expand All @@ -37,7 +31,7 @@ object RecordEncoderFields {
implicit def deriveRecordLast[K <: Symbol, H](
implicit
key: Witness.Aux[K],
head: RecordFieldEncoder[H]
head: TypedEncoder[H]
): RecordEncoderFields[FieldType[K, H] :: HNil] =
new RecordEncoderFields[FieldType[K, H] :: HNil] {
def value: List[RecordEncoderField] = fieldEncoder[K, H] :: Nil
Expand All @@ -46,7 +40,7 @@ object RecordEncoderFields {
implicit def deriveRecordCons[K <: Symbol, H, T <: HList](
implicit
key: Witness.Aux[K],
head: RecordFieldEncoder[H],
head: TypedEncoder[H],
tail: RecordEncoderFields[T]
): RecordEncoderFields[FieldType[K, H] :: T] =
new RecordEncoderFields[FieldType[K, H] :: T] {
Expand All @@ -60,8 +54,8 @@ object RecordEncoderFields {
private def fieldEncoder[K <: Symbol, H](
implicit
key: Witness.Aux[K],
e: RecordFieldEncoder[H]
): RecordEncoderField = RecordEncoderField(0, key.value.name, e.encoder)
e: TypedEncoder[H]
): RecordEncoderField = RecordEncoderField(0, key.value.name, e)
}

/**
Expand Down Expand Up @@ -154,7 +148,19 @@ class RecordEncoder[F, G <: HList, H <: HList](
newInstanceExprs: Lazy[NewInstanceExprs[G]],
classTag: ClassTag[F])
extends TypedEncoder[F] {
def nullable: Boolean = false

override def agnosticEncoder: AgnosticEncoder[F] =
ProductEncoder[F](
classTag,
fields.value.value.map(f => EncoderField(
f.name,
f.encoder.agnosticEncoder,
f.encoder.nullable,
Metadata.empty) ),
None)

override def jvmRepr: DataType = FramelessInternals.objectTypeFor[F]
/*def nullable: Boolean = false
def jvmRepr: DataType = FramelessInternals.objectTypeFor[F]
Expand Down Expand Up @@ -201,17 +207,11 @@ class RecordEncoder[F, G <: HList, H <: HList](
NewInstance(classTag.runtimeClass, newArgs, jvmRepr, propagateNull = true)
ifIsNull(jvmRepr, path, newExpr)
}
}
}*/

final class RecordFieldEncoder[T](
val encoder: TypedEncoder[T],
private[frameless] val jvmRepr: DataType,
private[frameless] val fromCatalyst: Expression => Expression,
private[frameless] val toCatalyst: Expression => Expression)
extends Serializable
}

object RecordFieldEncoder extends RecordFieldEncoderLowPriority {
object RecordFieldEncoder /*extends RecordFieldEncoderLowPriority */{

/**
* @tparam F the value class
Expand All @@ -235,7 +235,9 @@ object RecordFieldEncoder extends RecordFieldEncoderLowPriority {
i4: IsHCons.Aux[KS, K, HNil],
i5: TypedEncoder[V],
i6: ClassTag[F]
): RecordFieldEncoder[Option[F]] = {
): TypedEncoder[Option[F]] = {
TypedEncoder.optionEncoder(valueClass)
/*
val fieldName = i4.head(i3()).name
val innerJvmRepr = ObjectType(i6.runtimeClass)
Expand Down Expand Up @@ -278,7 +280,7 @@ object RecordFieldEncoder extends RecordFieldEncoderLowPriority {
jvmRepr = jvmr,
fromCatalyst = fromCatalyst,
toCatalyst = catalyst
)
) */
}

/**
Expand All @@ -302,8 +304,17 @@ object RecordFieldEncoder extends RecordFieldEncoderLowPriority {
i4: IsHCons.Aux[KS, K, HNil],
i5: TypedEncoder[V],
i6: ClassTag[F]
): RecordFieldEncoder[F] = {
val cls = i6.runtimeClass
): TypedEncoder[F] = new TypedEncoder[F]() {
override def agnosticEncoder: AgnosticEncoder[F] = ProductEncoder[F](
i6,
Seq(EncoderField(
i4.head(i3()).name,
i5.agnosticEncoder,
i5.nullable,
Metadata.empty)),
None)

/* { val cls = i6.runtimeClass
val jvmr = i5.jvmRepr
val fieldName = i4.head(i3()).name
Expand Down Expand Up @@ -335,10 +346,10 @@ object RecordFieldEncoder extends RecordFieldEncoderLowPriority {
toCatalyst = { expr: Expression =>
i5.toCatalyst(Invoke(expr, fieldName, jvmr))
}
)
)*/
}
}

/*
private[frameless] sealed trait RecordFieldEncoderLowPriority {
implicit def apply[T](
Expand All @@ -347,3 +358,4 @@ private[frameless] sealed trait RecordFieldEncoderLowPriority {
): RecordFieldEncoder[T] =
new RecordFieldEncoder[T](e, e.jvmRepr, e.fromCatalyst, e.toCatalyst)
}
*/
21 changes: 11 additions & 10 deletions dataset/src/main/scala/frameless/TypedDataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference,
import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint}
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.ShimUtils.column
import org.apache.spark.sql.classic.ClassicConversions.castToImpl
import org.apache.spark.sql.types.StructType
import shapeless._
import shapeless.labelled.FieldType
Expand Down Expand Up @@ -756,8 +757,8 @@ class TypedDataset[T] protected[frameless] (
e: TypedEncoder[(T, U)]
): TypedDataset[(T, U)] =
new TypedDataset(
ShimUtils.joinWith(dataset, other.dataset, column(Literal(true)), "cross")(TypedExpressionEncoder[(T, U)])
//self.dataset.joinWith(other.dataset, column(Literal(true)), "cross")
//ShimUtils.joinWith(dataset, other.dataset, column(Literal(true)), "cross")(TypedExpressionEncoder[(T, U)])
self.dataset.joinWith(other.dataset, column(Literal(true)), "cross")
)

/**
Expand All @@ -772,13 +773,13 @@ class TypedDataset[T] protected[frameless] (
to: TypedEncoder[(T, U)]
): TypedDataset[(Option[T], Option[U])] =
new TypedDataset(
ShimUtils.joinWith(dataset, other.dataset, condition.untyped, "full")(TypedExpressionEncoder[(T, U)])
.as[(Option[T], Option[U])](TypedExpressionEncoder[(Option[T], Option[U])])
/*self.dataset
//ShimUtils.joinWith(dataset, other.dataset, condition.untyped, "full")(TypedExpressionEncoder[(T, U)])
// .as[(Option[T], Option[U])](TypedExpressionEncoder[(Option[T], Option[U])])
self.dataset
.joinWith(other.dataset, condition.untyped, "full")
.as[(Option[T], Option[U])](
TypedExpressionEncoder[(Option[T], Option[U])]
)*/
)
)

/**
Expand Down Expand Up @@ -817,11 +818,11 @@ class TypedDataset[T] protected[frameless] (
to: TypedEncoder[(T, U)]
): TypedDataset[(T, Option[U])] =
new TypedDataset(
ShimUtils.joinWith(dataset, other.dataset, condition.untyped, "left_outer")(TypedExpressionEncoder[(T, U)])
.as[(T, Option[U])](TypedExpressionEncoder[(T, Option[U])])
/*self.dataset
//ShimUtils.joinWith(dataset, other.dataset, condition.untyped, "left_outer")(TypedExpressionEncoder[(T, U)])
// .as[(T, Option[U])](TypedExpressionEncoder[(T, Option[U])])
self.dataset
.joinWith(other.dataset, condition.untyped, "left_outer")
.as[(T, Option[U])](TypedExpressionEncoder[(T, Option[U])])*/
.as[(T, Option[U])](TypedExpressionEncoder[(T, Option[U])])
)

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package frameless

import java.util

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.classic.ClassicConversions.castToImpl
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.streaming.DataStreamWriter
import org.apache.spark.sql.types.StructType
Expand Down
Loading

0 comments on commit 8c75d4f

Please sign in to comment.