Skip to content

Commit

Permalink
Further refactoring, and start of a standalone scheduler backend
Browse files Browse the repository at this point in the history
  • Loading branch information
mateiz committed Jul 7, 2012
1 parent 4e2fe0b commit 909b325
Show file tree
Hide file tree
Showing 13 changed files with 211 additions and 39 deletions.
6 changes: 3 additions & 3 deletions core/src/main/scala/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ import spark.scheduler.DAGScheduler
import spark.scheduler.TaskScheduler
import spark.scheduler.local.LocalScheduler
import spark.scheduler.cluster.ClusterScheduler
import spark.scheduler.mesos.MesosScheduler
import spark.scheduler.mesos.MesosSchedulerBackend
import spark.storage.BlockManagerMaster

class SparkContext(
Expand Down Expand Up @@ -90,14 +90,14 @@ class SparkContext(
case _ =>
MesosNativeLibrary.load()
val sched = new ClusterScheduler(this)
val schedContext = new MesosScheduler(sched, this, master, frameworkName)
val schedContext = new MesosSchedulerBackend(sched, this, master, frameworkName)
sched.initialize(schedContext)
sched
/*
if (System.getProperty("spark.mesos.coarse", "false") == "true") {
new CoarseMesosScheduler(this, master, frameworkName)
} else {
new MesosScheduler(this, master, frameworkName)
new MesosSchedulerBackend(this, master, frameworkName)
}
*/
}
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ class Executor extends Logging {
1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable])
}

def launchTask(context: ExecutorContext, taskId: Long, serializedTask: ByteBuffer) {
def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) {
threadPool.execute(new TaskRunner(context, taskId, serializedTask))
}

class TaskRunner(context: ExecutorContext, taskId: Long, serializedTask: ByteBuffer)
class TaskRunner(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer)
extends Runnable {

override def run() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import java.nio.ByteBuffer
import spark.TaskState.TaskState

/**
* Interface used by Executor to send back updates to the cluster scheduler.
* A pluggable interface used by the Executor to send updates to the cluster scheduler.
*/
trait ExecutorContext {
trait ExecutorBackend {
def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer)
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ import com.google.protobuf.ByteString
import spark.{Utils, Logging}
import spark.TaskState

class MesosExecutorRunner(executor: Executor)
class MesosExecutorBackend(executor: Executor)
extends MesosExecutor
with ExecutorContext
with ExecutorBackend
with Logging {

var driver: ExecutorDriver = null
Expand Down Expand Up @@ -59,11 +59,11 @@ class MesosExecutorRunner(executor: Executor)
/**
* Entry point for Mesos executor.
*/
object MesosExecutorRunner {
object MesosExecutorBackend {
def main(args: Array[String]) {
MesosNativeLibrary.load()
// Create a new Executor and start it running
val runner = new MesosExecutorRunner(new Executor)
val runner = new MesosExecutorBackend(new Executor)
new MesosExecutorDriver(runner).run()
}
}
24 changes: 12 additions & 12 deletions core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,23 +50,23 @@ class ClusterScheduler(sc: SparkContext)
// Listener object to pass upcalls into
var listener: TaskSchedulerListener = null

var schedContext: ClusterSchedulerContext = null
var backend: SchedulerBackend = null

val mapOutputTracker = SparkEnv.get.mapOutputTracker

override def setListener(listener: TaskSchedulerListener) {
this.listener = listener
}

def initialize(context: ClusterSchedulerContext) {
schedContext = context
def initialize(context: SchedulerBackend) {
backend = context
createJarServer()
}

def newTaskId(): Long = nextTaskId.getAndIncrement()

override def start() {
schedContext.start()
backend.start()

if (System.getProperty("spark.speculation", "false") == "true") {
new Thread("ClusterScheduler speculation check") {
Expand Down Expand Up @@ -95,7 +95,7 @@ class ClusterScheduler(sc: SparkContext)
activeTaskSetsQueue += manager
taskSetTaskIds(taskSet.id) = new HashSet[Long]()
}
schedContext.reviveOffers()
backend.reviveOffers()
}

def taskSetFinished(manager: TaskSetManager) {
Expand Down Expand Up @@ -197,11 +197,11 @@ class ClusterScheduler(sc: SparkContext)
}
if (failedHost != None) {
listener.hostLost(failedHost.get)
schedContext.reviveOffers()
backend.reviveOffers()
}
if (taskFailed) {
// Also revive offers if a task had failed for some reason other than host lost
schedContext.reviveOffers()
backend.reviveOffers()
}
}

Expand All @@ -227,15 +227,15 @@ class ClusterScheduler(sc: SparkContext)
}

override def stop() {
if (schedContext != null) {
schedContext.stop()
if (backend != null) {
backend.stop()
}
if (jarServer != null) {
jarServer.stop()
}
}

override def defaultParallelism() = schedContext.defaultParallelism()
override def defaultParallelism() = backend.defaultParallelism()

// Create a server for all the JARs added by the user to SparkContext.
// We first copy the JARs to a temp directory for easier server setup.
Expand Down Expand Up @@ -271,7 +271,7 @@ class ClusterScheduler(sc: SparkContext)
}
}
if (shouldRevive) {
schedContext.reviveOffers()
backend.reviveOffers()
}
}

Expand All @@ -288,7 +288,7 @@ class ClusterScheduler(sc: SparkContext)
}
if (failedHost != None) {
listener.hostLost(failedHost.get)
schedContext.reviveOffers()
backend.reviveOffers()
}
}
}

This file was deleted.

15 changes: 15 additions & 0 deletions core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package spark.scheduler.cluster

/**
* A backend interface for cluster scheduling systems that allows plugging in different ones under
* ClusterScheduler. We assume a Mesos-like model where the application gets resource offers as
* machines become available and can launch tasks on them.
*/
trait SchedulerBackend {
def start(): Unit
def stop(): Unit
def reviveOffers(): Unit
def defaultParallelism(): Int

// TODO: Probably want to add a killTask too
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
package spark.scheduler.cluster

import java.nio.channels.Channels
import java.nio.ByteBuffer
import java.io.{IOException, EOFException, ObjectOutputStream, ObjectInputStream}
import spark.util.SerializableByteBuffer

class TaskDescription(val taskId: Long, val name: String, val serializedTask: ByteBuffer) {}
class TaskDescription(val taskId: Long, val name: String, _serializedTask: ByteBuffer)
extends Serializable {

// Because ByteBuffers are not serializable, we wrap the task in a SerializableByteBuffer
private val buffer = new SerializableByteBuffer(_serializedTask)

def serializedTask: ByteBuffer = buffer.value
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ import java.util.{ArrayList => JArrayList, List => JList}
import java.util.Collections
import spark.TaskState

class MesosScheduler(
class MesosSchedulerBackend(
scheduler: ClusterScheduler,
sc: SparkContext,
master: String,
frameworkName: String)
extends ClusterSchedulerContext
extends SchedulerBackend
with MScheduler
with Logging {

Expand Down Expand Up @@ -58,11 +58,11 @@ class MesosScheduler(

override def start() {
synchronized {
new Thread("MesosScheduler driver") {
new Thread("MesosSchedulerBackend driver") {
setDaemon(true)

override def run() {
val sched = MesosScheduler.this
val sched = MesosSchedulerBackend.this
val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(frameworkName).build()
driver = new MesosSchedulerDriver(sched, fwInfo, master)
try {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package spark.scheduler.standalone

import spark.TaskState.TaskState
import spark.scheduler.cluster.TaskDescription

sealed trait StandaloneClusterMessage extends Serializable

case class RegisterSlave(slaveId: String, host: String, cores: Int) extends StandaloneClusterMessage
case class LaunchTask(slaveId: String, task: TaskDescription) extends StandaloneClusterMessage

case class StatusUpdate(slaveId: String, taskId: Long, state: TaskState, data: Array[Byte])
extends StandaloneClusterMessage

case object ReviveOffers extends StandaloneClusterMessage
case object StopMaster extends StandaloneClusterMessage

Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package spark.scheduler.standalone

import scala.collection.mutable.{HashMap, HashSet}

import akka.actor.{Props, Actor, ActorRef, ActorSystem}
import akka.util.duration._
import akka.pattern.ask

import spark.{SparkException, Logging, TaskState}
import spark.TaskState.TaskState
import spark.scheduler.cluster.{WorkerOffer, ClusterScheduler, SchedulerBackend}
import akka.dispatch.Await
import java.nio.ByteBuffer
import java.util.concurrent.atomic.AtomicInteger

/**
* A standalone scheduler backend, which waits for standalone executors to connect to it through
* Akka. These may be executed in a variety of ways, such as Mesos tasks for the coarse-grained
* Mesos mode or standalone processes for Spark's standalone deploy mode (spark.deploy.*).
*/
class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: ActorSystem)
extends SchedulerBackend
with Logging {

// Use an atomic variable to track total number of cores in the cluster for simplicity and speed
var totalCoreCount = new AtomicInteger(0)

class MasterActor extends Actor {
val slaveActor = new HashMap[String, ActorRef]
val slaveHost = new HashMap[String, String]
val freeCores = new HashMap[String, Int]

def receive = {
case RegisterSlave(slaveId, host, cores) =>
slaveActor(slaveId) = sender
logInfo("Registered slave: " + sender + " with ID " + slaveId)
slaveHost(slaveId) = host
freeCores(slaveId) = cores
totalCoreCount.addAndGet(cores)
makeOffers()

case StatusUpdate(slaveId, taskId, state, data) =>
scheduler.statusUpdate(taskId, state, ByteBuffer.wrap(data))
if (TaskState.isFinished(state)) {
freeCores(slaveId) += 1
makeOffers(slaveId)
}

case LaunchTask(slaveId, task) =>
freeCores(slaveId) -= 1
slaveActor(slaveId) ! LaunchTask(slaveId, task)

case ReviveOffers =>
makeOffers()

case StopMaster =>
sender ! true
context.stop(self)

// TODO: Deal with nodes disconnecting too! (Including decreasing totalCoreCount)
}

// Make fake resource offers on all slaves
def makeOffers() {
scheduler.resourceOffers(
slaveHost.toArray.map {case (id, host) => new WorkerOffer(id, host, freeCores(id))})
}

// Make fake resource offers on just one slave
def makeOffers(slaveId: String) {
scheduler.resourceOffers(
Seq(new WorkerOffer(slaveId, slaveHost(slaveId), freeCores(slaveId))))
}
}

var masterActor: ActorRef = null
val taskIdsOnSlave = new HashMap[String, HashSet[String]]

def start() {
masterActor = actorSystem.actorOf(
Props(new MasterActor), name = StandaloneSchedulerBackend.ACTOR_NAME)
}

def stop() {
try {
if (masterActor != null) {
val timeout = 5.seconds
val future = masterActor.ask(StopMaster)(timeout)
Await.result(future, timeout)
}
} catch {
case e: Exception =>
throw new SparkException("Error stopping standalone scheduler master actor", e)
}
}

def reviveOffers() {
masterActor ! ReviveOffers
}

def defaultParallelism(): Int = totalCoreCount.get()
}

object StandaloneSchedulerBackend {
val ACTOR_NAME = "StandaloneScheduler"
}
35 changes: 35 additions & 0 deletions core/src/main/scala/spark/util/SerializableByteBuffer.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package spark.util

import java.nio.ByteBuffer
import java.io.{IOException, ObjectOutputStream, EOFException, ObjectInputStream}
import java.nio.channels.Channels

/**
* A wrapper around java.nio.ByteBuffer to make it serializable through Java serialization.
*/
class SerializableByteBuffer(@transient var buffer: ByteBuffer) {
def value = buffer

private def readObject(in: ObjectInputStream) {
val length = in.readInt()
buffer = ByteBuffer.allocate(length)
var amountRead = 0
val channel = Channels.newChannel(in)
while (amountRead < length) {
val ret = channel.read(buffer)
if (ret == -1) {
throw new EOFException("End of file before fully reading buffer")
}
amountRead += ret
}
buffer.rewind() // Allow us to read it later
}

private def writeObject(out: ObjectOutputStream) {
out.writeInt(buffer.limit())
if (Channels.newChannel(out).write(buffer) != buffer.limit()) {
throw new IOException("Could not fully write buffer to output stream")
}
buffer.rewind() // Allow us to write it again later
}
}
Loading

0 comments on commit 909b325

Please sign in to comment.