diff --git a/core/src/main/scala/spark/HttpFileServer.scala b/core/src/main/scala/spark/HttpFileServer.scala new file mode 100644 index 0000000000000000000000000000000000000000..3659de02c7fc6b6dd4f853f2119d5cec8b5495ee --- /dev/null +++ b/core/src/main/scala/spark/HttpFileServer.scala @@ -0,0 +1,31 @@ +package spark + +import java.io.{File, PrintWriter} +import java.net.URL +import scala.collection.mutable.HashMap +import org.apache.hadoop.fs.FileUtil + +class HttpFileServer extends Logging { + + var fileDir : File = null + var httpServer : HttpServer = null + var serverUri : String = null + + def initialize() { + fileDir = Utils.createTempDir() + logInfo("HTTP File server directory is " + fileDir) + httpServer = new HttpServer(fileDir) + httpServer.start() + serverUri = httpServer.uri + } + + def addFile(file: File) : String = { + Utils.copyFile(file, new File(fileDir, file.getName)) + return serverUri + "/" + file.getName + } + + def stop() { + httpServer.stop() + } + +} \ No newline at end of file diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 5d0f2950d61c3e58b317bf57bee7ea8ecd219671..dee7cd49251c49407cdbe4d148d7228cada32d59 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -2,11 +2,12 @@ package spark import java.io._ import java.util.concurrent.atomic.AtomicInteger +import java.net.URI import akka.actor.Actor import akka.actor.Actor._ -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ArrayBuffer, HashMap} import org.apache.hadoop.fs.Path import org.apache.hadoop.conf.Configuration @@ -76,7 +77,10 @@ class SparkContext( true, isLocal) SparkEnv.set(env) - + + // Used to store a URL for each static file together with the file's local timestamp + val files = HashMap[String, Long]() + // Create and start the scheduler private var taskScheduler: TaskScheduler = { // Regular expression used for local[N] master format @@ -90,13 +94,13 @@ class SparkContext( master match { case "local" => - new LocalScheduler(1, 0) + new LocalScheduler(1, 0, this) case LOCAL_N_REGEX(threads) => - new LocalScheduler(threads.toInt, 0) + new LocalScheduler(threads.toInt, 0, this) case LOCAL_N_FAILURES_REGEX(threads, maxFailures) => - new LocalScheduler(threads.toInt, maxFailures.toInt) + new LocalScheduler(threads.toInt, maxFailures.toInt, this) case SPARK_REGEX(sparkUrl) => val scheduler = new ClusterScheduler(this) @@ -131,7 +135,7 @@ class SparkContext( taskScheduler.start() private var dagScheduler = new DAGScheduler(taskScheduler) - + // Methods for creating RDDs def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism ): RDD[T] = { @@ -310,7 +314,24 @@ class SparkContext( // Keep around a weak hash map of values to Cached versions? def broadcast[T](value: T) = SparkEnv.get.broadcastManager.newBroadcast[T] (value, isLocal) - + + // Adds a file dependency to all Tasks executed in the future. + def addFile(path: String) : String = { + val uri = new URI(path) + uri.getScheme match { + // A local file + case null | "file" => + val file = new File(uri.getPath) + val url = env.httpFileServer.addFile(file) + files(url) = System.currentTimeMillis + logInfo("Added file " + path + " at " + url + " with timestamp " + files(url)) + return url + case _ => + files(path) = System.currentTimeMillis + return path + } + } + // Stop the SparkContext def stop() { dagScheduler.stop() diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index add8fcec51e65e174b0dc28b5b3e9357fa85b7d2..a95d1bc8ea8cb9a1f7fd2b6f00d0fe88d0bce81b 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -19,15 +19,17 @@ class SparkEnv ( val shuffleManager: ShuffleManager, val broadcastManager: BroadcastManager, val blockManager: BlockManager, - val connectionManager: ConnectionManager + val connectionManager: ConnectionManager, + val httpFileServer: HttpFileServer ) { /** No-parameter constructor for unit tests. */ def this() = { - this(null, null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null, null) + this(null, null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null, null, null) } def stop() { + httpFileServer.stop() mapOutputTracker.stop() cacheTracker.stop() shuffleFetcher.stop() @@ -95,7 +97,11 @@ object SparkEnv { System.getProperty("spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher") val shuffleFetcher = Class.forName(shuffleFetcherClass).newInstance().asInstanceOf[ShuffleFetcher] - + + val httpFileServer = new HttpFileServer() + httpFileServer.initialize() + System.setProperty("spark.fileserver.uri", httpFileServer.serverUri) + /* if (System.getProperty("spark.stream.distributed", "false") == "true") { val blockManagerClass = classOf[spark.storage.BlockManager].asInstanceOf[Class[_]] @@ -126,6 +132,7 @@ object SparkEnv { shuffleManager, broadcastManager, blockManager, - connectionManager) + connectionManager, + httpFileServer) } } diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 5eda1011f9f7cb42b093341ca4dd02719fdcfe6c..eb0a4c99bbc51510617842a4d22f45cc077ed561 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -1,18 +1,19 @@ package spark import java.io._ -import java.net.InetAddress +import java.net.{InetAddress, URL, URI} +import java.util.{Locale, UUID} import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor} - +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{Path, FileSystem} import scala.collection.mutable.ArrayBuffer import scala.util.Random -import java.util.{Locale, UUID} import scala.io.Source /** * Various utility methods used by Spark. */ -object Utils { +object Utils extends Logging { /** Serialize an object using Java serialization */ def serialize[T](o: T): Array[Byte] = { val bos = new ByteArrayOutputStream() @@ -115,6 +116,47 @@ object Utils { val out = new FileOutputStream(dest) copyStream(in, out, true) } + + + + /* Download a file from a given URL to the local filesystem */ + def downloadFile(url: URL, localPath: String) { + val in = url.openStream() + val out = new FileOutputStream(localPath) + Utils.copyStream(in, out, true) + } + + /** + * Download a file requested by the executor. Supports fetching the file in a variety of ways, + * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter. + */ + def fetchFile(url: String, targetDir: File) { + val filename = url.split("/").last + val targetFile = new File(targetDir, filename) + if (url.startsWith("http://") || url.startsWith("https://") || url.startsWith("ftp://")) { + // Use the java.net library to fetch it + logInfo("Fetching " + url + " to " + targetFile) + val in = new URL(url).openStream() + val out = new FileOutputStream(targetFile) + Utils.copyStream(in, out, true) + } else { + // Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others + val uri = new URI(url) + val conf = new Configuration() + val fs = FileSystem.get(uri, conf) + val in = fs.open(new Path(uri)) + val out = new FileOutputStream(targetFile) + Utils.copyStream(in, out, true) + } + // Decompress the file if it's a .tar or .tar.gz + if (filename.endsWith(".tar.gz") || filename.endsWith(".tgz")) { + logInfo("Untarring " + filename) + Utils.execute(Seq("tar", "-xzf", filename), targetDir) + } else if (filename.endsWith(".tar")) { + logInfo("Untarring " + filename) + Utils.execute(Seq("tar", "-xf", filename), targetDir) + } + } /** * Shuffle the elements of a collection into a random order, returning the diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala index 1740a42a7eff69bffd4ccbe3cfe881d534a8793f..704336102019cf706c011911a508d33865fd70a6 100644 --- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala @@ -65,38 +65,6 @@ class ExecutorRunner( } } - /** - * Download a file requested by the executor. Supports fetching the file in a variety of ways, - * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter. - */ - def fetchFile(url: String, targetDir: File) { - val filename = url.split("/").last - val targetFile = new File(targetDir, filename) - if (url.startsWith("http://") || url.startsWith("https://") || url.startsWith("ftp://")) { - // Use the java.net library to fetch it - logInfo("Fetching " + url + " to " + targetFile) - val in = new URL(url).openStream() - val out = new FileOutputStream(targetFile) - Utils.copyStream(in, out, true) - } else { - // Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others - val uri = new URI(url) - val conf = new Configuration() - val fs = FileSystem.get(uri, conf) - val in = fs.open(new Path(uri)) - val out = new FileOutputStream(targetFile) - Utils.copyStream(in, out, true) - } - // Decompress the file if it's a .tar or .tar.gz - if (filename.endsWith(".tar.gz") || filename.endsWith(".tgz")) { - logInfo("Untarring " + filename) - Utils.execute(Seq("tar", "-xzf", filename), targetDir) - } else if (filename.endsWith(".tar")) { - logInfo("Untarring " + filename) - Utils.execute(Seq("tar", "-xf", filename), targetDir) - } - } - /** Replace variables such as {{SLAVEID}} and {{CORES}} in a command argument passed to us */ def substituteVariables(argument: String): String = argument match { case "{{SLAVEID}}" => workerId diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala index dba209ac2726febf78e774ef1cc8b7fd6e7903a3..ce3aa4972627e03a5ca31a2f477562609fbeae2d 100644 --- a/core/src/main/scala/spark/executor/Executor.scala +++ b/core/src/main/scala/spark/executor/Executor.scala @@ -4,7 +4,9 @@ import java.io.{File, FileOutputStream} import java.net.{URL, URLClassLoader} import java.util.concurrent._ -import scala.collection.mutable.ArrayBuffer +import org.apache.hadoop.fs.FileUtil + +import scala.collection.mutable.{ArrayBuffer, HashMap} import spark.broadcast._ import spark.scheduler._ @@ -18,6 +20,8 @@ class Executor extends Logging { var classLoader: ClassLoader = null var threadPool: ExecutorService = null var env: SparkEnv = null + + val fileSet: HashMap[String, Long] = new HashMap[String, Long]() val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0)) @@ -63,6 +67,7 @@ class Executor extends Logging { Thread.currentThread.setContextClassLoader(classLoader) Accumulators.clear() val task = ser.deserialize[Task[Any]](serializedTask, classLoader) + task.downloadFileDependencies(fileSet) logInfo("Its generation is " + task.generation) env.mapOutputTracker.updateGeneration(task.generation) val value = task.run(taskId.toInt) @@ -108,7 +113,7 @@ class Executor extends Logging { for (uri <- uris.split(",").filter(_.size > 0)) { val url = new URL(uri) val filename = url.getPath.split("/").last - downloadFile(url, filename) + Utils.downloadFile(url, filename) localFiles += filename } if (localFiles.size > 0) { @@ -136,10 +141,4 @@ class Executor extends Logging { return loader } - // Download a file from a given URL to the local filesystem - private def downloadFile(url: URL, localPath: String) { - val in = url.openStream() - val out = new FileOutputStream(localPath) - Utils.copyStream(in, out, true) - } } diff --git a/core/src/main/scala/spark/scheduler/Task.scala b/core/src/main/scala/spark/scheduler/Task.scala index f84d8d9c4f5c904ba959e4b6a5cbcd25fa531a1b..faf042ad02adb0580d082d6aa44d617fc72ed599 100644 --- a/core/src/main/scala/spark/scheduler/Task.scala +++ b/core/src/main/scala/spark/scheduler/Task.scala @@ -1,5 +1,10 @@ package spark.scheduler +import scala.collection.mutable.HashMap +import spark.HttpFileServer +import spark.Utils +import java.io.File + /** * A task to execute on a worker node. */ @@ -8,4 +13,21 @@ abstract class Task[T](val stageId: Int) extends Serializable { def preferredLocations: Seq[String] = Nil var generation: Long = -1 // Map output tracker generation. Will be set by TaskScheduler. + + // Stores file dependencies for this task. + var fileSet : HashMap[String, Long] = new HashMap[String, Long]() + + // Downloads all file dependencies from the Master file server + def downloadFileDependencies(currentFileSet : HashMap[String, Long]) { + // Find files that either don't exist or have an earlier timestamp + val missingFiles = fileSet.filter { case(k,v) => + !currentFileSet.isDefinedAt(k) || currentFileSet(k) <= v + } + // Fetch each missing file + missingFiles.foreach { case (k,v) => + Utils.fetchFile(k, new File(System.getProperty("user.dir"))) + currentFileSet(k) = v + } + } + } diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index 5b59479682f2b48cef2d93da39f75021b1552ca9..a9ab82040c4144b58d33f37df6a50294339f2bf1 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -88,6 +88,7 @@ class ClusterScheduler(sc: SparkContext) def submitTasks(taskSet: TaskSet) { val tasks = taskSet.tasks + tasks.foreach { task => task.fileSet ++= sc.files } logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks") this.synchronized { val manager = new TaskSetManager(this, taskSet) @@ -235,30 +236,24 @@ class ClusterScheduler(sc: SparkContext) } override def defaultParallelism() = backend.defaultParallelism() - - // Create a server for all the JARs added by the user to SparkContext. - // We first copy the JARs to a temp directory for easier server setup. + + // Copies all the JARs added by the user to the SparkContext + // to the fileserver directory. private def createJarServer() { - val jarDir = Utils.createTempDir() - logInfo("Temp directory for JARs: " + jarDir) + val fileServerDir = SparkEnv.get.httpFileServer.fileDir + val fileServerUri = SparkEnv.get.httpFileServer.serverUri val filenames = ArrayBuffer[String]() - // Copy each JAR to a unique filename in the jarDir for ((path, index) <- sc.jars.zipWithIndex) { val file = new File(path) if (file.exists) { val filename = index + "_" + file.getName - Utils.copyFile(file, new File(jarDir, filename)) + Utils.copyFile(file, new File(fileServerDir, filename)) filenames += filename } } - // Create the server - jarServer = new HttpServer(jarDir) - jarServer.start() - // Build up the jar URI list - val serverUri = jarServer.uri - jarUris = filenames.map(f => serverUri + "/" + f).mkString(",") + jarUris = filenames.map(f => fileServerUri + "/" + f).mkString(",") System.setProperty("spark.jar.uris", jarUris) - logInfo("JAR server started at " + serverUri) + logInfo("JARs available at " + jarUris) } // Check for speculatable tasks in all our active jobs. diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index eb47988f0cdfb000936c17a13985dae317147e4d..4bd9d13637d2de92b6bfd6101717a11955f9d733 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -2,6 +2,7 @@ package spark.scheduler.local import java.util.concurrent.Executors import java.util.concurrent.atomic.AtomicInteger +import scala.collection.mutable.HashMap import spark._ import spark.scheduler._ @@ -11,12 +12,13 @@ import spark.scheduler._ * the scheduler also allows each task to fail up to maxFailures times, which is useful for * testing fault recovery. */ -class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with Logging { +class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkContext) extends TaskScheduler with Logging { var attemptId = new AtomicInteger(0) var threadPool = Executors.newFixedThreadPool(threads, DaemonThreadFactory) val env = SparkEnv.get var listener: TaskSchedulerListener = null - + val fileSet: HashMap[String, Long] = new HashMap[String, Long]() + // TODO: Need to take into account stage priority in scheduling override def start() {} @@ -30,6 +32,7 @@ class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with val failCount = new Array[Int](tasks.size) def submitTask(task: Task[_], idInJob: Int) { + task.fileSet ++= sc.files val myAttemptId = attemptId.getAndIncrement() threadPool.submit(new Runnable { def run() { @@ -42,6 +45,7 @@ class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with logInfo("Running task " + idInJob) // Set the Spark execution environment for the worker thread SparkEnv.set(env) + task.downloadFileDependencies(fileSet) try { // Serialize and deserialize the task so that accumulators are changed to thread-local ones; // this adds a bit of unnecessary overhead but matches how the Mesos Executor works. @@ -81,6 +85,7 @@ class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with } } + override def stop() { threadPool.shutdownNow() } diff --git a/core/src/test/scala/spark/FileServerSuite.scala b/core/src/test/scala/spark/FileServerSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..883149fecac72f332f945e46279f14c1b604e3f1 --- /dev/null +++ b/core/src/test/scala/spark/FileServerSuite.scala @@ -0,0 +1,43 @@ +package spark + +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter +import java.io.{File, PrintWriter} + +class FileServerSuite extends FunSuite with BeforeAndAfter { + + var sc: SparkContext = _ + + before { + // Create a sample text file + val pw = new PrintWriter(System.getProperty("java.io.tmpdir") + "FileServerSuite.txt") + pw.println("100") + pw.close() + } + + after { + if (sc != null) { + sc.stop() + sc = null + } + // Clean up downloaded file + val tmpFile = new File("FileServerSuite.txt") + if (tmpFile.exists) { + tmpFile.delete() + } + } + + test("Distributing files") { + sc = new SparkContext("local[4]", "test") + sc.addFile(System.getProperty("java.io.tmpdir") + "FileServerSuite.txt") + val testRdd = sc.parallelize(List(1,2,3,4)) + val result = testRdd.map { x => + val in = new java.io.BufferedReader(new java.io.FileReader("FileServerSuite.txt")) + val fileVal = in.readLine().toInt + in.close() + fileVal + }.reduce(_ + _) + assert(result == 400) + } + +} \ No newline at end of file