diff --git a/core/src/main/scala/org/apache/spark/sql/execution/datasources/hbase/HBaseFilter.scala b/core/src/main/scala/org/apache/spark/sql/execution/datasources/hbase/HBaseFilter.scala index 8ee281bb..7b9e2530 100644 --- a/core/src/main/scala/org/apache/spark/sql/execution/datasources/hbase/HBaseFilter.scala +++ b/core/src/main/scala/org/apache/spark/sql/execution/datasources/hbase/HBaseFilter.scala @@ -369,13 +369,22 @@ object HBaseFilter extends Logging{ case In(attribute: String, values: Array[Any]) => //converting a "key in (x1, x2, x3..) filter to (key == x1) or (key == x2) or ... val ranges = new ArrayBuffer[ScanRange[Array[Byte]]]() + var typedFilters = ArrayBuffer[TypedFilter]() values.foreach{ value => val sparkFilter = EqualTo(attribute, value) val hbaseFilter = buildFilter(sparkFilter, relation) - ranges ++= hbaseFilter.ranges + ranges ++= hbaseFilter.ranges.filter(_ != ScanRange.empty[Array[Byte]]) + typedFilters += hbaseFilter.tf } - HRF[Array[Byte]](ranges.toArray, TypedFilter.empty, handled = true) + val resultingTypedFilter = typedFilters.foldLeft(TypedFilter.empty){ + (acc, tf) => acc match { + case TypedFilter(None, FilterType.Und) => tf + case _ => TypedFilter.or(acc, tf) + } + } + val resultingRanges = if (ranges.isEmpty) Array(ScanRange.empty[Array[Byte]]) else ranges.toArray + HRF[Array[Byte]](resultingRanges, resultingTypedFilter, handled = true) case Not(In(attribute: String, values: Array[Any])) => //converting a "not(key in (x1, x2, x3..)) filter to (key != x1) and (key != x2) and .. val hrf = values.map{v => buildFilter(Not(EqualTo(attribute, v)),relation)} diff --git a/core/src/test/scala/org/apache/spark/sql/DefaultSourceSuite.scala b/core/src/test/scala/org/apache/spark/sql/DefaultSourceSuite.scala index 1093a119..1b1b62f7 100644 --- a/core/src/test/scala/org/apache/spark/sql/DefaultSourceSuite.scala +++ b/core/src/test/scala/org/apache/spark/sql/DefaultSourceSuite.scala @@ -153,6 +153,18 @@ class DefaultSourceSuite extends SHC with Logging { assert(c == 256) } + test("IN filter for column") { + val df = withCatalog(catalog) + val s = df.filter($"col4" isin (4, 5, 6)).select("col0") + assert(s.count() == 3) + } + + test("IN filter for rowkey") { + val df = withCatalog(catalog) + val s = df.filter($"col0" isin ("row005", "row001", "row002")).select("col0") + assert(s.count() == 3) + } + test("IN and Not IN filter1") { val df = withCatalog(catalog) val s = df.filter(($"col0" isin ("row005", "row001", "row002")) and !($"col0" isin ("row001", "row002"))) @@ -171,7 +183,7 @@ class DefaultSourceSuite extends SHC with Logging { assert(s.count() == 1) } - test("IN filter stack overflow") { + test("IN filter rowkey stack overflow") { val df = withCatalog(catalog) val items = (0 to 2000).map{i => s"xaz$i"} val filterInItems = Seq("row001") ++: items @@ -182,6 +194,18 @@ class DefaultSourceSuite extends SHC with Logging { assert(s.count() == 1) } + test("IN filter column stack overflow") { + val df = withCatalog(catalog) + val df_size = df.count() + val items = (0 to 2000).map(_ + df_size + 1) + val filterInItems = Seq(1) ++: items + + val s = df.filter($"col4" isin(filterInItems:_*)).select("col0") + s.explain(true) + s.show() + assert(s.count() == 1) + } + test("NOT IN filter stack overflow") { val df = withCatalog(catalog) val items = (0 to 2000).map{i => s"xaz$i"}