Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ClassTags to types that depend on Spark's Serializer. #334

Open
wants to merge 14 commits into
base: spark-1.0
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions project/SharkBuild.scala
Original file line number Diff line number Diff line change
@@ -203,6 +203,12 @@ object SharkBuild extends Build {
// See https://code.google.com/p/guava-libraries/issues/detail?id=1095
"com.google.code.findbugs" % "jsr305" % "1.3.+",

// sbt fails down download the javax.servlet artifacts from jetty 8.1:
// http://mvnrepository.com/artifact/org.eclipse.jetty.orbit/javax.servlet/3.0.0.v201112011016
// which may be due to the use of the orbit extension. So, we manually include servlet api
// from a separate source.
"org.mortbay.jetty" % "servlet-api" % "3.0.20100224",

// Hive unit test requirements. These are used by Hadoop to run the tests, but not necessary
// in usual Shark runs.
"commons-io" % "commons-io" % "2.1",
4 changes: 3 additions & 1 deletion src/main/scala/shark/SharkCliDriver.scala
Original file line number Diff line number Diff line change
@@ -162,7 +162,9 @@ object SharkCliDriver {
val cli = new SharkCliDriver(reloadRdds)
cli.setHiveVariables(oproc.getHiveVariables())

SharkEnv.fixUncompatibleConf(conf)
if (!ss.isRemoteMode) {
SharkEnv.fixUncompatibleConf(conf)
}

// Execute -i init files (always in silent mode)
cli.processInitFiles(ss)
6 changes: 3 additions & 3 deletions src/main/scala/shark/execution/LateralViewJoinOperator.scala
Original file line number Diff line number Diff line change
@@ -21,7 +21,7 @@ import java.nio.ByteBuffer

import scala.collection.mutable.ArrayBuffer
import scala.collection.JavaConversions._
import scala.reflect.BeanProperty
import scala.reflect.{BeanProperty, ClassTag}

import org.apache.commons.codec.binary.Base64
import org.apache.hadoop.hive.ql.exec.{ExprNodeEvaluator, ExprNodeEvaluatorFactory}
@@ -174,12 +174,12 @@ object KryoSerializerToString {

@transient val kryoSer = new SparkKryoSerializer(SparkEnv.get.conf)

def serialize[T](o: T): String = {
def serialize[T: ClassTag](o: T): String = {
val bytes = kryoSer.newInstance().serialize(o).array()
new String(Base64.encodeBase64(bytes))
}

def deserialize[T](byteString: String): T = {
def deserialize[T: ClassTag](byteString: String): T = {
val bytes = Base64.decodeBase64(byteString.getBytes())
kryoSer.newInstance().deserialize[T](ByteBuffer.wrap(bytes))
}
9 changes: 6 additions & 3 deletions src/main/scala/shark/execution/MapSplitPruning.scala
Original file line number Diff line number Diff line change
@@ -67,7 +67,7 @@ object MapSplitPruning {
true
}

case _: GenericUDFIn =>
case _: GenericUDFIn if e.children(0).isInstanceOf[ExprNodeColumnEvaluator] =>
testInPredicate(
s,
e.children(0).asInstanceOf[ExprNodeColumnEvaluator],
@@ -91,10 +91,13 @@ object MapSplitPruning {
val columnStats = s.stats(field.fieldID)

if (columnStats != null) {
expEvals.exists {
e =>
expEvals.exists { e =>
if (e.isInstanceOf[ExprNodeConstantEvaluator]) {
val constEval = e.asInstanceOf[ExprNodeConstantEvaluator]
columnStats := constEval.expr.getValue()
} else {
true
}
}
} else {
// If there is no stats on the column, don't prune.
Original file line number Diff line number Diff line change
@@ -19,18 +19,20 @@ package shark.execution.serialization

import java.nio.ByteBuffer

import scala.reflect.ClassTag

import org.apache.spark.SparkEnv
import org.apache.spark.serializer.{JavaSerializer => SparkJavaSerializer}


object JavaSerializer {
@transient val ser = new SparkJavaSerializer(SparkEnv.get.conf)

def serialize[T](o: T): Array[Byte] = {
def serialize[T: ClassTag](o: T): Array[Byte] = {
ser.newInstance().serialize(o).array()
}

def deserialize[T](bytes: Array[Byte]): T = {
def deserialize[T: ClassTag](bytes: Array[Byte]): T = {
ser.newInstance().deserialize[T](ByteBuffer.wrap(bytes))
}
}
Original file line number Diff line number Diff line change
@@ -17,13 +17,15 @@

package shark.execution.serialization

import scala.reflect.ClassTag

/**
* A wrapper around some unserializable objects that make them both Java
* serializable. Internally, Kryo is used for serialization.
*
* Use KryoSerializationWrapper(value) to create a wrapper.
*/
class KryoSerializationWrapper[T] extends Serializable {
class KryoSerializationWrapper[T: ClassTag] extends Serializable {

@transient var value: T = _

@@ -54,7 +56,7 @@ class KryoSerializationWrapper[T] extends Serializable {


object KryoSerializationWrapper {
def apply[T](value: T): KryoSerializationWrapper[T] = {
def apply[T: ClassTag](value: T): KryoSerializationWrapper[T] = {
val wrapper = new KryoSerializationWrapper[T]
wrapper.value = value
wrapper
Original file line number Diff line number Diff line change
@@ -19,6 +19,8 @@ package shark.execution.serialization

import java.nio.ByteBuffer

import scala.reflect.ClassTag

import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.serializer.{KryoSerializer => SparkKryoSerializer}

@@ -36,11 +38,11 @@ object KryoSerializer {
new SparkKryoSerializer(sparkConf)
}

def serialize[T](o: T): Array[Byte] = {
def serialize[T: ClassTag](o: T): Array[Byte] = {
ser.newInstance().serialize(o).array()
}

def deserialize[T](bytes: Array[Byte]): T = {
def deserialize[T: ClassTag](bytes: Array[Byte]): T = {
ser.newInstance().deserialize[T](ByteBuffer.wrap(bytes))
}
}
Original file line number Diff line number Diff line change
@@ -20,6 +20,8 @@ package shark.execution.serialization
import java.io.{InputStream, OutputStream}
import java.nio.ByteBuffer

import scala.reflect.ClassTag

import org.apache.hadoop.io.BytesWritable

import org.apache.spark.SparkConf
@@ -60,11 +62,11 @@ class ShuffleSerializer(conf: SparkConf) extends Serializer with Serializable {

class ShuffleSerializerInstance extends SerializerInstance with Serializable {

override def serialize[T](t: T): ByteBuffer = throw new UnsupportedOperationException
override def serialize[T: ClassTag](t: T): ByteBuffer = throw new UnsupportedOperationException

override def deserialize[T](bytes: ByteBuffer): T = throw new UnsupportedOperationException
override def deserialize[T: ClassTag](bytes: ByteBuffer): T = throw new UnsupportedOperationException

override def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T =
override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T =
throw new UnsupportedOperationException

override def serializeStream(s: OutputStream): SerializationStream = {
@@ -79,7 +81,7 @@ class ShuffleSerializerInstance extends SerializerInstance with Serializable {

class ShuffleSerializationStream(stream: OutputStream) extends SerializationStream with Serializable {

override def writeObject[T](t: T): SerializationStream = {
override def writeObject[T: ClassTag](t: T): SerializationStream = {
// On the write-side, the ReduceKey should be of type ReduceKeyMapSide.
val (key, value) = t.asInstanceOf[(ReduceKey, BytesWritable)]
writeUnsignedVarInt(key.length)
@@ -110,7 +112,7 @@ class ShuffleSerializationStream(stream: OutputStream) extends SerializationStre

class ShuffleDeserializationStream(stream: InputStream) extends DeserializationStream with Serializable {

override def readObject[T](): T = {
override def readObject[T: ClassTag](): T = {
// Return type is (ReduceKeyReduceSide, Array[Byte])
val keyLen = readUnsignedVarInt()
if (keyLen < 0) {
1 change: 1 addition & 0 deletions src/main/scala/shark/memstore2/TableRecovery.scala
Original file line number Diff line number Diff line change
@@ -58,6 +58,7 @@ object TableRecovery extends LogHelper {
logInfo(logMessage)
}
val cmd = QueryRewriteUtils.cacheToAlterTable("CACHE %s".format(tableName))
cmdRunner(s"use $databaseName")
cmdRunner(cmd)
}
}
Original file line number Diff line number Diff line change
@@ -88,7 +88,7 @@ class RLDecoder[V](buffer: ByteBuffer, columnType: ColumnType[_, V]) extends Ite
private var _count: Int = 0
private val _current: V = columnType.newWritable()

override def hasNext = buffer.hasRemaining()
override def hasNext = _count < _run || buffer.hasRemaining()

override def next(): V = {
if (_count == _run) {
60 changes: 60 additions & 0 deletions src/main/scala/shark/optimizer/SharkMapJoinProcessor.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright (C) 2012 The Regents of The University California.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package shark.optimizer

import java.util.{LinkedHashMap => JavaLinkedHashMap}

import org.apache.hadoop.hive.ql.exec.{MapJoinOperator, JoinOperator, Operator}
import org.apache.hadoop.hive.ql.optimizer.MapJoinProcessor
import org.apache.hadoop.hive.ql.parse.{ParseContext, QBJoinTree, OpParseContext}
import org.apache.hadoop.hive.ql.plan.OperatorDesc
import org.apache.hadoop.hive.conf.HiveConf

class SharkMapJoinProcessor extends MapJoinProcessor {

/**
* Override generateMapJoinOperator to bypass the step of validating Map Join hints int Hive.
*/
override def generateMapJoinOperator(
pctx: ParseContext,
op: JoinOperator,
joinTree: QBJoinTree,
mapJoinPos: Int): MapJoinOperator = {
val hiveConf: HiveConf = pctx.getConf
val noCheckOuterJoin: Boolean =
HiveConf.getBoolVar(hiveConf, HiveConf.ConfVars.HIVEOPTSORTMERGEBUCKETMAPJOIN) &&
HiveConf.getBoolVar(hiveConf, HiveConf.ConfVars.HIVEOPTBUCKETMAPJOIN)

val opParseCtxMap: JavaLinkedHashMap[Operator[_ <: OperatorDesc], OpParseContext] =
pctx.getOpParseCtx

// Explicitly set validateMapJoinTree to false to bypass the step of validating
// Map Join hints in Hive.
val validateMapJoinTree = false
val mapJoinOp: MapJoinOperator =
MapJoinProcessor.convertMapJoin(
opParseCtxMap, op, joinTree, mapJoinPos, noCheckOuterJoin, validateMapJoinTree)

// Hive originally uses genSelectPlan to insert an dummy select after the MapJoinOperator.
// We should not need this step.
// create a dummy select to select all columns
// MapJoinProcessor.genSelectPlan(pctx, mapJoinOp)

return mapJoinOp
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (C) 2012 The Regents of The University California.
* Copyright (C) 2012 The Regents of The University California.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,14 +15,15 @@
* limitations under the License.
*/

package shark
package shark.optimizer

import java.util.{List => JavaList}

import org.apache.hadoop.hive.ql.optimizer.JoinReorder
import org.apache.hadoop.hive.ql.optimizer.{Optimizer => HiveOptimizer,
SimpleFetchOptimizer, Transform}
import org.apache.hadoop.hive.ql.parse.{ParseContext}
SimpleFetchOptimizer, Transform, MapJoinProcessor => HiveMapJoinProcessor}
import org.apache.hadoop.hive.ql.parse.ParseContext
import shark.LogHelper

class SharkOptimizer extends HiveOptimizer with LogHelper {

@@ -49,6 +50,13 @@ class SharkOptimizer extends HiveOptimizer with LogHelper {
transformation match {
case _: SimpleFetchOptimizer => {}
case _: JoinReorder => {}
case _: HiveMapJoinProcessor => {
// Use SharkMapJoinProcessor to bypass the step of validating Map Join hints
// in Hive. So, we can use hints to mark tables that will be considered as small
// tables (like Hive 0.9).
val sharkMapJoinProcessor = new SharkMapJoinProcessor
pctx = sharkMapJoinProcessor.transform(pctx)
}
case _ => {
pctx = transformation.transform(pctx)
}
3 changes: 2 additions & 1 deletion src/main/scala/shark/parse/SharkSemanticAnalyzer.scala
Original file line number Diff line number Diff line change
@@ -38,11 +38,12 @@ import org.apache.hadoop.hive.ql.parse._
import org.apache.hadoop.hive.ql.plan._
import org.apache.hadoop.hive.ql.session.SessionState

import shark.{LogHelper, SharkConfVars, SharkOptimizer}
import shark.{LogHelper, SharkConfVars}
import shark.execution.{HiveDesc, Operator, OperatorFactory, ReduceSinkOperator}
import shark.execution.{SharkDDLWork, SparkLoadWork, SparkWork, TerminalOperator}
import shark.memstore2.{CacheType, LazySimpleSerDeWrapper, MemoryMetadataManager}
import shark.memstore2.SharkTblProperties
import shark.optimizer.SharkOptimizer


/**
5 changes: 5 additions & 0 deletions src/test/scala/shark/SQLSuite.scala
Original file line number Diff line number Diff line change
@@ -718,6 +718,11 @@ class SQLSuite extends FunSuite {
where year(from_unixtime(k)) between "2013" and "2014" """, Array[String]("0"))
}

test("map pruning with functions in in clause") {
expectSql("""select count(*) from mapsplitfunc_cached
where year(from_unixtime(k)) in ("2013", concat("201", "4")) """, Array[String]("0"))
}

//////////////////////////////////////////////////////////////////////////////
// SharkContext APIs (e.g. sql2rdd, sql)
//////////////////////////////////////////////////////////////////////////////
Original file line number Diff line number Diff line change
@@ -78,6 +78,7 @@ class CompressedColumnIteratorSuite extends FunSuite {
}

l.foreach { x =>
assert(iter.hasNext)
iter.next()
assert(t.get(iter.current, oi) === x)
}