Skip to content

Commit

Permalink
4 tests failing - encoding derivation issue
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-twiner committed Feb 7, 2025
1 parent 91098b4 commit 9c2dd30
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 10 deletions.
4 changes: 1 addition & 3 deletions core/src/main/scala/frameless/CatalystSummable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@ import scala.annotation.implicitNotFound
* - Short -> Long
*/
@implicitNotFound("Cannot compute sum of type ${In}.")
trait CatalystSummable[In, Out] {
def zero: In
}
trait CatalystSummable[In, Out] extends CatalystZero[In]

object CatalystSummable {
def apply[In, Out](zero: In): CatalystSummable[In, Out] = {
Expand Down
53 changes: 53 additions & 0 deletions core/src/main/scala/frameless/CatalystZero.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package frameless

import shapeless.{Generic, HList, Lazy, Poly1}
import shapeless.ops.hlist.{LiftAll, Mapper}

import java.time.{Duration, Instant, Period}
import scala.annotation.implicitNotFound

/** Types that can be provided with zero's by Catalyst - no zero's were hurt during this
* Used by min/first/ etc. to provide generalised coalesce, the zero value is used to
* represent null handling only and should not be received, if it is desired then you
* must break free of TypedDataset's guardrails.
*/
@implicitNotFound("Cannot provide zero value columns of type ${A}.")
abstract class CatalystZero[A](implicit ev: NotCatalystNullable[A]) {
def zero: A
}

object CatalystZero {
def apply[A: NotCatalystNullable](zero: A): CatalystZero[A] = {
val _zero = zero
new CatalystZero[A] { val zero: A = _zero }
}

implicit val framelessZeroLong : CatalystZero[Long] = CatalystZero(zero = 0L)
implicit val framelessZeroBigDecimal: CatalystZero[BigDecimal] = CatalystZero(zero = BigDecimal(0))
implicit val framelessZeroDouble : CatalystZero[Double] = CatalystZero(zero = 0.0)
implicit val framelessZeroInt : CatalystZero[Int] = CatalystZero(zero = 0)
implicit val framelessZeroShort : CatalystZero[Short] = CatalystZero(zero = 0)
implicit val framelessZeroBooleanOrdered : CatalystZero[Boolean] = CatalystZero(zero = false)
implicit val framelessZeroByte : CatalystZero[Byte] = CatalystZero(zero = 0)
implicit val framelessZeroFloat : CatalystZero[Float] = CatalystZero(zero = 0.0f)
implicit val framelessZeroSQLDate : CatalystZero[SQLDate] = CatalystZero(zero = SQLDate(0))
implicit val framelessZeroSQLTimestamp: CatalystZero[SQLTimestamp] = CatalystZero(zero = SQLTimestamp(0))
implicit val framelessZeroString : CatalystZero[String] = CatalystZero(zero = "")
implicit val framelessZeroInstant : CatalystZero[Instant] = CatalystZero(zero = Instant.now())
implicit val framelessZeroDuration : CatalystZero[Duration] = CatalystZero(zero = Duration.ZERO)
implicit val framelessZeroPeriod : CatalystZero[Period] = CatalystZero(zero = Period.ZERO)

object ZeroPoly extends Poly1 {

implicit def caseZero[U: CatalystZero] =
at[CatalystZero[U]](c => c.zero)

}

implicit def deriveGeneric[G, H <: HList, O <: HList]
(implicit
i0: Generic.Aux[G, H],
i1: Lazy[LiftAll.Aux[CatalystZero, H, O]],
i2: Mapper.Aux[ZeroPoly.type, O, H],
): CatalystZero[G] = CatalystZero(i0.from(i1.value.instances.map(ZeroPoly)))
}
13 changes: 10 additions & 3 deletions dataset/src/main/scala/frameless/TypedDataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -131,14 +131,21 @@ class TypedDataset[T] protected[frameless] (
(c, i) <- underlyingColumns.zipWithIndex
if !c.uencoder.nullable
} yield s"_${i + 1} is not null"
).mkString(" or ")
).mkString(" and ")

val selected = dataset
.toDF()
.agg(cols.head, cols.tail: _*)
.as[Out](TypedExpressionEncoder[Out])

// spark4 really likes types correct, only select after filtering out rows
val filtered =
if (filterStr.isEmpty)
selected
else
selected.filter(filterStr)

TypedDataset.create[Out](
if (filterStr.isEmpty) selected else selected.filter(filterStr)
filtered.as[Out](TypedExpressionEncoder[Out])
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -654,9 +654,12 @@ trait NonAggregateFunctions {
def factorial[T](
column: AbstractTypedColumn[T, Long]
)(implicit
i0: TypedEncoder[Long]
): column.ThisType[T, Long] =
column.typed(sparkFunctions.factorial(column.untyped))
i0: TypedEncoder[Long],
i1: CatalystZero[Long]
): column.ThisType[T, Long] = {
val factOrZero = nullToZero(sparkFunctions.factorial(column.untyped))
column.typed(factOrZero)
}

/**
* Non-Aggregate function: Computes bitwise NOT.
Expand Down
13 changes: 13 additions & 0 deletions dataset/src/main/scala/frameless/functions/package.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package frameless

import frameless.{reflection => ScalaReflection}
import org.apache.spark.sql.{Column, functions => sparkFunctions}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder

import scala.reflect.ClassTag
Expand All @@ -12,6 +13,18 @@ import org.apache.spark.sql.catalyst.expressions.Literal

package object functions extends Udf with UnaryFunctions {

/**
* provides a non-null expression from a nullable one
* @param nullable the expression which produces nulls
* @tparam T
* @return
*/
def nullToZero[T: CatalystZero](nullable: Column): Column = {
val zeroExpr = sparkFunctions.lit(implicitly[CatalystZero[T]].zero)
val orZero = sparkFunctions.coalesce(nullable, zeroExpr)
orZero
}

object aggregate extends AggregateFunctions
object nonAggregate extends NonAggregateFunctions

Expand Down
4 changes: 3 additions & 1 deletion dataset/src/test/scala/frameless/GroupByTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ class GroupByTests extends TypedDatasetSuite {
val listMinC = if(data.isEmpty) implicitly[Numeric[C]].fromInt(0) else data.map(_.c).min
val listMaxD = if(data.isEmpty) implicitly[Numeric[D]].fromInt(0) else data.map(_.d).max

datasetSum ?= Vector(if (data.isEmpty) null else (listSumA, listSumB, listMinC, listMaxD))
datasetSum ?= {
if (data.isEmpty) Vector.empty else Vector((listSumA, listSumB, listMinC, listMaxD))
}
}

check(forAll(prop[Long, Long, Long, Int] _))
Expand Down

0 comments on commit 9c2dd30

Please sign in to comment.