Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#803 - add eval udf implementation - includes / requires #804 #806

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
d3ddaf1
#755 - correct version number in readme for non 3.5 build
chris-twiner Sep 29, 2023
a435adc
Merge branch 'master' of github.com:typelevel/frameless
chris-twiner Dec 27, 2023
24bde95
Merge branch 'master' of github.com:typelevel/frameless
chris-twiner Feb 20, 2024
2fa1bb0
(cherry picked from commit 955ba829779010d43b9f37ec438f0c8eaea76e0e)
chris-twiner Mar 20, 2024
ee38804
#804 - starter fix, set needed
chris-twiner Mar 20, 2024
fb1c109
#804 - encoding for Set derivatives as well - test build
chris-twiner Mar 20, 2024
ae8b69a
#804 - encoding for Set derivatives as well - test build
chris-twiner Mar 20, 2024
0435c3a
#804 - encoding for Set derivatives as well - test build
chris-twiner Mar 20, 2024
52034b2
#804 - encoding for Set derivatives as well - test build, hashtrieset…
chris-twiner Mar 20, 2024
9e45d92
#804 - encoding for Set derivatives as well - test build, hashtrieset…
chris-twiner Mar 20, 2024
e7881c0
#804 - encoding for Set derivatives as well - test build, 2.13 forced…
chris-twiner Mar 20, 2024
594fceb
#804 - encoding for Set derivatives as well - test build, 2.13 forced…
chris-twiner Mar 20, 2024
5a01976
#804 - encoding for Set derivatives as well - test build, 2.13 forced…
chris-twiner Mar 20, 2024
365b21f
#804 - encoding for Set derivatives as well - test build, 2.13 forced…
chris-twiner Mar 20, 2024
4395c16
#804 - encoding for Set derivatives as well - test build, 2.13 forced…
chris-twiner Mar 20, 2024
c792c05
Merge remote-tracking branch 'upstream/master' into temp/804_clean
chris-twiner Mar 20, 2024
f0d5f16
#804 - rebased
chris-twiner Mar 20, 2024
3bdb8ad
#803 - clean udf from #804, no shim start
chris-twiner Mar 21, 2024
c2f3492
#803 - clean udf eval needs #804
chris-twiner Mar 21, 2024
08d7c3d
#803 - clean udf eval needs #804
chris-twiner Mar 21, 2024
e157cdb
#803 - stream
chris-twiner Mar 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,14 @@ lazy val datasetSettings =
mc("frameless.functions.FramelessLit"),
mc(f"frameless.functions.FramelessLit$$"),
dmm("frameless.functions.package.litAggr"),
dmm("org.apache.spark.sql.FramelessInternals.column")
dmm("org.apache.spark.sql.FramelessInternals.column"),
dmm("frameless.TypedEncoder.collectionEncoder"),
dmm("frameless.TypedEncoder.setEncoder"),
dmm("frameless.functions.FramelessUdf.evalCode"),
dmm("frameless.functions.FramelessUdf.copy"),
dmm("frameless.functions.FramelessUdf.this"),
dmm("frameless.functions.FramelessUdf.apply"),
imt("frameless.functions.FramelessUdf.apply")
)
},
coverageExcludedPackages := "org.apache.spark.sql.reflection",
Expand Down
67 changes: 67 additions & 0 deletions dataset/src/main/scala/frameless/CollectionCaster.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package frameless

import frameless.TypedEncoder.CollectionConversion
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{
CodegenContext,
CodegenFallback,
ExprCode
}
import org.apache.spark.sql.catalyst.expressions.{ Expression, UnaryExpression }
import org.apache.spark.sql.types.{ DataType, ObjectType }

case class CollectionCaster[F[_], C[_], Y](
child: Expression,
conversion: CollectionConversion[F, C, Y])
extends UnaryExpression
with CodegenFallback {

protected def withNewChildInternal(newChild: Expression): Expression =
copy(child = newChild)

override def eval(input: InternalRow): Any = {
val o = child.eval(input).asInstanceOf[Object]
o match {
case col: F[Y] @unchecked =>
conversion.convert(col)
case _ => o
}
}

override def dataType: DataType = child.dataType
}

case class SeqCaster[C[X] <: Iterable[X], Y](child: Expression)
extends UnaryExpression {

protected def withNewChildInternal(newChild: Expression): Expression =
copy(child = newChild)

// eval on interpreted works, fallback on codegen does not, e.g. with ColumnTests.asCol and Vectors, the code generated still has child of type Vector but child eval returns X2, which is not good
override def eval(input: InternalRow): Any = {
val o = child.eval(input).asInstanceOf[Object]
o match {
case col: Set[Y] @unchecked =>
col.toSeq
case _ => o
}
}

def toSeqOr[T](isSet: => T, or: => T): T =
child.dataType match {
case ObjectType(cls)
if classOf[scala.collection.Set[_]].isAssignableFrom(cls) =>
isSet
case t => or
}

override def dataType: DataType =
toSeqOr(ObjectType(classOf[scala.collection.Seq[_]]), child.dataType)

override protected def doGenCode(
ctx: CodegenContext,
ev: ExprCode
): ExprCode =
defineCodeGen(ctx, ev, c => toSeqOr(s"$c.toVector()", s"$c"))

}
121 changes: 88 additions & 33 deletions dataset/src/main/scala/frameless/TypedEncoder.scala
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
package frameless

import java.math.BigInteger

import java.util.Date

import java.time.{ Duration, Instant, Period, LocalDate }

import java.time.{ Duration, Instant, LocalDate, Period }
import java.sql.Timestamp

import scala.reflect.ClassTag

import org.apache.spark.sql.FramelessInternals
import org.apache.spark.sql.FramelessInternals.UserDefinedType
import org.apache.spark.sql.{ reflection => ScalaReflection }
Expand All @@ -22,10 +17,11 @@ import org.apache.spark.sql.catalyst.util.{
}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

import shapeless._
import shapeless.ops.hlist.IsHCons

import scala.collection.immutable.{ ListSet, TreeSet }

abstract class TypedEncoder[T](
implicit
val classTag: ClassTag[T])
Expand Down Expand Up @@ -501,10 +497,76 @@ object TypedEncoder {
override def toString: String = s"arrayEncoder($jvmRepr)"
}

implicit def collectionEncoder[C[X] <: Seq[X], T](
/**
* Per #804 - when MapObjects is used in interpreted mode the type returned is Seq, not the derived type used in compilation
*
* This type class offers extensible conversion for more specific types. By default Seq, List and Vector for Seq's and Set, TreeSet and ListSet are supported.
*
* @tparam C
*/
trait CollectionConversion[F[_], C[_], Y] extends Serializable {
def convert(c: F[Y]): C[Y]
}

object CollectionConversion {

implicit def seqToSeq[Y] = new CollectionConversion[Seq, Seq, Y] {

override def convert(c: Seq[Y]): Seq[Y] =
c match {
// Stream is produced
case _: Stream[Y] @unchecked => c.toVector.toSeq
case _ => c
}
}

implicit def seqToVector[Y] = new CollectionConversion[Seq, Vector, Y] {
override def convert(c: Seq[Y]): Vector[Y] = c.toVector
}

implicit def seqToList[Y] = new CollectionConversion[Seq, List, Y] {
override def convert(c: Seq[Y]): List[Y] = c.toList
}

implicit def setToSet[Y] = new CollectionConversion[Set, Set, Y] {
override def convert(c: Set[Y]): Set[Y] = c
}

implicit def setToTreeSet[Y](
implicit
ordering: Ordering[Y]
) = new CollectionConversion[Set, TreeSet, Y] {

override def convert(c: Set[Y]): TreeSet[Y] =
TreeSet.newBuilder.++=(c).result()
}

implicit def setToListSet[Y] = new CollectionConversion[Set, ListSet, Y] {

override def convert(c: Set[Y]): ListSet[Y] =
ListSet.newBuilder.++=(c).result()
}
}

implicit def seqEncoder[C[X] <: Seq[X], T](
implicit
i0: Lazy[RecordFieldEncoder[T]],
i1: ClassTag[C[T]],
i2: CollectionConversion[Seq, C, T]
) = collectionEncoder[Seq, C, T]

implicit def setEncoder[C[X] <: Set[X], T](
implicit
i0: Lazy[RecordFieldEncoder[T]],
i1: ClassTag[C[T]],
i2: CollectionConversion[Set, C, T]
) = collectionEncoder[Set, C, T]

def collectionEncoder[O[_], C[X], T](
implicit
i0: Lazy[RecordFieldEncoder[T]],
i1: ClassTag[C[T]]
i1: ClassTag[C[T]],
i2: CollectionConversion[O, C, T]
): TypedEncoder[C[T]] = new TypedEncoder[C[T]] {
private lazy val encodeT = i0.value.encoder

Expand All @@ -521,38 +583,31 @@ object TypedEncoder {
if (ScalaReflection.isNativeType(enc.jvmRepr)) {
NewInstance(classOf[GenericArrayData], path :: Nil, catalystRepr)
} else {
MapObjects(enc.toCatalyst, path, enc.jvmRepr, encodeT.nullable)
// converts to Seq, both Set and Seq handling must convert to Seq first
MapObjects(
enc.toCatalyst,
SeqCaster(path),
enc.jvmRepr,
encodeT.nullable
)
}
}

def fromCatalyst(path: Expression): Expression =
MapObjects(
i0.value.fromCatalyst,
path,
encodeT.catalystRepr,
encodeT.nullable,
Some(i1.runtimeClass) // This will cause MapObjects to build a collection of type C[_] directly
)
CollectionCaster[O, C, T](
MapObjects(
i0.value.fromCatalyst,
path,
encodeT.catalystRepr,
encodeT.nullable,
Some(i1.runtimeClass) // This will cause MapObjects to build a collection of type C[_] directly when compiling
),
implicitly[CollectionConversion[O, C, T]]
) // This will convert Seq to the appropriate C[_] when eval'ing.

override def toString: String = s"collectionEncoder($jvmRepr)"
}

/**
* @param i1 implicit lazy `RecordFieldEncoder[T]` to encode individual elements of the set.
* @param i2 implicit `ClassTag[Set[T]]` to provide runtime information about the set type.
* @tparam T the element type of the set.
* @return a `TypedEncoder` instance for `Set[T]`.
*/
implicit def setEncoder[T](
implicit
i1: shapeless.Lazy[RecordFieldEncoder[T]],
i2: ClassTag[Set[T]]
): TypedEncoder[Set[T]] = {
implicit val inj: Injection[Set[T], Seq[T]] = Injection(_.toSeq, _.toSet)

TypedEncoder.usingInjection
}

/**
* @tparam A the key type
* @tparam B the value type
Expand Down
Loading
Loading