From 897e499649b80bb8e326267e66218593b1e788ac Mon Sep 17 00:00:00 2001 From: Itamar Ravid Date: Fri, 22 Sep 2017 11:59:09 +0300 Subject: [PATCH] Add a typed `col` function for creating column references Resolves #186. --- .../src/main/scala/frameless/functions/package.scala | 10 ++++++++++ dataset/src/test/scala/frameless/SelectTests.scala | 3 ++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/dataset/src/main/scala/frameless/functions/package.scala b/dataset/src/main/scala/frameless/functions/package.scala index ccecdb49..7f6b58ed 100644 --- a/dataset/src/main/scala/frameless/functions/package.scala +++ b/dataset/src/main/scala/frameless/functions/package.scala @@ -2,6 +2,8 @@ package frameless import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.functions.{ col => sparkCol } +import shapeless.Witness package object functions extends Udf with UnaryFunctions { object aggregate extends AggregateFunctions @@ -17,4 +19,12 @@ package object functions extends Udf with UnaryFunctions { new TypedColumn(expr) } } + + def col[T, A](column: Witness.Lt[Symbol])( + implicit + exists: TypedColumn.Exists[T, column.T, A], + encoder: TypedEncoder[A]): TypedColumn[T, A] = { + val untypedExpr = sparkCol(column.value.name).as[A](TypedExpressionEncoder[A]) + new TypedColumn[T, A](untypedExpr) + } } diff --git a/dataset/src/test/scala/frameless/SelectTests.scala b/dataset/src/test/scala/frameless/SelectTests.scala index d90466cc..9ab842ad 100644 --- a/dataset/src/test/scala/frameless/SelectTests.scala +++ b/dataset/src/test/scala/frameless/SelectTests.scala @@ -18,9 +18,10 @@ class SelectTests extends TypedDatasetSuite { val A = dataset.col[A]('a) val dataset2 = dataset.select(A).collect().run().toVector + val symDataset2 = dataset.select(functions.col('a)).collect().run().toVector val data2 = data.map { case X4(a, _, _, _) => a } - dataset2 ?= data2 + (dataset2 ?= data2) && (symDataset2 ?= data2) } check(forAll(prop[Int, Int, Int, Int] _))