Skip to content
Snippets Groups Projects
Commit b864c36a authored by Denny's avatar Denny
Browse files

Dynamically adding jar files and caching fileSets.

parent f275fb07
No related branches found
No related tags found
No related merge requests found
......@@ -7,25 +7,39 @@ import org.apache.hadoop.fs.FileUtil
class HttpFileServer extends Logging {
var baseDir : File = null
var fileDir : File = null
var jarDir : File = null
var httpServer : HttpServer = null
var serverUri : String = null
def initialize() {
fileDir = Utils.createTempDir()
logInfo("HTTP File server directory is " + fileDir)
baseDir = Utils.createTempDir()
fileDir = new File(baseDir, "files")
jarDir = new File(baseDir, "jars")
fileDir.mkdir()
jarDir.mkdir()
logInfo("HTTP File server directory is " + baseDir)
httpServer = new HttpServer(fileDir)
httpServer.start()
serverUri = httpServer.uri
}
def stop() {
httpServer.stop()
}
def addFile(file: File) : String = {
Utils.copyFile(file, new File(fileDir, file.getName))
return serverUri + "/" + file.getName
return addFileToDir(file, fileDir)
}
def stop() {
httpServer.stop()
def addJar(file: File) : String = {
return addFileToDir(file, jarDir)
}
def addFileToDir(file: File, dir: File) : String = {
Utils.copyFile(file, new File(dir, file.getName))
return dir + "/" + file.getName
}
}
\ No newline at end of file
......@@ -2,14 +2,14 @@ package spark
import java.io._
import java.util.concurrent.atomic.AtomicInteger
import java.net.URI
import java.net.{URI, URLClassLoader}
import akka.actor.Actor
import akka.actor.Actor._
import scala.collection.mutable.{ArrayBuffer, HashMap}
import org.apache.hadoop.fs.Path
import org.apache.hadoop.fs.{FileUtil, Path}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.mapred.SequenceFileInputFormat
......@@ -78,8 +78,12 @@ class SparkContext(
isLocal)
SparkEnv.set(env)
// Used to store a URL for each static file together with the file's local timestamp
val files = HashMap[String, Long]()
// Used to store a URL for each static file/jar together with the file's local timestamp
val addedFiles = HashMap[String, Long]()
val addedJars = HashMap[String, Long]()
// Add each JAR given through the constructor
jars.foreach { addJar(_) }
// Create and start the scheduler
private var taskScheduler: TaskScheduler = {
......@@ -316,20 +320,40 @@ class SparkContext(
def broadcast[T](value: T) = SparkEnv.get.broadcastManager.newBroadcast[T] (value, isLocal)
// Adds a file dependency to all Tasks executed in the future.
def addFile(path: String) : String = {
def addFile(path: String) {
val uri = new URI(path)
uri.getScheme match {
// A local file
case null | "file" =>
val file = new File(uri.getPath)
val url = env.httpFileServer.addFile(file)
files(url) = System.currentTimeMillis
logInfo("Added file " + path + " at " + url + " with timestamp " + files(url))
return url
case _ =>
files(path) = System.currentTimeMillis
return path
val key = uri.getScheme match {
case null | "file" => env.httpFileServer.addFile(new File(uri.getPath))
case _ => path
}
addedFiles(key) = System.currentTimeMillis
// Fetch the file locally in case the task is executed locally
val filename = new File(path.split("/").last)
Utils.fetchFile(path, new File(""))
logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key))
}
def clearFiles() {
addedFiles.keySet.map(_.split("/").last).foreach { k => new File(k).delete() }
addedFiles.clear()
}
// Adds a jar dependency to all Tasks executed in the future.
def addJar(path: String) {
val uri = new URI(path)
val key = uri.getScheme match {
case null | "file" => env.httpFileServer.addJar(new File(uri.getPath))
case _ => path
}
addedJars(key) = System.currentTimeMillis
logInfo("Added JAR " + path + " at " + key + " with timestamp " + addedJars(key))
}
def clearJars() {
addedJars.keySet.map(_.split("/").last).foreach { k => new File(k).delete() }
addedJars.clear()
}
// Stop the SparkContext
......@@ -339,6 +363,9 @@ class SparkContext(
taskScheduler = null
// TODO: Cache.stop()?
env.stop()
// Clean up locally linked files
clearFiles()
clearJars()
SparkEnv.set(null)
ShuffleMapTask.clearCache()
logInfo("Successfully stopped SparkContext")
......
......@@ -5,7 +5,7 @@ import java.net.{InetAddress, URL, URI}
import java.util.{Locale, UUID}
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{Path, FileSystem}
import org.apache.hadoop.fs.{Path, FileSystem, FileUtil}
import scala.collection.mutable.ArrayBuffer
import scala.util.Random
import scala.io.Source
......@@ -133,20 +133,27 @@ object Utils extends Logging {
def fetchFile(url: String, targetDir: File) {
val filename = url.split("/").last
val targetFile = new File(targetDir, filename)
if (url.startsWith("http://") || url.startsWith("https://") || url.startsWith("ftp://")) {
// Use the java.net library to fetch it
logInfo("Fetching " + url + " to " + targetFile)
val in = new URL(url).openStream()
val out = new FileOutputStream(targetFile)
Utils.copyStream(in, out, true)
} else {
// Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others
val uri = new URI(url)
val conf = new Configuration()
val fs = FileSystem.get(uri, conf)
val in = fs.open(new Path(uri))
val out = new FileOutputStream(targetFile)
Utils.copyStream(in, out, true)
val uri = new URI(url)
uri.getScheme match {
case "http" | "https" | "ftp" =>
logInfo("Fetching " + url + " to " + targetFile)
val in = new URL(url).openStream()
val out = new FileOutputStream(targetFile)
Utils.copyStream(in, out, true)
case "file" | null =>
// Remove the file if it already exists
targetFile.delete()
// Symlink the file locally
logInfo("Symlinking " + url + " to " + targetFile)
FileUtil.symLink(url, targetFile.toString)
case _ =>
// Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others
val uri = new URI(url)
val conf = new Configuration()
val fs = FileSystem.get(uri, conf)
val in = fs.open(new Path(uri))
val out = new FileOutputStream(targetFile)
Utils.copyStream(in, out, true)
}
// Decompress the file if it's a .tar or .tar.gz
if (filename.endsWith(".tar.gz") || filename.endsWith(".tgz")) {
......
package spark.executor
import java.io.{File, FileOutputStream}
import java.net.{URL, URLClassLoader}
import java.net.{URI, URL, URLClassLoader}
import java.util.concurrent._
import org.apache.hadoop.fs.FileUtil
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.collection.mutable.{ArrayBuffer, Map, HashMap}
import spark.broadcast._
import spark.scheduler._
......@@ -17,11 +17,13 @@ import java.nio.ByteBuffer
* The Mesos executor for Spark.
*/
class Executor extends Logging {
var classLoader: ClassLoader = null
var urlClassLoader : URLClassLoader = null
var threadPool: ExecutorService = null
var env: SparkEnv = null
val fileSet: HashMap[String, Long] = new HashMap[String, Long]()
val jarSet: HashMap[String, Long] = new HashMap[String, Long]()
val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0))
......@@ -40,13 +42,14 @@ class Executor extends Logging {
env = SparkEnv.createFromSystemProperties(slaveHostname, 0, false, false)
SparkEnv.set(env)
// Create our ClassLoader (using spark properties) and set it on this thread
classLoader = createClassLoader()
Thread.currentThread.setContextClassLoader(classLoader)
// Start worker thread pool
threadPool = new ThreadPoolExecutor(
1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable])
// Create our ClassLoader and set it on this thread
urlClassLoader = createClassLoader()
Thread.currentThread.setContextClassLoader(urlClassLoader)
}
def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) {
......@@ -58,16 +61,16 @@ class Executor extends Logging {
override def run() {
SparkEnv.set(env)
Thread.currentThread.setContextClassLoader(classLoader)
Thread.currentThread.setContextClassLoader(urlClassLoader)
val ser = SparkEnv.get.closureSerializer.newInstance()
logInfo("Running task ID " + taskId)
context.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
try {
SparkEnv.set(env)
Thread.currentThread.setContextClassLoader(classLoader)
Accumulators.clear()
val task = ser.deserialize[Task[Any]](serializedTask, classLoader)
task.downloadFileDependencies(fileSet)
val task = ser.deserialize[Task[Any]](serializedTask, urlClassLoader)
task.downloadDependencies(fileSet, jarSet)
updateClassLoader()
logInfo("Its generation is " + task.generation)
env.mapOutputTracker.updateGeneration(task.generation)
val value = task.run(taskId.toInt)
......@@ -101,25 +104,16 @@ class Executor extends Logging {
* Create a ClassLoader for use in tasks, adding any JARs specified by the user or any classes
* created by the interpreter to the search path
*/
private def createClassLoader(): ClassLoader = {
var loader = this.getClass.getClassLoader
// If any JAR URIs are given through spark.jar.uris, fetch them to the
// current directory and put them all on the classpath. We assume that
// each URL has a unique file name so that no local filenames will clash
// in this process. This is guaranteed by ClusterScheduler.
val uris = System.getProperty("spark.jar.uris", "")
val localFiles = ArrayBuffer[String]()
for (uri <- uris.split(",").filter(_.size > 0)) {
val url = new URL(uri)
val filename = url.getPath.split("/").last
Utils.downloadFile(url, filename)
localFiles += filename
}
if (localFiles.size > 0) {
val urls = localFiles.map(f => new File(f).toURI.toURL).toArray
loader = new URLClassLoader(urls, loader)
}
private def createClassLoader(): URLClassLoader = {
var loader = this.getClass().getClassLoader()
// For each of the jars in the jarSet, add them to the class loader.
// We assume each of the files has already been fetched.
val urls = jarSet.keySet.map { uri =>
new File(uri.split("/").last).toURI.toURL
}.toArray
loader = new URLClassLoader(urls, loader)
// If the REPL is in use, add another ClassLoader that will read
// new classes defined by the REPL as the user types code
......@@ -138,7 +132,23 @@ class Executor extends Logging {
}
}
return loader
return new URLClassLoader(Array(), loader)
}
def updateClassLoader() {
val currentURLs = urlClassLoader.getURLs()
val urlSet = jarSet.keySet.map { x => new File(x.split("/").last).toURI.toURL }
// For abstraction reasons the addURL method in URLClassLoader is protected.
// We'll save us the hassle of sublassing here and use relfection instead.
val m = classOf[URLClassLoader].getDeclaredMethod("addURL", classOf[URL])
m.setAccessible(true)
urlSet.filterNot(currentURLs.contains(_)).foreach { url =>
logInfo("Adding " + url + " to the class loader.")
m.invoke(urlClassLoader, url)
}
}
}
package spark.scheduler
import java.io._
import java.util.HashMap
import java.util.{HashMap => JHashMap}
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.collection.JavaConversions._
import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
......@@ -20,7 +20,8 @@ object ShuffleMapTask {
// A simple map between the stage id to the serialized byte array of a task.
// Served as a cache for task serialization because serialization can be
// expensive on the master node if it needs to launch thousands of tasks.
val serializedInfoCache = new HashMap[Int, Array[Byte]]
val serializedInfoCache = new JHashMap[Int, Array[Byte]]
val fileSetCache = new JHashMap[Int, Array[Byte]]
def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_,_]): Array[Byte] = {
synchronized {
......@@ -40,6 +41,23 @@ object ShuffleMapTask {
}
}
// Since both the JarSet and FileSet have the same format this is used for both.
def serializeFileSet(set : HashMap[String, Long]) : Array[Byte] = {
val old = fileSetCache.get(set.hashCode)
if (old != null) {
return old
} else {
val out = new ByteArrayOutputStream
val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
objOut.writeObject(set.toArray)
objOut.close()
val bytes = out.toByteArray
fileSetCache.put(set.hashCode, bytes)
return bytes
}
}
def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_,_,_]) = {
synchronized {
val loader = Thread.currentThread.getContextClassLoader
......@@ -54,9 +72,18 @@ object ShuffleMapTask {
}
}
// Since both the JarSet and FileSet have the same format this is used for both.
def deserializeFileSet(bytes: Array[Byte]) : HashMap[String, Long] = {
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
val objIn = new ObjectInputStream(in)
val set = objIn.readObject().asInstanceOf[Array[(String, Long)]].toMap
return (HashMap(set.toSeq: _*))
}
def clearCache() {
synchronized {
serializedInfoCache.clear()
fileSetCache.clear()
}
}
}
......@@ -84,6 +111,14 @@ class ShuffleMapTask(
val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep)
out.writeInt(bytes.length)
out.write(bytes)
val fileSetBytes = ShuffleMapTask.serializeFileSet(fileSet)
out.writeInt(fileSetBytes.length)
out.write(fileSetBytes)
val jarSetBytes = ShuffleMapTask.serializeFileSet(jarSet)
out.writeInt(jarSetBytes.length)
out.write(jarSetBytes)
out.writeInt(partition)
out.writeLong(generation)
out.writeObject(split)
......@@ -97,6 +132,17 @@ class ShuffleMapTask(
val (rdd_, dep_) = ShuffleMapTask.deserializeInfo(stageId, bytes)
rdd = rdd_
dep = dep_
val fileSetNumBytes = in.readInt()
val fileSetBytes = new Array[Byte](fileSetNumBytes)
in.readFully(fileSetBytes)
fileSet = ShuffleMapTask.deserializeFileSet(fileSetBytes)
val jarSetNumBytes = in.readInt()
val jarSetBytes = new Array[Byte](jarSetNumBytes)
in.readFully(jarSetBytes)
fileSet = ShuffleMapTask.deserializeFileSet(jarSetBytes)
partition = in.readInt()
generation = in.readLong()
split = in.readObject().asInstanceOf[Split]
......@@ -110,7 +156,7 @@ class ShuffleMapTask(
val bucketIterators =
if (aggregator.mapSideCombine) {
// Apply combiners (map-side aggregation) to the map output.
val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[Any, Any])
val buckets = Array.tabulate(numOutputSplits)(_ => new JHashMap[Any, Any])
for (elem <- rdd.iterator(split)) {
val (k, v) = elem.asInstanceOf[(Any, Any)]
val bucketId = partitioner.getPartition(k)
......
package spark.scheduler
import scala.collection.mutable.HashMap
import scala.collection.mutable.{HashMap}
import spark.HttpFileServer
import spark.Utils
import java.io.File
......@@ -14,20 +14,29 @@ abstract class Task[T](val stageId: Int) extends Serializable {
var generation: Long = -1 // Map output tracker generation. Will be set by TaskScheduler.
// Stores file dependencies for this task.
// Stores jar and file dependencies for this task.
var fileSet : HashMap[String, Long] = new HashMap[String, Long]()
var jarSet : HashMap[String, Long] = new HashMap[String, Long]()
// Downloads all file dependencies from the Master file server
def downloadFileDependencies(currentFileSet : HashMap[String, Long]) {
// Find files that either don't exist or have an earlier timestamp
val missingFiles = fileSet.filter { case(k,v) =>
!currentFileSet.isDefinedAt(k) || currentFileSet(k) <= v
}
// Fetch each missing file
missingFiles.foreach { case (k,v) =>
def downloadDependencies(currentFileSet : HashMap[String, Long],
currentJarSet : HashMap[String, Long]) {
// Fetch missing file dependencies
fileSet.filter { case(k,v) =>
!currentFileSet.contains(k) || currentFileSet(k) <= v
}.foreach { case (k,v) =>
Utils.fetchFile(k, new File(System.getProperty("user.dir")))
currentFileSet(k) = v
}
// Fetch missing jar dependencies
jarSet.filter { case(k,v) =>
!currentJarSet.contains(k) || currentJarSet(k) <= v
}.foreach { case (k,v) =>
Utils.fetchFile(k, new File(System.getProperty("user.dir")))
currentJarSet(k) = v
}
}
}
......@@ -60,7 +60,6 @@ class ClusterScheduler(sc: SparkContext)
def initialize(context: SchedulerBackend) {
backend = context
createJarServer()
}
def newTaskId(): Long = nextTaskId.getAndIncrement()
......@@ -88,7 +87,10 @@ class ClusterScheduler(sc: SparkContext)
def submitTasks(taskSet: TaskSet) {
val tasks = taskSet.tasks
tasks.foreach { task => task.fileSet ++= sc.files }
tasks.foreach { task =>
task.fileSet ++= sc.addedFiles
task.jarSet ++= sc.addedJars
}
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
this.synchronized {
val manager = new TaskSetManager(this, taskSet)
......@@ -237,25 +239,6 @@ class ClusterScheduler(sc: SparkContext)
override def defaultParallelism() = backend.defaultParallelism()
// Copies all the JARs added by the user to the SparkContext
// to the fileserver directory.
private def createJarServer() {
val fileServerDir = SparkEnv.get.httpFileServer.fileDir
val fileServerUri = SparkEnv.get.httpFileServer.serverUri
val filenames = ArrayBuffer[String]()
for ((path, index) <- sc.jars.zipWithIndex) {
val file = new File(path)
if (file.exists) {
val filename = index + "_" + file.getName
Utils.copyFile(file, new File(fileServerDir, filename))
filenames += filename
}
}
jarUris = filenames.map(f => fileServerUri + "/" + f).mkString(",")
System.setProperty("spark.jar.uris", jarUris)
logInfo("JARs available at " + jarUris)
}
// Check for speculatable tasks in all our active jobs.
def checkSpeculatableTasks() {
var shouldRevive = false
......
package spark.scheduler.local
import java.io.File
import java.net.URLClassLoader
import java.util.concurrent.Executors
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.HashMap
......@@ -18,10 +20,11 @@ class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkContext) extends T
val env = SparkEnv.get
var listener: TaskSchedulerListener = null
val fileSet: HashMap[String, Long] = new HashMap[String, Long]()
val jarSet: HashMap[String, Long] = new HashMap[String, Long]()
// TODO: Need to take into account stage priority in scheduling
override def start() {}
override def start() { }
override def setListener(listener: TaskSchedulerListener) {
this.listener = listener
......@@ -32,7 +35,8 @@ class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkContext) extends T
val failCount = new Array[Int](tasks.size)
def submitTask(task: Task[_], idInJob: Int) {
task.fileSet ++= sc.files
task.fileSet ++= sc.addedFiles
task.jarSet ++= sc.addedJars
val myAttemptId = attemptId.getAndIncrement()
threadPool.submit(new Runnable {
def run() {
......@@ -45,7 +49,9 @@ class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkContext) extends T
logInfo("Running task " + idInJob)
// Set the Spark execution environment for the worker thread
SparkEnv.set(env)
task.downloadFileDependencies(fileSet)
task.downloadDependencies(fileSet, jarSet)
// Create a new classLaoder for the downloaded JARs
Thread.currentThread.setContextClassLoader(createClassLoader())
try {
// Serialize and deserialize the task so that accumulators are changed to thread-local ones;
// this adds a bit of unnecessary overhead but matches how the Mesos Executor works.
......@@ -90,5 +96,14 @@ class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkContext) extends T
threadPool.shutdownNow()
}
private def createClassLoader() : ClassLoader = {
val currentLoader = Thread.currentThread.getContextClassLoader()
val urls = jarSet.keySet.map { uri =>
new File(uri.split("/").last).toURI.toURL
}.toArray
logInfo("Creating ClassLoader with jars: " + urls.mkString)
return new URLClassLoader(urls, currentLoader)
}
override def defaultParallelism() = threads
}
File added
package spark
import com.google.common.io.Files
import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfter
import java.io.{File, PrintWriter}
import SparkContext._
class FileServerSuite extends FunSuite with BeforeAndAfter {
var sc: SparkContext = _
var tmpFile : File = _
var testJarFile : File = _
before {
// Create a sample text file
val pw = new PrintWriter(System.getProperty("java.io.tmpdir") + "FileServerSuite.txt")
val tmpdir = new File(Files.createTempDir(), "test")
tmpdir.mkdir()
tmpFile = new File(tmpdir, "FileServerSuite.txt")
val pw = new PrintWriter(tmpFile)
pw.println("100")
pw.close()
}
......@@ -21,7 +28,6 @@ class FileServerSuite extends FunSuite with BeforeAndAfter {
sc = null
}
// Clean up downloaded file
val tmpFile = new File("FileServerSuite.txt")
if (tmpFile.exists) {
tmpFile.delete()
}
......@@ -29,15 +35,30 @@ class FileServerSuite extends FunSuite with BeforeAndAfter {
test("Distributing files") {
sc = new SparkContext("local[4]", "test")
sc.addFile(System.getProperty("java.io.tmpdir") + "FileServerSuite.txt")
val testRdd = sc.parallelize(List(1,2,3,4))
val result = testRdd.map { x =>
val in = new java.io.BufferedReader(new java.io.FileReader("FileServerSuite.txt"))
sc.addFile(tmpFile.toString)
val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0))
val result = sc.parallelize(testData).reduceByKey {
val in = new java.io.BufferedReader(new java.io.FileReader(tmpFile))
val fileVal = in.readLine().toInt
in.close()
fileVal
}.reduce(_ + _)
assert(result == 400)
_ * fileVal + _ * fileVal
}.collect
println(result)
assert(result.toSet === Set((1,200), (2,300), (3,500)))
}
test ("Dynamically adding JARS") {
sc = new SparkContext("local[4]", "test")
val sampleJarFile = getClass().getClassLoader().getResource("uncommons-maths-1.2.2.jar").getFile()
sc.addJar(sampleJarFile)
val testData = Array((1,1), (1,1), (2,1), (3,5), (2,3), (3,0))
val result = sc.parallelize(testData).reduceByKey { (x,y) =>
val fac = Thread.currentThread.getContextClassLoader().loadClass("org.uncommons.maths.Maths").getDeclaredMethod("factorial", classOf[Int])
val a = fac.invoke(null, x.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt
val b = fac.invoke(null, y.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt
a + b
}.collect()
assert(result.toSet === Set((1,2), (2,7), (3,121)))
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment