From 9c2dd30d0484f1088f92b4693f59ab82cd5d3527 Mon Sep 17 00:00:00 2001 From: Chris Twiner Date: Fri, 7 Feb 2025 21:19:09 +0100 Subject: [PATCH] 4 tests failing - encoding derivation issue --- .../scala/frameless/CatalystSummable.scala | 4 +- .../main/scala/frameless/CatalystZero.scala | 53 +++++++++++++++++++ .../main/scala/frameless/TypedDataset.scala | 13 +++-- .../functions/NonAggregateFunctions.scala | 9 ++-- .../scala/frameless/functions/package.scala | 13 +++++ .../test/scala/frameless/GroupByTests.scala | 4 +- 6 files changed, 86 insertions(+), 10 deletions(-) create mode 100644 core/src/main/scala/frameless/CatalystZero.scala diff --git a/core/src/main/scala/frameless/CatalystSummable.scala b/core/src/main/scala/frameless/CatalystSummable.scala index 94010505..64443910 100644 --- a/core/src/main/scala/frameless/CatalystSummable.scala +++ b/core/src/main/scala/frameless/CatalystSummable.scala @@ -13,9 +13,7 @@ import scala.annotation.implicitNotFound * - Short -> Long */ @implicitNotFound("Cannot compute sum of type ${In}.") -trait CatalystSummable[In, Out] { - def zero: In -} +trait CatalystSummable[In, Out] extends CatalystZero[In] object CatalystSummable { def apply[In, Out](zero: In): CatalystSummable[In, Out] = { diff --git a/core/src/main/scala/frameless/CatalystZero.scala b/core/src/main/scala/frameless/CatalystZero.scala new file mode 100644 index 00000000..b7a7ee56 --- /dev/null +++ b/core/src/main/scala/frameless/CatalystZero.scala @@ -0,0 +1,53 @@ +package frameless + +import shapeless.{Generic, HList, Lazy, Poly1} +import shapeless.ops.hlist.{LiftAll, Mapper} + +import java.time.{Duration, Instant, Period} +import scala.annotation.implicitNotFound + +/** Types that can be provided with zero's by Catalyst - no zero's were hurt during this + * Used by min/first/ etc. to provide generalised coalesce, the zero value is used to + * represent null handling only and should not be received, if it is desired then you + * must break free of TypedDataset's guardrails. + */ +@implicitNotFound("Cannot provide zero value columns of type ${A}.") +abstract class CatalystZero[A](implicit ev: NotCatalystNullable[A]) { + def zero: A +} + +object CatalystZero { + def apply[A: NotCatalystNullable](zero: A): CatalystZero[A] = { + val _zero = zero + new CatalystZero[A] { val zero: A = _zero } + } + + implicit val framelessZeroLong : CatalystZero[Long] = CatalystZero(zero = 0L) + implicit val framelessZeroBigDecimal: CatalystZero[BigDecimal] = CatalystZero(zero = BigDecimal(0)) + implicit val framelessZeroDouble : CatalystZero[Double] = CatalystZero(zero = 0.0) + implicit val framelessZeroInt : CatalystZero[Int] = CatalystZero(zero = 0) + implicit val framelessZeroShort : CatalystZero[Short] = CatalystZero(zero = 0) + implicit val framelessZeroBooleanOrdered : CatalystZero[Boolean] = CatalystZero(zero = false) + implicit val framelessZeroByte : CatalystZero[Byte] = CatalystZero(zero = 0) + implicit val framelessZeroFloat : CatalystZero[Float] = CatalystZero(zero = 0.0f) + implicit val framelessZeroSQLDate : CatalystZero[SQLDate] = CatalystZero(zero = SQLDate(0)) + implicit val framelessZeroSQLTimestamp: CatalystZero[SQLTimestamp] = CatalystZero(zero = SQLTimestamp(0)) + implicit val framelessZeroString : CatalystZero[String] = CatalystZero(zero = "") + implicit val framelessZeroInstant : CatalystZero[Instant] = CatalystZero(zero = Instant.now()) + implicit val framelessZeroDuration : CatalystZero[Duration] = CatalystZero(zero = Duration.ZERO) + implicit val framelessZeroPeriod : CatalystZero[Period] = CatalystZero(zero = Period.ZERO) + + object ZeroPoly extends Poly1 { + + implicit def caseZero[U: CatalystZero] = + at[CatalystZero[U]](c => c.zero) + + } + + implicit def deriveGeneric[G, H <: HList, O <: HList] + (implicit + i0: Generic.Aux[G, H], + i1: Lazy[LiftAll.Aux[CatalystZero, H, O]], + i2: Mapper.Aux[ZeroPoly.type, O, H], + ): CatalystZero[G] = CatalystZero(i0.from(i1.value.instances.map(ZeroPoly))) +} diff --git a/dataset/src/main/scala/frameless/TypedDataset.scala b/dataset/src/main/scala/frameless/TypedDataset.scala index f14e274f..9ba6972c 100644 --- a/dataset/src/main/scala/frameless/TypedDataset.scala +++ b/dataset/src/main/scala/frameless/TypedDataset.scala @@ -131,14 +131,21 @@ class TypedDataset[T] protected[frameless] ( (c, i) <- underlyingColumns.zipWithIndex if !c.uencoder.nullable } yield s"_${i + 1} is not null" - ).mkString(" or ") + ).mkString(" and ") val selected = dataset .toDF() .agg(cols.head, cols.tail: _*) - .as[Out](TypedExpressionEncoder[Out]) + + // spark4 really likes types correct, only select after filtering out rows + val filtered = + if (filterStr.isEmpty) + selected + else + selected.filter(filterStr) + TypedDataset.create[Out]( - if (filterStr.isEmpty) selected else selected.filter(filterStr) + filtered.as[Out](TypedExpressionEncoder[Out]) ) } } diff --git a/dataset/src/main/scala/frameless/functions/NonAggregateFunctions.scala b/dataset/src/main/scala/frameless/functions/NonAggregateFunctions.scala index 3ee42f45..9c638fc3 100644 --- a/dataset/src/main/scala/frameless/functions/NonAggregateFunctions.scala +++ b/dataset/src/main/scala/frameless/functions/NonAggregateFunctions.scala @@ -654,9 +654,12 @@ trait NonAggregateFunctions { def factorial[T]( column: AbstractTypedColumn[T, Long] )(implicit - i0: TypedEncoder[Long] - ): column.ThisType[T, Long] = - column.typed(sparkFunctions.factorial(column.untyped)) + i0: TypedEncoder[Long], + i1: CatalystZero[Long] + ): column.ThisType[T, Long] = { + val factOrZero = nullToZero(sparkFunctions.factorial(column.untyped)) + column.typed(factOrZero) + } /** * Non-Aggregate function: Computes bitwise NOT. diff --git a/dataset/src/main/scala/frameless/functions/package.scala b/dataset/src/main/scala/frameless/functions/package.scala index c567aeff..f66cf044 100644 --- a/dataset/src/main/scala/frameless/functions/package.scala +++ b/dataset/src/main/scala/frameless/functions/package.scala @@ -1,6 +1,7 @@ package frameless import frameless.{reflection => ScalaReflection} +import org.apache.spark.sql.{Column, functions => sparkFunctions} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import scala.reflect.ClassTag @@ -12,6 +13,18 @@ import org.apache.spark.sql.catalyst.expressions.Literal package object functions extends Udf with UnaryFunctions { + /** + * provides a non-null expression from a nullable one + * @param nullable the expression which produces nulls + * @tparam T + * @return + */ + def nullToZero[T: CatalystZero](nullable: Column): Column = { + val zeroExpr = sparkFunctions.lit(implicitly[CatalystZero[T]].zero) + val orZero = sparkFunctions.coalesce(nullable, zeroExpr) + orZero + } + object aggregate extends AggregateFunctions object nonAggregate extends NonAggregateFunctions diff --git a/dataset/src/test/scala/frameless/GroupByTests.scala b/dataset/src/test/scala/frameless/GroupByTests.scala index 7178def3..a5a03bd6 100644 --- a/dataset/src/test/scala/frameless/GroupByTests.scala +++ b/dataset/src/test/scala/frameless/GroupByTests.scala @@ -120,7 +120,9 @@ class GroupByTests extends TypedDatasetSuite { val listMinC = if(data.isEmpty) implicitly[Numeric[C]].fromInt(0) else data.map(_.c).min val listMaxD = if(data.isEmpty) implicitly[Numeric[D]].fromInt(0) else data.map(_.d).max - datasetSum ?= Vector(if (data.isEmpty) null else (listSumA, listSumB, listMinC, listMaxD)) + datasetSum ?= { + if (data.isEmpty) Vector.empty else Vector((listSumA, listSumB, listMinC, listMaxD)) + } } check(forAll(prop[Long, Long, Long, Int] _))