diff --git a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala index 7b954a477570f5c42e0b8ee4529b7234f751667e..9c37fadb78d2f260839061deee0f0dd89eb748f6 100644 --- a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala +++ b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala @@ -38,7 +38,6 @@ class BagelSuite extends FunSuite with Assertions with BeforeAndAfter with Timeo } // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown System.clearProperty("spark.driver.port") - System.clearProperty("spark.hostPort") } test("halting by voting") { diff --git a/bin/pyspark b/bin/pyspark index d6810f4686bf56ae9f9caa1c620c2a2633135f21..ed6f8da73035a5176e1f184d42e1a113d725e92c 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -59,12 +59,7 @@ if [ -n "$IPYTHON_OPTS" ]; then fi if [[ "$IPYTHON" = "1" ]] ; then - # IPython <1.0.0 doesn't honor PYTHONSTARTUP, while 1.0.0+ does. - # Hence we clear PYTHONSTARTUP and use the -c "%run $IPYTHONSTARTUP" command which works on all versions - # We also force interactive mode with "-i" - IPYTHONSTARTUP=$PYTHONSTARTUP - PYTHONSTARTUP= - exec ipython "$IPYTHON_OPTS" -i -c "%run $IPYTHONSTARTUP" + exec ipython $IPYTHON_OPTS else exec "$PYSPARK_PYTHON" "$@" fi diff --git a/conf/log4j.properties.template b/conf/log4j.properties.template index d72dbadc3904f327effddf99594045067be2f529..f7f853559468ae105b3b66cc8d856eaaaa25b044 100644 --- a/conf/log4j.properties.template +++ b/conf/log4j.properties.template @@ -1,8 +1,11 @@ # Set everything to be logged to the console log4j.rootCategory=INFO, console log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.target=System.err log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n -# Ignore messages below warning level from Jetty, because it's a bit verbose +# Settings to quiet third party logs that are too verbose log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO +log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO diff --git a/core/pom.xml b/core/pom.xml index aac0a9d11e12dafcb8bb053dc1f97a40be0ab4e2..9e5a450d57a477ec3b7020d518b5f461b72c2bdb 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -98,6 +98,11 @@ <groupId>${akka.group}</groupId> <artifactId>akka-slf4j_${scala.binary.version}</artifactId> </dependency> + <dependency> + <groupId>${akka.group}</groupId> + <artifactId>akka-testkit_${scala.binary.version}</artifactId> + <scope>test</scope> + </dependency> <dependency> <groupId>org.scala-lang</groupId> <artifactId>scala-library</artifactId> @@ -165,6 +170,11 @@ <artifactId>scalatest_${scala.binary.version}</artifactId> <scope>test</scope> </dependency> + <dependency> + <groupId>org.mockito</groupId> + <artifactId>mockito-all</artifactId> + <scope>test</scope> + </dependency> <dependency> <groupId>org.scalacheck</groupId> <artifactId>scalacheck_${scala.binary.version}</artifactId> diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties index d72dbadc3904f327effddf99594045067be2f529..f7f853559468ae105b3b66cc8d856eaaaa25b044 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties @@ -1,8 +1,11 @@ # Set everything to be logged to the console log4j.rootCategory=INFO, console log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.target=System.err log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n -# Ignore messages below warning level from Jetty, because it's a bit verbose +# Settings to quiet third party logs that are too verbose log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO +log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index 5f73d234aa0506447f23d59ca0e91b68c3ab5fb7..e89ac28b8eedfa111afc0f4219b8df48a0251e14 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -218,7 +218,7 @@ private object Accumulators { def newId: Long = synchronized { lastId += 1 - return lastId + lastId } def register(a: Accumulable[_, _], original: Boolean): Unit = synchronized { diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index 1a2ec55876c35089d3b448a078d2212cafda08f4..8b30cd4bfe69de89fe8b838e838728d743cbf2d0 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -17,7 +17,7 @@ package org.apache.spark -import org.apache.spark.util.AppendOnlyMap +import org.apache.spark.util.collection.{AppendOnlyMap, ExternalAppendOnlyMap} /** * A set of functions used to aggregate data. @@ -31,30 +31,51 @@ case class Aggregator[K, V, C] ( mergeValue: (C, V) => C, mergeCombiners: (C, C) => C) { + private val sparkConf = SparkEnv.get.conf + private val externalSorting = sparkConf.getBoolean("spark.shuffle.externalSorting", true) + def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]) : Iterator[(K, C)] = { - val combiners = new AppendOnlyMap[K, C] - var kv: Product2[K, V] = null - val update = (hadValue: Boolean, oldValue: C) => { - if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2) - } - while (iter.hasNext) { - kv = iter.next() - combiners.changeValue(kv._1, update) + if (!externalSorting) { + val combiners = new AppendOnlyMap[K,C] + var kv: Product2[K, V] = null + val update = (hadValue: Boolean, oldValue: C) => { + if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2) + } + while (iter.hasNext) { + kv = iter.next() + combiners.changeValue(kv._1, update) + } + combiners.iterator + } else { + val combiners = + new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners) + while (iter.hasNext) { + val (k, v) = iter.next() + combiners.insert(k, v) + } + combiners.iterator } - combiners.iterator } def combineCombinersByKey(iter: Iterator[(K, C)]) : Iterator[(K, C)] = { - val combiners = new AppendOnlyMap[K, C] - var kc: (K, C) = null - val update = (hadValue: Boolean, oldValue: C) => { - if (hadValue) mergeCombiners(oldValue, kc._2) else kc._2 + if (!externalSorting) { + val combiners = new AppendOnlyMap[K,C] + var kc: Product2[K, C] = null + val update = (hadValue: Boolean, oldValue: C) => { + if (hadValue) mergeCombiners(oldValue, kc._2) else kc._2 + } + while (iter.hasNext) { + kc = iter.next() + combiners.changeValue(kc._1, update) + } + combiners.iterator + } else { + val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners) + while (iter.hasNext) { + val (k, c) = iter.next() + combiners.insert(k, c) + } + combiners.iterator } - while (iter.hasNext) { - kc = iter.next() - combiners.changeValue(kc._1, update) - } - combiners.iterator } } - diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index 519ecde50a163507c97340c6de3b308cd6277324..8e5dd8a85020d4d3eb9bdf259c652638df6834a9 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -38,7 +38,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { blockManager.get(key) match { case Some(values) => // Partition is already materialized, so just return its values - return new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]]) + new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]]) case None => // Mark the split as loading (unless someone else marks it first) @@ -74,7 +74,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { val elements = new ArrayBuffer[Any] elements ++= computedValues blockManager.put(key, elements, storageLevel, tellMaster = true) - return elements.iterator.asInstanceOf[Iterator[T]] + elements.iterator.asInstanceOf[Iterator[T]] } finally { loading.synchronized { loading.remove(key) diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala index ad1ee20045f46e5591067364692a4f3e0de60b27..a885898ad48d45e6e8d49ac099f9938cd71630c9 100644 --- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala @@ -47,17 +47,17 @@ private[spark] class HttpFileServer extends Logging { def addFile(file: File) : String = { addFileToDir(file, fileDir) - return serverUri + "/files/" + file.getName + serverUri + "/files/" + file.getName } def addJar(file: File) : String = { addFileToDir(file, jarDir) - return serverUri + "/jars/" + file.getName + serverUri + "/jars/" + file.getName } def addFileToDir(file: File, dir: File) : String = { Files.copy(file, new File(dir, file.getName)) - return dir + "/" + file.getName + dir + "/" + file.getName } } diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index 4a34989e50e57255a6520a2a8e0a852ddda6e2e2..9063cae87e14099e388b97bdffc6ee267c0f7832 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -41,7 +41,7 @@ trait Logging { } log_ = LoggerFactory.getLogger(className) } - return log_ + log_ } // Log methods that take only a String diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 77b8ca1cce80b58b3d248fc81dc5d4e157b7cf09..30d182b008930af670b7e9120e78ed69189febd3 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -32,15 +32,16 @@ import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.{AkkaUtils, MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils} private[spark] sealed trait MapOutputTrackerMessage -private[spark] case class GetMapOutputStatuses(shuffleId: Int, requester: String) +private[spark] case class GetMapOutputStatuses(shuffleId: Int) extends MapOutputTrackerMessage private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster) extends Actor with Logging { def receive = { - case GetMapOutputStatuses(shuffleId: Int, requester: String) => - logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + requester) + case GetMapOutputStatuses(shuffleId: Int) => + val hostPort = sender.path.address.hostPort + logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort) sender ! tracker.getSerializedMapOutputStatuses(shuffleId) case StopMapOutputTracker => @@ -119,11 +120,10 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging { if (fetchedStatuses == null) { // We won the race to fetch the output locs; do so logInfo("Doing the fetch; tracker actor = " + trackerActor) - val hostPort = Utils.localHostPort(conf) // This try-finally prevents hangs due to timeouts: try { val fetchedBytes = - askTracker(GetMapOutputStatuses(shuffleId, hostPort)).asInstanceOf[Array[Byte]] + askTracker(GetMapOutputStatuses(shuffleId)).asInstanceOf[Array[Byte]] fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes) logInfo("Got the output locations") mapStatuses.put(shuffleId, fetchedStatuses) @@ -139,7 +139,7 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging { return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses) } } - else{ + else { throw new FetchFailedException(null, shuffleId, -1, reduceId, new Exception("Missing all output locations for shuffle " + shuffleId)) } diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 31b0773bfe06c6ca4cebc75e4a13430857940dd9..fc0a7498820b5039c54615044792207bf4ca276c 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -53,15 +53,16 @@ object Partitioner { return r.partitioner.get } if (rdd.context.conf.contains("spark.default.parallelism")) { - return new HashPartitioner(rdd.context.defaultParallelism) + new HashPartitioner(rdd.context.defaultParallelism) } else { - return new HashPartitioner(bySize.head.partitions.size) + new HashPartitioner(bySize.head.partitions.size) } } } /** - * A [[org.apache.spark.Partitioner]] that implements hash-based partitioning using Java's `Object.hashCode`. + * A [[org.apache.spark.Partitioner]] that implements hash-based partitioning using + * Java's `Object.hashCode`. * * Java arrays have hashCodes that are based on the arrays' identities rather than their contents, * so attempting to partition an RDD[Array[_]] or RDD[(Array[_], _)] using a HashPartitioner will @@ -84,8 +85,8 @@ class HashPartitioner(partitions: Int) extends Partitioner { } /** - * A [[org.apache.spark.Partitioner]] that partitions sortable records by range into roughly equal ranges. - * Determines the ranges by sampling the RDD passed in. + * A [[org.apache.spark.Partitioner]] that partitions sortable records by range into roughly + * equal ranges. The ranges are determined by sampling the content of the RDD passed in. */ class RangePartitioner[K <% Ordered[K]: ClassTag, V]( partitions: Int, diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 0e47f4e442927418cefd5cbc316d1ede06f93933..55ac76bf63909b4b01f8914c4fb065bad0b8d4cf 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -31,9 +31,9 @@ import scala.reflect.{ClassTag, classTag} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable, -FloatWritable, IntWritable, LongWritable, NullWritable, Text, Writable} + FloatWritable, IntWritable, LongWritable, NullWritable, Text, Writable} import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf, SequenceFileInputFormat, -TextInputFormat} + TextInputFormat} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHadoopJob} import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} import org.apache.mesos.MesosNativeLibrary @@ -49,7 +49,7 @@ import org.apache.spark.scheduler.local.LocalBackend import org.apache.spark.storage.{BlockManagerSource, RDDInfo, StorageStatus, StorageUtils} import org.apache.spark.ui.SparkUI import org.apache.spark.util.{Utils, TimeStampedHashMap, MetadataCleaner, MetadataCleanerType, -ClosureCleaner} + ClosureCleaner} /** * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark @@ -116,7 +116,7 @@ class SparkContext( throw new SparkException("An application must be set in your configuration") } - if (conf.get("spark.logConf", "false").toBoolean) { + if (conf.getBoolean("spark.logConf", false)) { logInfo("Spark configuration:\n" + conf.toDebugString) } @@ -244,6 +244,10 @@ class SparkContext( localProperties.set(new Properties()) } + /** + * Set a local property that affects jobs submitted from this thread, such as the + * Spark fair scheduler pool. + */ def setLocalProperty(key: String, value: String) { if (localProperties.get() == null) { localProperties.set(new Properties()) @@ -255,6 +259,10 @@ class SparkContext( } } + /** + * Get a local property set in this thread, or null if it is missing. See + * [[org.apache.spark.SparkContext.setLocalProperty]]. + */ def getLocalProperty(key: String): String = Option(localProperties.get).map(_.getProperty(key)).getOrElse(null) @@ -265,7 +273,7 @@ class SparkContext( } /** - * Assigns a group id to all the jobs started by this thread until the group id is set to a + * Assigns a group ID to all the jobs started by this thread until the group ID is set to a * different value or cleared. * * Often, a unit of execution in an application consists of multiple Spark actions or jobs. @@ -288,7 +296,7 @@ class SparkContext( setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, groupId) } - /** Clear the job group id and its description. */ + /** Clear the current thread's job group ID and its description. */ def clearJobGroup() { setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, null) setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, null) @@ -337,29 +345,42 @@ class SparkContext( } /** - * Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf given its InputFormat and any - * other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable, - * etc). + * Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf given its InputFormat and other + * necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable), + * using the older MapReduce API (`org.apache.hadoop.mapred`). + * + * @param conf JobConf for setting up the dataset + * @param inputFormatClass Class of the [[InputFormat]] + * @param keyClass Class of the keys + * @param valueClass Class of the values + * @param minSplits Minimum number of Hadoop Splits to generate. + * @param cloneRecords If true, Spark will clone the records produced by Hadoop RecordReader. + * Most RecordReader implementations reuse wrapper objects across multiple + * records, and can cause problems in RDD collect or aggregation operations. + * By default the records are cloned in Spark. However, application + * programmers can explicitly disable the cloning for better performance. */ - def hadoopRDD[K, V]( + def hadoopRDD[K: ClassTag, V: ClassTag]( conf: JobConf, inputFormatClass: Class[_ <: InputFormat[K, V]], keyClass: Class[K], valueClass: Class[V], - minSplits: Int = defaultMinSplits + minSplits: Int = defaultMinSplits, + cloneRecords: Boolean = true ): RDD[(K, V)] = { // Add necessary security credentials to the JobConf before broadcasting it. SparkHadoopUtil.get.addCredentials(conf) - new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits) + new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits, cloneRecords) } /** Get an RDD for a Hadoop file with an arbitrary InputFormat */ - def hadoopFile[K, V]( + def hadoopFile[K: ClassTag, V: ClassTag]( path: String, inputFormatClass: Class[_ <: InputFormat[K, V]], keyClass: Class[K], valueClass: Class[V], - minSplits: Int = defaultMinSplits + minSplits: Int = defaultMinSplits, + cloneRecords: Boolean = true ): RDD[(K, V)] = { // A Hadoop configuration can be about 10 KB, which is pretty big, so broadcast it. val confBroadcast = broadcast(new SerializableWritable(hadoopConfiguration)) @@ -371,7 +392,8 @@ class SparkContext( inputFormatClass, keyClass, valueClass, - minSplits) + minSplits, + cloneRecords) } /** @@ -382,14 +404,15 @@ class SparkContext( * val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path, minSplits) * }}} */ - def hadoopFile[K, V, F <: InputFormat[K, V]](path: String, minSplits: Int) - (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]) - : RDD[(K, V)] = { + def hadoopFile[K, V, F <: InputFormat[K, V]] + (path: String, minSplits: Int, cloneRecords: Boolean = true) + (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] = { hadoopFile(path, - fm.runtimeClass.asInstanceOf[Class[F]], - km.runtimeClass.asInstanceOf[Class[K]], - vm.runtimeClass.asInstanceOf[Class[V]], - minSplits) + fm.runtimeClass.asInstanceOf[Class[F]], + km.runtimeClass.asInstanceOf[Class[K]], + vm.runtimeClass.asInstanceOf[Class[V]], + minSplits, + cloneRecords) } /** @@ -400,61 +423,67 @@ class SparkContext( * val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path) * }}} */ - def hadoopFile[K, V, F <: InputFormat[K, V]](path: String) + def hadoopFile[K, V, F <: InputFormat[K, V]](path: String, cloneRecords: Boolean = true) (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] = - hadoopFile[K, V, F](path, defaultMinSplits) + hadoopFile[K, V, F](path, defaultMinSplits, cloneRecords) /** Get an RDD for a Hadoop file with an arbitrary new API InputFormat. */ - def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]](path: String) + def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]] + (path: String, cloneRecords: Boolean = true) (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] = { newAPIHadoopFile( - path, - fm.runtimeClass.asInstanceOf[Class[F]], - km.runtimeClass.asInstanceOf[Class[K]], - vm.runtimeClass.asInstanceOf[Class[V]]) + path, + fm.runtimeClass.asInstanceOf[Class[F]], + km.runtimeClass.asInstanceOf[Class[K]], + vm.runtimeClass.asInstanceOf[Class[V]], + cloneRecords = cloneRecords) } /** * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat * and extra configuration options to pass to the input format. */ - def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]]( + def newAPIHadoopFile[K: ClassTag, V: ClassTag, F <: NewInputFormat[K, V]]( path: String, fClass: Class[F], kClass: Class[K], vClass: Class[V], - conf: Configuration = hadoopConfiguration): RDD[(K, V)] = { + conf: Configuration = hadoopConfiguration, + cloneRecords: Boolean = true): RDD[(K, V)] = { val job = new NewHadoopJob(conf) NewFileInputFormat.addInputPath(job, new Path(path)) val updatedConf = job.getConfiguration - new NewHadoopRDD(this, fClass, kClass, vClass, updatedConf) + new NewHadoopRDD(this, fClass, kClass, vClass, updatedConf, cloneRecords) } /** * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat * and extra configuration options to pass to the input format. */ - def newAPIHadoopRDD[K, V, F <: NewInputFormat[K, V]]( + def newAPIHadoopRDD[K: ClassTag, V: ClassTag, F <: NewInputFormat[K, V]]( conf: Configuration = hadoopConfiguration, fClass: Class[F], kClass: Class[K], - vClass: Class[V]): RDD[(K, V)] = { - new NewHadoopRDD(this, fClass, kClass, vClass, conf) + vClass: Class[V], + cloneRecords: Boolean = true): RDD[(K, V)] = { + new NewHadoopRDD(this, fClass, kClass, vClass, conf, cloneRecords) } /** Get an RDD for a Hadoop SequenceFile with given key and value types. */ - def sequenceFile[K, V](path: String, + def sequenceFile[K: ClassTag, V: ClassTag](path: String, keyClass: Class[K], valueClass: Class[V], - minSplits: Int + minSplits: Int, + cloneRecords: Boolean = true ): RDD[(K, V)] = { val inputFormatClass = classOf[SequenceFileInputFormat[K, V]] - hadoopFile(path, inputFormatClass, keyClass, valueClass, minSplits) + hadoopFile(path, inputFormatClass, keyClass, valueClass, minSplits, cloneRecords) } /** Get an RDD for a Hadoop SequenceFile with given key and value types. */ - def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]): RDD[(K, V)] = - sequenceFile(path, keyClass, valueClass, defaultMinSplits) + def sequenceFile[K: ClassTag, V: ClassTag](path: String, keyClass: Class[K], valueClass: Class[V], + cloneRecords: Boolean = true): RDD[(K, V)] = + sequenceFile(path, keyClass, valueClass, defaultMinSplits, cloneRecords) /** * Version of sequenceFile() for types implicitly convertible to Writables through a @@ -472,17 +501,18 @@ class SparkContext( * for the appropriate type. In addition, we pass the converter a ClassTag of its type to * allow it to figure out the Writable class to use in the subclass case. */ - def sequenceFile[K, V](path: String, minSplits: Int = defaultMinSplits) - (implicit km: ClassTag[K], vm: ClassTag[V], - kcf: () => WritableConverter[K], vcf: () => WritableConverter[V]) + def sequenceFile[K, V] + (path: String, minSplits: Int = defaultMinSplits, cloneRecords: Boolean = true) + (implicit km: ClassTag[K], vm: ClassTag[V], + kcf: () => WritableConverter[K], vcf: () => WritableConverter[V]) : RDD[(K, V)] = { val kc = kcf() val vc = vcf() val format = classOf[SequenceFileInputFormat[Writable, Writable]] val writables = hadoopFile(path, format, kc.writableClass(km).asInstanceOf[Class[Writable]], - vc.writableClass(vm).asInstanceOf[Class[Writable]], minSplits) - writables.map{case (k,v) => (kc.convert(k), vc.convert(v))} + vc.writableClass(vm).asInstanceOf[Class[Writable]], minSplits, cloneRecords) + writables.map { case (k, v) => (kc.convert(k), vc.convert(v)) } } /** @@ -517,15 +547,15 @@ class SparkContext( // Methods for creating shared variables /** - * Create an [[org.apache.spark.Accumulator]] variable of a given type, which tasks can "add" values - * to using the `+=` method. Only the driver can access the accumulator's `value`. + * Create an [[org.apache.spark.Accumulator]] variable of a given type, which tasks can "add" + * values to using the `+=` method. Only the driver can access the accumulator's `value`. */ def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) = new Accumulator(initialValue, param) /** - * Create an [[org.apache.spark.Accumulable]] shared variable, to which tasks can add values with `+=`. - * Only the driver can access the accumuable's `value`. + * Create an [[org.apache.spark.Accumulable]] shared variable, to which tasks can add values + * with `+=`. Only the driver can access the accumuable's `value`. * @tparam T accumulator type * @tparam R type that can be added to the accumulator */ @@ -538,14 +568,16 @@ class SparkContext( * Growable and TraversableOnce are the standard APIs that guarantee += and ++=, implemented by * standard mutable collections. So you can use this with mutable Map, Set, etc. */ - def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable, T](initialValue: R) = { + def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable, T] + (initialValue: R) = { val param = new GrowableAccumulableParam[R,T] new Accumulable(initialValue, param) } /** - * Broadcast a read-only variable to the cluster, returning a [[org.apache.spark.broadcast.Broadcast]] object for - * reading it in distributed functions. The variable will be sent to each cluster only once. + * Broadcast a read-only variable to the cluster, returning a + * [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions. + * The variable will be sent to each cluster only once. */ def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal) @@ -667,10 +699,10 @@ class SparkContext( key = uri.getScheme match { // A JAR file which exists only on the driver node case null | "file" => - if (SparkHadoopUtil.get.isYarnMode()) { - // In order for this to work on yarn the user must specify the --addjars option to - // the client to upload the file into the distributed cache to make it show up in the - // current working directory. + if (SparkHadoopUtil.get.isYarnMode() && master == "yarn-standalone") { + // In order for this to work in yarn standalone mode the user must specify the + // --addjars option to the client to upload the file into the distributed cache + // of the AM to make it show up in the current working directory. val fileName = new Path(uri.getPath).getName() try { env.httpFileServer.addJar(new File(fileName)) @@ -754,8 +786,11 @@ class SparkContext( private[spark] def getCallSite(): String = { val callSite = getLocalProperty("externalCallSite") - if (callSite == null) return Utils.formatSparkCallSite - callSite + if (callSite == null) { + Utils.formatSparkCallSite + } else { + callSite + } } /** @@ -905,7 +940,7 @@ class SparkContext( */ private[spark] def clean[F <: AnyRef](f: F): F = { ClosureCleaner.clean(f) - return f + f } /** @@ -917,7 +952,7 @@ class SparkContext( val path = new Path(dir, UUID.randomUUID().toString) val fs = path.getFileSystem(hadoopConfiguration) fs.mkdirs(path) - fs.getFileStatus(path).getPath().toString + fs.getFileStatus(path).getPath.toString } } @@ -1010,7 +1045,8 @@ object SparkContext { implicit def stringToText(s: String) = new Text(s) - private implicit def arrayToArrayWritable[T <% Writable: ClassTag](arr: Traversable[T]): ArrayWritable = { + private implicit def arrayToArrayWritable[T <% Writable: ClassTag](arr: Traversable[T]) + : ArrayWritable = { def anyToWritable[U <% Writable](u: U): Writable = u new ArrayWritable(classTag[T].runtimeClass.asInstanceOf[Class[Writable]], @@ -1033,7 +1069,9 @@ object SparkContext { implicit def booleanWritableConverter() = simpleWritableConverter[Boolean, BooleanWritable](_.get) - implicit def bytesWritableConverter() = simpleWritableConverter[Array[Byte], BytesWritable](_.getBytes) + implicit def bytesWritableConverter() = { + simpleWritableConverter[Array[Byte], BytesWritable](_.getBytes) + } implicit def stringWritableConverter() = simpleWritableConverter[String, Text](_.toString) @@ -1049,7 +1087,8 @@ object SparkContext { if (uri != null) { val uriStr = uri.toString if (uriStr.startsWith("jar:file:")) { - // URI will be of the form "jar:file:/path/foo.jar!/package/cls.class", so pull out the /path/foo.jar + // URI will be of the form "jar:file:/path/foo.jar!/package/cls.class", + // so pull out the /path/foo.jar List(uriStr.substring("jar:file:".length, uriStr.indexOf('!'))) } else { Nil @@ -1072,7 +1111,7 @@ object SparkContext { * parameters that are passed as the default value of null, instead of throwing an exception * like SparkConf would. */ - private def updatedConf( + private[spark] def updatedConf( conf: SparkConf, master: String, appName: String, @@ -1203,7 +1242,7 @@ object SparkContext { case mesosUrl @ MESOS_REGEX(_) => MesosNativeLibrary.load() val scheduler = new TaskSchedulerImpl(sc) - val coarseGrained = sc.conf.get("spark.mesos.coarse", "false").toBoolean + val coarseGrained = sc.conf.getBoolean("spark.mesos.coarse", false) val url = mesosUrl.stripPrefix("mesos://") // strip scheme from raw Mesos URLs val backend = if (coarseGrained) { new CoarseMesosSchedulerBackend(scheduler, sc, url, appName) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 2e36ccb9a0f076b92c59ac00793b91237320f5a2..ed788560e79f17c2e22c6a9527f4b9b29513124a 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -54,7 +54,11 @@ class SparkEnv private[spark] ( val httpFileServer: HttpFileServer, val sparkFilesDir: String, val metricsSystem: MetricsSystem, - val conf: SparkConf) { + val conf: SparkConf) extends Logging { + + // A mapping of thread ID to amount of memory used for shuffle in bytes + // All accesses should be manually synchronized + val shuffleMemoryMap = mutable.HashMap[Long, Long]() private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() @@ -128,16 +132,6 @@ object SparkEnv extends Logging { conf.set("spark.driver.port", boundPort.toString) } - // set only if unset until now. - if (!conf.contains("spark.hostPort")) { - if (!isDriver){ - // unexpected - Utils.logErrorWithStack("Unexpected NOT to have spark.hostPort set") - } - Utils.checkHost(hostname) - conf.set("spark.hostPort", hostname + ":" + boundPort) - } - val classLoader = Thread.currentThread.getContextClassLoader // Create an instance of the class named by the given Java system property, or by @@ -162,7 +156,7 @@ object SparkEnv extends Logging { actorSystem.actorOf(Props(newActor), name = name) } else { val driverHost: String = conf.get("spark.driver.host", "localhost") - val driverPort: Int = conf.get("spark.driver.port", "7077").toInt + val driverPort: Int = conf.getInt("spark.driver.port", 7077) Utils.checkHost(driverHost, "Expected hostname") val url = s"akka.tcp://spark@$driverHost:$driverPort/user/$name" val timeout = AkkaUtils.lookupTimeout(conf) diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index 618d95015f7475253614dfcfab8cb418bde91a80..4e63117a5133461991efac6136949c85b31fee20 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -134,28 +134,28 @@ class SparkHadoopWriter(@transient jobConf: JobConf) format = conf.value.getOutputFormat() .asInstanceOf[OutputFormat[AnyRef,AnyRef]] } - return format + format } private def getOutputCommitter(): OutputCommitter = { if (committer == null) { committer = conf.value.getOutputCommitter } - return committer + committer } private def getJobContext(): JobContext = { if (jobContext == null) { jobContext = newJobContext(conf.value, jID.value) } - return jobContext + jobContext } private def getTaskContext(): TaskAttemptContext = { if (taskContext == null) { taskContext = newTaskAttemptContext(conf.value, taID.value) } - return taskContext + taskContext } private def setIDs(jobid: Int, splitid: Int, attemptid: Int) { @@ -182,19 +182,18 @@ object SparkHadoopWriter { def createJobID(time: Date, id: Int): JobID = { val formatter = new SimpleDateFormat("yyyyMMddHHmm") val jobtrackerID = formatter.format(new Date()) - return new JobID(jobtrackerID, id) + new JobID(jobtrackerID, id) } def createPathFromString(path: String, conf: JobConf): Path = { if (path == null) { throw new IllegalArgumentException("Output path is null") } - var outputPath = new Path(path) + val outputPath = new Path(path) val fs = outputPath.getFileSystem(conf) if (outputPath == null || fs == null) { throw new IllegalArgumentException("Incorrectly formatted output path") } - outputPath = outputPath.makeQualified(fs) - return outputPath + outputPath.makeQualified(fs) } } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index da30cf619a1d0ecfabf501faecc0e2b0f0a64738..b0dedc6f4eb135f5751414921d21eb4fd2110080 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -207,13 +207,13 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav * e.g. for the array * [1,10,20,50] the buckets are [1,10) [10,20) [20,50] * e.g 1<=x<10 , 10<=x<20, 20<=x<50 - * And on the input of 1 and 50 we would have a histogram of 1,0,0 - * + * And on the input of 1 and 50 we would have a histogram of 1,0,0 + * * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched * from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets * to true. * buckets must be sorted and not contain any duplicates. - * buckets array must be at least two elements + * buckets array must be at least two elements * All NaN entries are treated the same. If you have a NaN bucket it must be * the maximum value of the last position and all NaN entries will be counted * in that bucket. @@ -225,6 +225,12 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav def histogram(buckets: Array[Double], evenBuckets: Boolean): Array[Long] = { srdd.histogram(buckets.map(_.toDouble), evenBuckets) } + + /** Assign a name to this RDD */ + def setName(name: String): JavaDoubleRDD = { + srdd.setName(name) + this + } } object JavaDoubleRDD { diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 55c87450ac65ae5256cba054cf4ce7684773c592..0fb7e195b34c4a9b530e2f718fa98f86d391303d 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -647,6 +647,12 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kClassTag: ClassTag[K def countApproxDistinctByKey(relativeSD: Double, numPartitions: Int): JavaRDD[(K, Long)] = { rdd.countApproxDistinctByKey(relativeSD, numPartitions) } + + /** Assign a name to this RDD */ + def setName(name: String): JavaPairRDD[K, V] = { + rdd.setName(name) + this + } } object JavaPairRDD { diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index 037cd1c774691e6cd2122d03dea488d1a57f0517..7d48ce01cf2cc9606c786e14018ad8bacf2f7876 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -127,6 +127,12 @@ JavaRDDLike[T, JavaRDD[T]] { wrapRDD(rdd.subtract(other, p)) override def toString = rdd.toString + + /** Assign a name to this RDD */ + def setName(name: String): JavaRDD[T] = { + rdd.setName(name) + this + } } object JavaRDD { diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 924d8af0602f4a7fa8e33b637ebd5af23059ad81..ebbbbd88061a1df7c205638f0872d7b9f724e107 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -244,6 +244,11 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { new java.util.ArrayList(arr) } + /** + * Return an array that contains all of the elements in this RDD. + */ + def toArray(): JList[T] = collect() + /** * Return an array that contains all of the elements in a specific partition of this RDD. */ @@ -455,4 +460,5 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def countApproxDistinct(relativeSD: Double = 0.05): Long = rdd.countApproxDistinct(relativeSD) + def name(): String = rdd.name } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index e93b10fd7eecb71b59908a54a765d0e23b7da7b1..7a6f044965027d26bd7af468eb225784c76ccc31 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -425,6 +425,51 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork def clearCallSite() { sc.clearCallSite() } + + /** + * Set a local property that affects jobs submitted from this thread, such as the + * Spark fair scheduler pool. + */ + def setLocalProperty(key: String, value: String): Unit = sc.setLocalProperty(key, value) + + /** + * Get a local property set in this thread, or null if it is missing. See + * [[org.apache.spark.api.java.JavaSparkContext.setLocalProperty]]. + */ + def getLocalProperty(key: String): String = sc.getLocalProperty(key) + + /** + * Assigns a group ID to all the jobs started by this thread until the group ID is set to a + * different value or cleared. + * + * Often, a unit of execution in an application consists of multiple Spark actions or jobs. + * Application programmers can use this method to group all those jobs together and give a + * group description. Once set, the Spark web UI will associate such jobs with this group. + * + * The application can also use [[org.apache.spark.api.java.JavaSparkContext.cancelJobGroup]] + * to cancel all running jobs in this group. For example, + * {{{ + * // In the main thread: + * sc.setJobGroup("some_job_to_cancel", "some job description"); + * rdd.map(...).count(); + * + * // In a separate thread: + * sc.cancelJobGroup("some_job_to_cancel"); + * }}} + */ + def setJobGroup(groupId: String, description: String): Unit = sc.setJobGroup(groupId, description) + + /** Clear the current thread's job group ID and its description. */ + def clearJobGroup(): Unit = sc.clearJobGroup() + + /** + * Cancel active jobs for the specified group. See + * [[org.apache.spark.api.java.JavaSparkContext.setJobGroup]] for more information. + */ + def cancelJobGroup(groupId: String): Unit = sc.cancelJobGroup(groupId) + + /** Cancel all jobs that have been scheduled or are running. */ + def cancelAllJobs(): Unit = sc.cancelAllJobs() } object JavaSparkContext { @@ -436,5 +481,12 @@ object JavaSparkContext { * Find the JAR from which a given class was loaded, to make it easy for users to pass * their JARs to SparkContext. */ - def jarOfClass(cls: Class[_]) = SparkContext.jarOfClass(cls).toArray + def jarOfClass(cls: Class[_]): Array[String] = SparkContext.jarOfClass(cls).toArray + + /** + * Find the JAR that contains the class of a particular object, to make it easy for users + * to pass their JARs to SparkContext. In most cases you can call jarOfObject(this) in + * your driver program. + */ + def jarOfObject(obj: AnyRef): Array[String] = SparkContext.jarOfObject(obj).toArray } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 32cc70e8c9dda4d98e44b8d5165ad5375507e460..82527fe6638482248e590fada0fd4ad78e46c313 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -41,7 +41,7 @@ private[spark] class PythonRDD[T: ClassTag]( accumulator: Accumulator[JList[Array[Byte]]]) extends RDD[Array[Byte]](parent) { - val bufferSize = conf.get("spark.buffer.size", "65536").toInt + val bufferSize = conf.getInt("spark.buffer.size", 65536) override def getPartitions = parent.partitions @@ -95,7 +95,7 @@ private[spark] class PythonRDD[T: ClassTag]( // Return an iterator that read lines from the process's stdout val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) - return new Iterator[Array[Byte]] { + val stdoutIterator = new Iterator[Array[Byte]] { def next(): Array[Byte] = { val obj = _nextObj if (hasNext) { @@ -156,6 +156,7 @@ private[spark] class PythonRDD[T: ClassTag]( def hasNext = _nextObj.length != 0 } + stdoutIterator } val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) @@ -250,7 +251,7 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort: Utils.checkHost(serverHost, "Expected hostname") - val bufferSize = SparkEnv.get.conf.get("spark.buffer.size", "65536").toInt + val bufferSize = SparkEnv.get.conf.getInt("spark.buffer.size", 65536) override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index db596d5fcc05413debab51d1ca9dd8dca0321090..0eacda3d7dc2b4422bd0a46f1b54078c95c2a431 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -92,8 +92,8 @@ private object HttpBroadcast extends Logging { def initialize(isDriver: Boolean, conf: SparkConf) { synchronized { if (!initialized) { - bufferSize = conf.get("spark.buffer.size", "65536").toInt - compress = conf.get("spark.broadcast.compress", "true").toBoolean + bufferSize = conf.getInt("spark.buffer.size", 65536) + compress = conf.getBoolean("spark.broadcast.compress", true) if (isDriver) { createServer(conf) conf.set("spark.httpBroadcast.uri", serverUri) diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 95309382786a902fceedfeac5e712e8ca021b952..1d295c62bcb6c2bb82b1685098641f0b823d11fb 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -180,7 +180,7 @@ extends Logging { initialized = false } - lazy val BLOCK_SIZE = conf.get("spark.broadcast.blockSize", "4096").toInt * 1024 + lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024 def blockifyObject[T](obj: T): TorrentInfo = { val byteArray = Utils.serialize[T](obj) @@ -203,16 +203,16 @@ extends Logging { } bais.close() - var tInfo = TorrentInfo(retVal, blockNum, byteArray.length) + val tInfo = TorrentInfo(retVal, blockNum, byteArray.length) tInfo.hasBlocks = blockNum - return tInfo + tInfo } def unBlockifyObject[T](arrayOfBlocks: Array[TorrentBlock], totalBytes: Int, totalBlocks: Int): T = { - var retByteArray = new Array[Byte](totalBytes) + val retByteArray = new Array[Byte](totalBytes) for (i <- 0 until totalBlocks) { System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray, i * BLOCK_SIZE, arrayOfBlocks(i).byteArray.length) diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala new file mode 100644 index 0000000000000000000000000000000000000000..e133893f6ca5bab05fea99d4319265fa4c43db74 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import scala.collection.JavaConversions._ +import scala.collection.mutable.Map +import scala.concurrent._ + +import akka.actor._ +import akka.pattern.ask +import org.apache.log4j.{Level, Logger} + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.deploy.DeployMessages._ +import org.apache.spark.deploy.master.{DriverState, Master} +import org.apache.spark.util.{AkkaUtils, Utils} +import akka.actor.Actor.emptyBehavior +import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycleEvent} + +/** + * Proxy that relays messages to the driver. + */ +class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends Actor with Logging { + var masterActor: ActorSelection = _ + val timeout = AkkaUtils.askTimeout(conf) + + override def preStart() = { + masterActor = context.actorSelection(Master.toAkkaUrl(driverArgs.master)) + + context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + + println(s"Sending ${driverArgs.cmd} command to ${driverArgs.master}") + + driverArgs.cmd match { + case "launch" => + // TODO: We could add an env variable here and intercept it in `sc.addJar` that would + // truncate filesystem paths similar to what YARN does. For now, we just require + // people call `addJar` assuming the jar is in the same directory. + val env = Map[String, String]() + System.getenv().foreach{case (k, v) => env(k) = v} + + val mainClass = "org.apache.spark.deploy.worker.DriverWrapper" + val command = new Command(mainClass, Seq("{{WORKER_URL}}", driverArgs.mainClass) ++ + driverArgs.driverOptions, env) + + val driverDescription = new DriverDescription( + driverArgs.jarUrl, + driverArgs.memory, + driverArgs.cores, + driverArgs.supervise, + command) + + masterActor ! RequestSubmitDriver(driverDescription) + + case "kill" => + val driverId = driverArgs.driverId + val killFuture = masterActor ! RequestKillDriver(driverId) + } + } + + /* Find out driver status then exit the JVM */ + def pollAndReportStatus(driverId: String) { + println(s"... waiting before polling master for driver state") + Thread.sleep(5000) + println("... polling master for driver state") + val statusFuture = (masterActor ? RequestDriverStatus(driverId))(timeout) + .mapTo[DriverStatusResponse] + val statusResponse = Await.result(statusFuture, timeout) + + statusResponse.found match { + case false => + println(s"ERROR: Cluster master did not recognize $driverId") + System.exit(-1) + case true => + println(s"State of $driverId is ${statusResponse.state.get}") + // Worker node, if present + (statusResponse.workerId, statusResponse.workerHostPort, statusResponse.state) match { + case (Some(id), Some(hostPort), Some(DriverState.RUNNING)) => + println(s"Driver running on $hostPort ($id)") + case _ => + } + // Exception, if present + statusResponse.exception.map { e => + println(s"Exception from cluster was: $e") + System.exit(-1) + } + System.exit(0) + } + } + + override def receive = { + + case SubmitDriverResponse(success, driverId, message) => + println(message) + if (success) pollAndReportStatus(driverId.get) else System.exit(-1) + + case KillDriverResponse(driverId, success, message) => + println(message) + if (success) pollAndReportStatus(driverId) else System.exit(-1) + + case DisassociatedEvent(_, remoteAddress, _) => + println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") + System.exit(-1) + + case AssociationErrorEvent(cause, _, remoteAddress, _) => + println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") + println(s"Cause was: $cause") + System.exit(-1) + } +} + +/** + * Executable utility for starting and terminating drivers inside of a standalone cluster. + */ +object Client { + def main(args: Array[String]) { + val conf = new SparkConf() + val driverArgs = new ClientArguments(args) + + if (!driverArgs.logLevel.isGreaterOrEqual(Level.WARN)) { + conf.set("spark.akka.logLifecycleEvents", "true") + } + conf.set("spark.akka.askTimeout", "10") + conf.set("akka.loglevel", driverArgs.logLevel.toString.replace("WARN", "WARNING")) + Logger.getRootLogger.setLevel(driverArgs.logLevel) + + // TODO: See if we can initialize akka so return messages are sent back using the same TCP + // flow. Else, this (sadly) requires the DriverClient be routable from the Master. + val (actorSystem, _) = AkkaUtils.createActorSystem( + "driverClient", Utils.localHostName(), 0, false, conf) + + actorSystem.actorOf(Props(classOf[ClientActor], driverArgs, conf)) + + actorSystem.awaitTermination() + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala new file mode 100644 index 0000000000000000000000000000000000000000..db67c6d1bb55c58069a23a7d2d3abb01a4d2a1ea --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.net.URL + +import scala.collection.mutable.ListBuffer + +import org.apache.log4j.Level + +/** + * Command-line parser for the driver client. + */ +private[spark] class ClientArguments(args: Array[String]) { + val defaultCores = 1 + val defaultMemory = 512 + + var cmd: String = "" // 'launch' or 'kill' + var logLevel = Level.WARN + + // launch parameters + var master: String = "" + var jarUrl: String = "" + var mainClass: String = "" + var supervise: Boolean = false + var memory: Int = defaultMemory + var cores: Int = defaultCores + private var _driverOptions = ListBuffer[String]() + def driverOptions = _driverOptions.toSeq + + // kill parameters + var driverId: String = "" + + parse(args.toList) + + def parse(args: List[String]): Unit = args match { + case ("--cores" | "-c") :: value :: tail => + cores = value.toInt + parse(tail) + + case ("--memory" | "-m") :: value :: tail => + memory = value.toInt + parse(tail) + + case ("--supervise" | "-s") :: tail => + supervise = true + parse(tail) + + case ("--help" | "-h") :: tail => + printUsageAndExit(0) + + case ("--verbose" | "-v") :: tail => + logLevel = Level.INFO + parse(tail) + + case "launch" :: _master :: _jarUrl :: _mainClass :: tail => + cmd = "launch" + + try { + new URL(_jarUrl) + } catch { + case e: Exception => + println(s"Jar url '${_jarUrl}' is not a valid URL.") + println(s"Jar must be in URL format (e.g. hdfs://XX, file://XX)") + printUsageAndExit(-1) + } + + jarUrl = _jarUrl + master = _master + mainClass = _mainClass + _driverOptions ++= tail + + case "kill" :: _master :: _driverId :: tail => + cmd = "kill" + master = _master + driverId = _driverId + + case _ => + printUsageAndExit(1) + } + + /** + * Print usage and exit JVM with the given exit code. + */ + def printUsageAndExit(exitCode: Int) { + // TODO: It wouldn't be too hard to allow users to submit their app and dependency jars + // separately similar to in the YARN client. + val usage = + s""" + |Usage: DriverClient [options] launch <active-master> <jar-url> <main-class> [driver options] + |Usage: DriverClient kill <active-master> <driver-id> + | + |Options: + | -c CORES, --cores CORES Number of cores to request (default: $defaultCores) + | -m MEMORY, --memory MEMORY Megabytes of memory to request (default: $defaultMemory) + | -s, --supervise Whether to restart the driver on failure + | -v, --verbose Print more debugging output + """.stripMargin + System.err.println(usage) + System.exit(exitCode) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index 275331724afba010988baf082c6364fbcce9b5bf..5e824e1a678b647c1ef95387ae8674736e6565cd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -20,12 +20,12 @@ package org.apache.spark.deploy import scala.collection.immutable.List import org.apache.spark.deploy.ExecutorState.ExecutorState -import org.apache.spark.deploy.master.{WorkerInfo, ApplicationInfo} +import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo} +import org.apache.spark.deploy.master.DriverState.DriverState import org.apache.spark.deploy.master.RecoveryState.MasterState -import org.apache.spark.deploy.worker.ExecutorRunner +import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} import org.apache.spark.util.Utils - private[deploy] sealed trait DeployMessage extends Serializable /** Contains messages sent between Scheduler actor nodes. */ @@ -54,7 +54,14 @@ private[deploy] object DeployMessages { exitStatus: Option[Int]) extends DeployMessage - case class WorkerSchedulerStateResponse(id: String, executors: List[ExecutorDescription]) + case class DriverStateChanged( + driverId: String, + state: DriverState, + exception: Option[Exception]) + extends DeployMessage + + case class WorkerSchedulerStateResponse(id: String, executors: List[ExecutorDescription], + driverIds: Seq[String]) case class Heartbeat(workerId: String) extends DeployMessage @@ -76,14 +83,18 @@ private[deploy] object DeployMessages { sparkHome: String) extends DeployMessage - // Client to Master + case class LaunchDriver(driverId: String, driverDesc: DriverDescription) extends DeployMessage + + case class KillDriver(driverId: String) extends DeployMessage + + // AppClient to Master case class RegisterApplication(appDescription: ApplicationDescription) extends DeployMessage case class MasterChangeAcknowledged(appId: String) - // Master to Client + // Master to AppClient case class RegisteredApplication(appId: String, masterUrl: String) extends DeployMessage @@ -97,11 +108,28 @@ private[deploy] object DeployMessages { case class ApplicationRemoved(message: String) - // Internal message in Client + // DriverClient <-> Master + + case class RequestSubmitDriver(driverDescription: DriverDescription) extends DeployMessage + + case class SubmitDriverResponse(success: Boolean, driverId: Option[String], message: String) + extends DeployMessage + + case class RequestKillDriver(driverId: String) extends DeployMessage + + case class KillDriverResponse(driverId: String, success: Boolean, message: String) + extends DeployMessage + + case class RequestDriverStatus(driverId: String) extends DeployMessage + + case class DriverStatusResponse(found: Boolean, state: Option[DriverState], + workerId: Option[String], workerHostPort: Option[String], exception: Option[Exception]) + + // Internal message in AppClient - case object StopClient + case object StopAppClient - // Master to Worker & Client + // Master to Worker & AppClient case class MasterChanged(masterUrl: String, masterWebUiUrl: String) @@ -113,6 +141,7 @@ private[deploy] object DeployMessages { case class MasterStateResponse(host: String, port: Int, workers: Array[WorkerInfo], activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo], + activeDrivers: Array[DriverInfo], completedDrivers: Array[DriverInfo], status: MasterState) { Utils.checkHost(host, "Required hostname") @@ -128,14 +157,15 @@ private[deploy] object DeployMessages { // Worker to WorkerWebUI case class WorkerStateResponse(host: String, port: Int, workerId: String, - executors: List[ExecutorRunner], finishedExecutors: List[ExecutorRunner], masterUrl: String, + executors: List[ExecutorRunner], finishedExecutors: List[ExecutorRunner], + drivers: List[DriverRunner], finishedDrivers: List[DriverRunner], masterUrl: String, cores: Int, memory: Int, coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String) { Utils.checkHost(host, "Required hostname") assert (port > 0) } - // Actor System to Worker + // Liveness checks in various places case object SendHeartbeat } diff --git a/core/src/main/scala/org/apache/spark/deploy/DriverDescription.scala b/core/src/main/scala/org/apache/spark/deploy/DriverDescription.scala new file mode 100644 index 0000000000000000000000000000000000000000..58c95dc4f9116ae2a1eefd0e9773ec78c631deaa --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/DriverDescription.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +private[spark] class DriverDescription( + val jarUrl: String, + val mem: Int, + val cores: Int, + val supervise: Boolean, + val command: Command) + extends Serializable { + + override def toString: String = s"DriverDescription (${command.mainClass})" +} diff --git a/core/src/main/scala/org/apache/spark/deploy/client/Client.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala similarity index 94% rename from core/src/main/scala/org/apache/spark/deploy/client/Client.scala rename to core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 481026eaa2106e2e4310067906e9adf18c65f479..1415e2f3d1886974cdc6a725b973ffbbe6ccc7e8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -33,16 +33,17 @@ import org.apache.spark.deploy.master.Master import org.apache.spark.util.AkkaUtils /** - * The main class used to talk to a Spark deploy cluster. Takes a master URL, an app description, - * and a listener for cluster events, and calls back the listener when various events occur. + * Interface allowing applications to speak with a Spark deploy cluster. Takes a master URL, + * an app description, and a listener for cluster events, and calls back the listener when various + * events occur. * * @param masterUrls Each url should look like spark://host:port. */ -private[spark] class Client( +private[spark] class AppClient( actorSystem: ActorSystem, masterUrls: Array[String], appDescription: ApplicationDescription, - listener: ClientListener, + listener: AppClientListener, conf: SparkConf) extends Logging { @@ -155,7 +156,7 @@ private[spark] class Client( case AssociationErrorEvent(cause, _, address, _) if isPossibleMaster(address) => logWarning(s"Could not connect to $address: $cause") - case StopClient => + case StopAppClient => markDead() sender ! true context.stop(self) @@ -188,7 +189,7 @@ private[spark] class Client( if (actor != null) { try { val timeout = AkkaUtils.askTimeout(conf) - val future = actor.ask(StopClient)(timeout) + val future = actor.ask(StopAppClient)(timeout) Await.result(future, timeout) } catch { case e: TimeoutException => diff --git a/core/src/main/scala/org/apache/spark/deploy/client/ClientListener.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClientListener.scala similarity index 97% rename from core/src/main/scala/org/apache/spark/deploy/client/ClientListener.scala rename to core/src/main/scala/org/apache/spark/deploy/client/AppClientListener.scala index be7a11bd1553724376d334245bf14a30d6a06a35..55d4ef1b31aaacbe49878ba4bb3cabebfc7946a2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/ClientListener.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClientListener.scala @@ -24,7 +24,7 @@ package org.apache.spark.deploy.client * * Users of this API should *not* block inside the callback methods. */ -private[spark] trait ClientListener { +private[spark] trait AppClientListener { def connected(appId: String): Unit /** Disconnection may be a temporary state, as we fail over to a new Master. */ diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala index 28ebbdc66bbb885f9b7aa8e545852f8c4a18d6ee..ffa909c26b64aa26c725b57cf9dfc80a7cd749fc 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala @@ -23,7 +23,7 @@ import org.apache.spark.deploy.{Command, ApplicationDescription} private[spark] object TestClient { - class TestListener extends ClientListener with Logging { + class TestListener extends AppClientListener with Logging { def connected(id: String) { logInfo("Connected to master, got app ID " + id) } @@ -51,7 +51,7 @@ private[spark] object TestClient { "TestClient", Some(1), 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home", "ignored") val listener = new TestListener - val client = new Client(actorSystem, Array(url), desc, listener, new SparkConf) + val client = new AppClient(actorSystem, Array(url), desc, listener, new SparkConf) client.start() actorSystem.awaitTermination() } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala new file mode 100644 index 0000000000000000000000000000000000000000..33377931d69931eb822af6a3183bc117e3ebdb3a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.master + +import java.util.Date + +import org.apache.spark.deploy.DriverDescription + +private[spark] class DriverInfo( + val startTime: Long, + val id: String, + val desc: DriverDescription, + val submitDate: Date) + extends Serializable { + + @transient var state: DriverState.Value = DriverState.SUBMITTED + /* If we fail when launching the driver, the exception is stored here. */ + @transient var exception: Option[Exception] = None + /* Most recent worker assigned to this driver */ + @transient var worker: Option[WorkerInfo] = None +} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala b/core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala new file mode 100644 index 0000000000000000000000000000000000000000..26a68bade3c60eb5003b9a1fc0dfab4485cb6dc7 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.master + +private[spark] object DriverState extends Enumeration { + + type DriverState = Value + + // SUBMITTED: Submitted but not yet scheduled on a worker + // RUNNING: Has been allocated to a worker to run + // FINISHED: Previously ran and exited cleanly + // RELAUNCHING: Exited non-zero or due to worker failure, but has not yet started running again + // UNKNOWN: The state of the driver is temporarily not known due to master failure recovery + // KILLED: A user manually killed this driver + // FAILED: The driver exited non-zero and was not supervised + // ERROR: Unable to run or restart due to an unrecoverable error (e.g. missing jar file) + val SUBMITTED, RUNNING, FINISHED, RELAUNCHING, UNKNOWN, KILLED, FAILED, ERROR = Value +} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala index 043945a211f26948d9ad7a3d99b80951acb5b23a..74bb9ebf1db4a3e52fedb4c14db7cf63d9f6d22f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala @@ -19,8 +19,6 @@ package org.apache.spark.deploy.master import java.io._ -import scala.Serializable - import akka.serialization.Serialization import org.apache.spark.Logging @@ -47,6 +45,15 @@ private[spark] class FileSystemPersistenceEngine( new File(dir + File.separator + "app_" + app.id).delete() } + override def addDriver(driver: DriverInfo) { + val driverFile = new File(dir + File.separator + "driver_" + driver.id) + serializeIntoFile(driverFile, driver) + } + + override def removeDriver(driver: DriverInfo) { + new File(dir + File.separator + "driver_" + driver.id).delete() + } + override def addWorker(worker: WorkerInfo) { val workerFile = new File(dir + File.separator + "worker_" + worker.id) serializeIntoFile(workerFile, worker) @@ -56,13 +63,15 @@ private[spark] class FileSystemPersistenceEngine( new File(dir + File.separator + "worker_" + worker.id).delete() } - override def readPersistedData(): (Seq[ApplicationInfo], Seq[WorkerInfo]) = { + override def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = { val sortedFiles = new File(dir).listFiles().sortBy(_.getName) val appFiles = sortedFiles.filter(_.getName.startsWith("app_")) val apps = appFiles.map(deserializeFromFile[ApplicationInfo]) + val driverFiles = sortedFiles.filter(_.getName.startsWith("driver_")) + val drivers = driverFiles.map(deserializeFromFile[DriverInfo]) val workerFiles = sortedFiles.filter(_.getName.startsWith("worker_")) val workers = workerFiles.map(deserializeFromFile[WorkerInfo]) - (apps, workers) + (apps, drivers, workers) } private def serializeIntoFile(file: File, value: AnyRef) { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 6617b7100f44bfa5d33bf25b88fca04bd170470e..d9ea96afcf52a2e4719f571c1c27e5dbf91c6389 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -23,19 +23,22 @@ import java.util.Date import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.concurrent.Await import scala.concurrent.duration._ +import scala.util.Random import akka.actor._ import akka.pattern.ask import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import akka.serialization.SerializationExtension -import org.apache.spark.{SparkConf, SparkContext, Logging, SparkException} -import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} + +import org.apache.spark.{SparkConf, Logging, SparkException} +import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.MasterMessages._ import org.apache.spark.deploy.master.ui.MasterWebUI import org.apache.spark.metrics.MetricsSystem import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.deploy.master.DriverState.DriverState private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging { import context.dispatcher // to use Akka's scheduler.schedule() @@ -43,13 +46,12 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act val conf = new SparkConf val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs - val WORKER_TIMEOUT = conf.get("spark.worker.timeout", "60").toLong * 1000 - val RETAINED_APPLICATIONS = conf.get("spark.deploy.retainedApplications", "200").toInt - val REAPER_ITERATIONS = conf.get("spark.dead.worker.persistence", "15").toInt + val WORKER_TIMEOUT = conf.getLong("spark.worker.timeout", 60) * 1000 + val RETAINED_APPLICATIONS = conf.getInt("spark.deploy.retainedApplications", 200) + val REAPER_ITERATIONS = conf.getInt("spark.dead.worker.persistence", 15) val RECOVERY_DIR = conf.get("spark.deploy.recoveryDirectory", "") val RECOVERY_MODE = conf.get("spark.deploy.recoveryMode", "NONE") - var nextAppNumber = 0 val workers = new HashSet[WorkerInfo] val idToWorker = new HashMap[String, WorkerInfo] val actorToWorker = new HashMap[ActorRef, WorkerInfo] @@ -59,9 +61,14 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act val idToApp = new HashMap[String, ApplicationInfo] val actorToApp = new HashMap[ActorRef, ApplicationInfo] val addressToApp = new HashMap[Address, ApplicationInfo] - val waitingApps = new ArrayBuffer[ApplicationInfo] val completedApps = new ArrayBuffer[ApplicationInfo] + var nextAppNumber = 0 + + val drivers = new HashSet[DriverInfo] + val completedDrivers = new ArrayBuffer[DriverInfo] + val waitingDrivers = new ArrayBuffer[DriverInfo] // Drivers currently spooled for scheduling + var nextDriverNumber = 0 Utils.checkHost(host, "Expected hostname") @@ -142,14 +149,14 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act override def receive = { case ElectedLeader => { - val (storedApps, storedWorkers) = persistenceEngine.readPersistedData() - state = if (storedApps.isEmpty && storedWorkers.isEmpty) + val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData() + state = if (storedApps.isEmpty && storedDrivers.isEmpty && storedWorkers.isEmpty) RecoveryState.ALIVE else RecoveryState.RECOVERING logInfo("I have been elected leader! New state: " + state) if (state == RecoveryState.RECOVERING) { - beginRecovery(storedApps, storedWorkers) + beginRecovery(storedApps, storedDrivers, storedWorkers) context.system.scheduler.scheduleOnce(WORKER_TIMEOUT millis) { completeRecovery() } } } @@ -176,6 +183,69 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act } } + case RequestSubmitDriver(description) => { + if (state != RecoveryState.ALIVE) { + val msg = s"Can only accept driver submissions in ALIVE state. Current state: $state." + sender ! SubmitDriverResponse(false, None, msg) + } else { + logInfo("Driver submitted " + description.command.mainClass) + val driver = createDriver(description) + persistenceEngine.addDriver(driver) + waitingDrivers += driver + drivers.add(driver) + schedule() + + // TODO: It might be good to instead have the submission client poll the master to determine + // the current status of the driver. For now it's simply "fire and forget". + + sender ! SubmitDriverResponse(true, Some(driver.id), + s"Driver successfully submitted as ${driver.id}") + } + } + + case RequestKillDriver(driverId) => { + if (state != RecoveryState.ALIVE) { + val msg = s"Can only kill drivers in ALIVE state. Current state: $state." + sender ! KillDriverResponse(driverId, success = false, msg) + } else { + logInfo("Asked to kill driver " + driverId) + val driver = drivers.find(_.id == driverId) + driver match { + case Some(d) => + if (waitingDrivers.contains(d)) { + waitingDrivers -= d + self ! DriverStateChanged(driverId, DriverState.KILLED, None) + } + else { + // We just notify the worker to kill the driver here. The final bookkeeping occurs + // on the return path when the worker submits a state change back to the master + // to notify it that the driver was successfully killed. + d.worker.foreach { w => + w.actor ! KillDriver(driverId) + } + } + // TODO: It would be nice for this to be a synchronous response + val msg = s"Kill request for $driverId submitted" + logInfo(msg) + sender ! KillDriverResponse(driverId, success = true, msg) + case None => + val msg = s"Driver $driverId has already finished or does not exist" + logWarning(msg) + sender ! KillDriverResponse(driverId, success = false, msg) + } + } + } + + case RequestDriverStatus(driverId) => { + (drivers ++ completedDrivers).find(_.id == driverId) match { + case Some(driver) => + sender ! DriverStatusResponse(found = true, Some(driver.state), + driver.worker.map(_.id), driver.worker.map(_.hostPort), driver.exception) + case None => + sender ! DriverStatusResponse(found = false, None, None, None, None) + } + } + case RegisterApplication(description) => { if (state == RecoveryState.STANDBY) { // ignore, don't send response @@ -218,6 +288,15 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act } } + case DriverStateChanged(driverId, state, exception) => { + state match { + case DriverState.ERROR | DriverState.FINISHED | DriverState.KILLED | DriverState.FAILED => + removeDriver(driverId, state, exception) + case _ => + throw new Exception(s"Received unexpected state update for driver $driverId: $state") + } + } + case Heartbeat(workerId) => { idToWorker.get(workerId) match { case Some(workerInfo) => @@ -239,7 +318,7 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act if (canCompleteRecovery) { completeRecovery() } } - case WorkerSchedulerStateResponse(workerId, executors) => { + case WorkerSchedulerStateResponse(workerId, executors, driverIds) => { idToWorker.get(workerId) match { case Some(worker) => logInfo("Worker has been re-registered: " + workerId) @@ -252,6 +331,14 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act worker.addExecutor(execInfo) execInfo.copyState(exec) } + + for (driverId <- driverIds) { + drivers.find(_.id == driverId).foreach { driver => + driver.worker = Some(worker) + driver.state = DriverState.RUNNING + worker.drivers(driverId) = driver + } + } case None => logWarning("Scheduler state from unknown worker: " + workerId) } @@ -269,7 +356,7 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act case RequestMasterState => { sender ! MasterStateResponse(host, port, workers.toArray, apps.toArray, completedApps.toArray, - state) + drivers.toArray, completedDrivers.toArray, state) } case CheckForWorkerTimeOut => { @@ -285,7 +372,8 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act workers.count(_.state == WorkerState.UNKNOWN) == 0 && apps.count(_.state == ApplicationState.UNKNOWN) == 0 - def beginRecovery(storedApps: Seq[ApplicationInfo], storedWorkers: Seq[WorkerInfo]) { + def beginRecovery(storedApps: Seq[ApplicationInfo], storedDrivers: Seq[DriverInfo], + storedWorkers: Seq[WorkerInfo]) { for (app <- storedApps) { logInfo("Trying to recover app: " + app.id) try { @@ -297,6 +385,12 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act } } + for (driver <- storedDrivers) { + // Here we just read in the list of drivers. Any drivers associated with now-lost workers + // will be re-launched when we detect that the worker is missing. + drivers += driver + } + for (worker <- storedWorkers) { logInfo("Trying to recover worker: " + worker.id) try { @@ -320,6 +414,18 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act workers.filter(_.state == WorkerState.UNKNOWN).foreach(removeWorker) apps.filter(_.state == ApplicationState.UNKNOWN).foreach(finishApplication) + // Reschedule drivers which were not claimed by any workers + drivers.filter(_.worker.isEmpty).foreach { d => + logWarning(s"Driver ${d.id} was not found after master recovery") + if (d.desc.supervise) { + logWarning(s"Re-launching ${d.id}") + relaunchDriver(d) + } else { + removeDriver(d.id, DriverState.ERROR, None) + logWarning(s"Did not re-launch ${d.id} because it was not supervised") + } + } + state = RecoveryState.ALIVE schedule() logInfo("Recovery complete - resuming operations!") @@ -340,6 +446,18 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act */ def schedule() { if (state != RecoveryState.ALIVE) { return } + + // First schedule drivers, they take strict precedence over applications + val shuffledWorkers = Random.shuffle(workers) // Randomization helps balance drivers + for (worker <- shuffledWorkers if worker.state == WorkerState.ALIVE) { + for (driver <- waitingDrivers) { + if (worker.memoryFree >= driver.desc.mem && worker.coresFree >= driver.desc.cores) { + launchDriver(worker, driver) + waitingDrivers -= driver + } + } + } + // Right now this is a very simple FIFO scheduler. We keep trying to fit in the first app // in the queue, then the second app, etc. if (spreadOutApps) { @@ -426,9 +544,25 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act exec.id, ExecutorState.LOST, Some("worker lost"), None) exec.application.removeExecutor(exec) } + for (driver <- worker.drivers.values) { + if (driver.desc.supervise) { + logInfo(s"Re-launching ${driver.id}") + relaunchDriver(driver) + } else { + logInfo(s"Not re-launching ${driver.id} because it was not supervised") + removeDriver(driver.id, DriverState.ERROR, None) + } + } persistenceEngine.removeWorker(worker) } + def relaunchDriver(driver: DriverInfo) { + driver.worker = None + driver.state = DriverState.RELAUNCHING + waitingDrivers += driver + schedule() + } + def createApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = { val now = System.currentTimeMillis() val date = new Date(now) @@ -508,6 +642,41 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act } } } + + def newDriverId(submitDate: Date): String = { + val appId = "driver-%s-%04d".format(DATE_FORMAT.format(submitDate), nextDriverNumber) + nextDriverNumber += 1 + appId + } + + def createDriver(desc: DriverDescription): DriverInfo = { + val now = System.currentTimeMillis() + val date = new Date(now) + new DriverInfo(now, newDriverId(date), desc, date) + } + + def launchDriver(worker: WorkerInfo, driver: DriverInfo) { + logInfo("Launching driver " + driver.id + " on worker " + worker.id) + worker.addDriver(driver) + driver.worker = Some(worker) + worker.actor ! LaunchDriver(driver.id, driver.desc) + driver.state = DriverState.RUNNING + } + + def removeDriver(driverId: String, finalState: DriverState, exception: Option[Exception]) { + drivers.find(d => d.id == driverId) match { + case Some(driver) => + logInfo(s"Removing driver: $driverId") + drivers -= driver + completedDrivers += driver + persistenceEngine.removeDriver(driver) + driver.state = finalState + driver.exception = exception + driver.worker.foreach(w => w.removeDriver(driver)) + case None => + logWarning(s"Asked to remove unknown driver: $driverId") + } + } } private[spark] object Master { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala index 94b986caf283518e2e809d5b6742541bc481ead7..e3640ea4f7e640365bea55cce33fc74f3d30fbd5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala @@ -35,11 +35,15 @@ private[spark] trait PersistenceEngine { def removeWorker(worker: WorkerInfo) + def addDriver(driver: DriverInfo) + + def removeDriver(driver: DriverInfo) + /** * Returns the persisted data sorted by their respective ids (which implies that they're * sorted by time of creation). */ - def readPersistedData(): (Seq[ApplicationInfo], Seq[WorkerInfo]) + def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) def close() {} } @@ -49,5 +53,8 @@ private[spark] class BlackHolePersistenceEngine extends PersistenceEngine { override def removeApplication(app: ApplicationInfo) {} override def addWorker(worker: WorkerInfo) {} override def removeWorker(worker: WorkerInfo) {} - override def readPersistedData() = (Nil, Nil) + override def addDriver(driver: DriverInfo) {} + override def removeDriver(driver: DriverInfo) {} + + override def readPersistedData() = (Nil, Nil, Nil) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala index e05f587b58c6437ce869fd26d7021c8dc20338b0..c5fa9cf7d7c2d549aa4dbc421942af1f493020bb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala @@ -17,8 +17,10 @@ package org.apache.spark.deploy.master -import akka.actor.ActorRef import scala.collection.mutable + +import akka.actor.ActorRef + import org.apache.spark.util.Utils private[spark] class WorkerInfo( @@ -35,7 +37,8 @@ private[spark] class WorkerInfo( Utils.checkHost(host, "Expected hostname") assert (port > 0) - @transient var executors: mutable.HashMap[String, ExecutorInfo] = _ // fullId => info + @transient var executors: mutable.HashMap[String, ExecutorInfo] = _ // executorId => info + @transient var drivers: mutable.HashMap[String, DriverInfo] = _ // driverId => info @transient var state: WorkerState.Value = _ @transient var coresUsed: Int = _ @transient var memoryUsed: Int = _ @@ -54,6 +57,7 @@ private[spark] class WorkerInfo( private def init() { executors = new mutable.HashMap + drivers = new mutable.HashMap state = WorkerState.ALIVE coresUsed = 0 memoryUsed = 0 @@ -83,6 +87,18 @@ private[spark] class WorkerInfo( executors.values.exists(_.application == app) } + def addDriver(driver: DriverInfo) { + drivers(driver.id) = driver + memoryUsed += driver.desc.mem + coresUsed += driver.desc.cores + } + + def removeDriver(driver: DriverInfo) { + drivers -= driver.id + memoryUsed -= driver.desc.mem + coresUsed -= driver.desc.cores + } + def webUiAddress : String = { "http://" + this.publicAddress + ":" + this.webUiPort } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index 52000d4f9c11caa42a0900f0271d44e9d8f1d95d..f24f49ea8ad9ff939593da4bdcc4db552e3f1fff 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -49,6 +49,14 @@ class ZooKeeperPersistenceEngine(serialization: Serialization, conf: SparkConf) zk.delete(WORKING_DIR + "/app_" + app.id) } + override def addDriver(driver: DriverInfo) { + serializeIntoFile(WORKING_DIR + "/driver_" + driver.id, driver) + } + + override def removeDriver(driver: DriverInfo) { + zk.delete(WORKING_DIR + "/driver_" + driver.id) + } + override def addWorker(worker: WorkerInfo) { serializeIntoFile(WORKING_DIR + "/worker_" + worker.id, worker) } @@ -61,13 +69,15 @@ class ZooKeeperPersistenceEngine(serialization: Serialization, conf: SparkConf) zk.close() } - override def readPersistedData(): (Seq[ApplicationInfo], Seq[WorkerInfo]) = { + override def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = { val sortedFiles = zk.getChildren(WORKING_DIR).toList.sorted val appFiles = sortedFiles.filter(_.startsWith("app_")) val apps = appFiles.map(deserializeFromFile[ApplicationInfo]) + val driverFiles = sortedFiles.filter(_.startsWith("driver_")) + val drivers = driverFiles.map(deserializeFromFile[DriverInfo]) val workerFiles = sortedFiles.filter(_.startsWith("worker_")) val workers = workerFiles.map(deserializeFromFile[WorkerInfo]) - (apps, workers) + (apps, drivers, workers) } private def serializeIntoFile(path: String, value: AnyRef) { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index dbb0cb90f51862ba44135af89ebe26f534557b46..9485bfd89eb5797c21bdd06ea217c6cb0bb89508 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -67,11 +67,11 @@ private[spark] class ApplicationPage(parent: MasterWebUI) { <li><strong>User:</strong> {app.desc.user}</li> <li><strong>Cores:</strong> { - if (app.desc.maxCores == Integer.MAX_VALUE) { + if (app.desc.maxCores == None) { "Unlimited (%s granted)".format(app.coresGranted) } else { "%s (%s granted, %s left)".format( - app.desc.maxCores, app.coresGranted, app.coresLeft) + app.desc.maxCores.get, app.coresGranted, app.coresLeft) } } </li> diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala index 4ef762892c1e26aaf5220dbb8ea8759d3a13213f..a9af8df5525d68375b552be8526e3f1a685da111 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.master.ui import scala.concurrent.Await +import scala.concurrent.duration._ import scala.xml.Node import akka.pattern.ask @@ -26,7 +27,7 @@ import net.liftweb.json.JsonAST.JValue import org.apache.spark.deploy.{DeployWebUI, JsonProtocol} import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} -import org.apache.spark.deploy.master.{ApplicationInfo, WorkerInfo} +import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo} import org.apache.spark.ui.UIUtils import org.apache.spark.util.Utils @@ -56,6 +57,16 @@ private[spark] class IndexPage(parent: MasterWebUI) { val completedApps = state.completedApps.sortBy(_.endTime).reverse val completedAppsTable = UIUtils.listingTable(appHeaders, appRow, completedApps) + val driverHeaders = Seq("ID", "Submitted Time", "Worker", "State", "Cores", "Memory", "Main Class") + val activeDrivers = state.activeDrivers.sortBy(_.startTime).reverse + val activeDriversTable = UIUtils.listingTable(driverHeaders, driverRow, activeDrivers) + val completedDrivers = state.completedDrivers.sortBy(_.startTime).reverse + val completedDriversTable = UIUtils.listingTable(driverHeaders, driverRow, completedDrivers) + + // For now we only show driver information if the user has submitted drivers to the cluster. + // This is until we integrate the notion of drivers and applications in the UI. + def hasDrivers = activeDrivers.length > 0 || completedDrivers.length > 0 + val content = <div class="row-fluid"> <div class="span12"> @@ -70,6 +81,9 @@ private[spark] class IndexPage(parent: MasterWebUI) { <li><strong>Applications:</strong> {state.activeApps.size} Running, {state.completedApps.size} Completed </li> + <li><strong>Drivers:</strong> + {state.activeDrivers.size} Running, + {state.completedDrivers.size} Completed </li> </ul> </div> </div> @@ -84,17 +98,39 @@ private[spark] class IndexPage(parent: MasterWebUI) { <div class="row-fluid"> <div class="span12"> <h4> Running Applications </h4> - {activeAppsTable} </div> </div> + <div> + {if (hasDrivers) + <div class="row-fluid"> + <div class="span12"> + <h4> Running Drivers </h4> + {activeDriversTable} + </div> + </div> + } + </div> + <div class="row-fluid"> <div class="span12"> <h4> Completed Applications </h4> {completedAppsTable} </div> + </div> + + <div> + {if (hasDrivers) + <div class="row-fluid"> + <div class="span12"> + <h4> Completed Drivers </h4> + {completedDriversTable} + </div> + </div> + } </div>; + UIUtils.basicSparkPage(content, "Spark Master at " + state.uri) } @@ -134,4 +170,20 @@ private[spark] class IndexPage(parent: MasterWebUI) { <td>{DeployWebUI.formatDuration(app.duration)}</td> </tr> } + + def driverRow(driver: DriverInfo): Seq[Node] = { + <tr> + <td>{driver.id} </td> + <td>{driver.submitDate}</td> + <td>{driver.worker.map(w => <a href={w.webUiAddress}>{w.id.toString}</a>).getOrElse("None")}</td> + <td>{driver.state}</td> + <td sorttable_customkey={driver.desc.cores.toString}> + {driver.desc.cores} + </td> + <td sorttable_customkey={driver.desc.mem.toString}> + {Utils.megabytesToString(driver.desc.mem.toLong)} + </td> + <td>{driver.desc.command.arguments(1)}</td> + </tr> + } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala new file mode 100644 index 0000000000000000000000000000000000000000..7507bf8ad0e6c56b8f09381066ed5168add24d24 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala @@ -0,0 +1,63 @@ +package org.apache.spark.deploy.worker + +import java.io.{File, FileOutputStream, IOException, InputStream} +import java.lang.System._ + +import org.apache.spark.Logging +import org.apache.spark.deploy.Command +import org.apache.spark.util.Utils + +/** + ** Utilities for running commands with the spark classpath. + */ +object CommandUtils extends Logging { + private[spark] def buildCommandSeq(command: Command, memory: Int, sparkHome: String): Seq[String] = { + val runner = getEnv("JAVA_HOME", command).map(_ + "/bin/java").getOrElse("java") + + // SPARK-698: do not call the run.cmd script, as process.destroy() + // fails to kill a process tree on Windows + Seq(runner) ++ buildJavaOpts(command, memory, sparkHome) ++ Seq(command.mainClass) ++ + command.arguments + } + + private def getEnv(key: String, command: Command): Option[String] = + command.environment.get(key).orElse(Option(System.getenv(key))) + + /** + * Attention: this must always be aligned with the environment variables in the run scripts and + * the way the JAVA_OPTS are assembled there. + */ + def buildJavaOpts(command: Command, memory: Int, sparkHome: String): Seq[String] = { + val libraryOpts = getEnv("SPARK_LIBRARY_PATH", command) + .map(p => List("-Djava.library.path=" + p)) + .getOrElse(Nil) + val workerLocalOpts = Option(getenv("SPARK_JAVA_OPTS")).map(Utils.splitCommandString).getOrElse(Nil) + val userOpts = getEnv("SPARK_JAVA_OPTS", command).map(Utils.splitCommandString).getOrElse(Nil) + val memoryOpts = Seq(s"-Xms${memory}M", s"-Xmx${memory}M") + + // Figure out our classpath with the external compute-classpath script + val ext = if (System.getProperty("os.name").startsWith("Windows")) ".cmd" else ".sh" + val classPath = Utils.executeAndGetOutput( + Seq(sparkHome + "/bin/compute-classpath" + ext), + extraEnvironment=command.environment) + + Seq("-cp", classPath) ++ libraryOpts ++ workerLocalOpts ++ userOpts ++ memoryOpts + } + + /** Spawn a thread that will redirect a given stream to a file */ + def redirectStream(in: InputStream, file: File) { + val out = new FileOutputStream(file, true) + // TODO: It would be nice to add a shutdown hook here that explains why the output is + // terminating. Otherwise if the worker dies the executor logs will silently stop. + new Thread("redirect output to " + file) { + override def run() { + try { + Utils.copyStream(in, out, true) + } catch { + case e: IOException => + logInfo("Redirection to " + file + " closed: " + e.getMessage) + } + } + }.start() + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala new file mode 100644 index 0000000000000000000000000000000000000000..b4df1a0dd47184c12b24d9a460faa63eb3547ca1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.worker + +import java.io._ + +import scala.collection.JavaConversions._ +import scala.collection.mutable.Map + +import akka.actor.ActorRef +import com.google.common.base.Charsets +import com.google.common.io.Files +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileUtil, Path} + +import org.apache.spark.Logging +import org.apache.spark.deploy.{Command, DriverDescription} +import org.apache.spark.deploy.DeployMessages.DriverStateChanged +import org.apache.spark.deploy.master.DriverState +import org.apache.spark.deploy.master.DriverState.DriverState + +/** + * Manages the execution of one driver, including automatically restarting the driver on failure. + */ +private[spark] class DriverRunner( + val driverId: String, + val workDir: File, + val sparkHome: File, + val driverDesc: DriverDescription, + val worker: ActorRef, + val workerUrl: String) + extends Logging { + + @volatile var process: Option[Process] = None + @volatile var killed = false + + // Populated once finished + var finalState: Option[DriverState] = None + var finalException: Option[Exception] = None + var finalExitCode: Option[Int] = None + + // Decoupled for testing + private[deploy] def setClock(_clock: Clock) = clock = _clock + private[deploy] def setSleeper(_sleeper: Sleeper) = sleeper = _sleeper + private var clock = new Clock { + def currentTimeMillis(): Long = System.currentTimeMillis() + } + private var sleeper = new Sleeper { + def sleep(seconds: Int): Unit = (0 until seconds).takeWhile(f => {Thread.sleep(1000); !killed}) + } + + /** Starts a thread to run and manage the driver. */ + def start() = { + new Thread("DriverRunner for " + driverId) { + override def run() { + try { + val driverDir = createWorkingDirectory() + val localJarFilename = downloadUserJar(driverDir) + + // Make sure user application jar is on the classpath + // TODO: If we add ability to submit multiple jars they should also be added here + val env = Map(driverDesc.command.environment.toSeq: _*) + env("SPARK_CLASSPATH") = env.getOrElse("SPARK_CLASSPATH", "") + s":$localJarFilename" + val newCommand = Command(driverDesc.command.mainClass, + driverDesc.command.arguments.map(substituteVariables), env) + val command = CommandUtils.buildCommandSeq(newCommand, driverDesc.mem, + sparkHome.getAbsolutePath) + launchDriver(command, env, driverDir, driverDesc.supervise) + } + catch { + case e: Exception => finalException = Some(e) + } + + val state = + if (killed) { DriverState.KILLED } + else if (finalException.isDefined) { DriverState.ERROR } + else { + finalExitCode match { + case Some(0) => DriverState.FINISHED + case _ => DriverState.FAILED + } + } + + finalState = Some(state) + + worker ! DriverStateChanged(driverId, state, finalException) + } + }.start() + } + + /** Terminate this driver (or prevent it from ever starting if not yet started) */ + def kill() { + synchronized { + process.foreach(p => p.destroy()) + killed = true + } + } + + /** Replace variables in a command argument passed to us */ + private def substituteVariables(argument: String): String = argument match { + case "{{WORKER_URL}}" => workerUrl + case other => other + } + + /** + * Creates the working directory for this driver. + * Will throw an exception if there are errors preparing the directory. + */ + private def createWorkingDirectory(): File = { + val driverDir = new File(workDir, driverId) + if (!driverDir.exists() && !driverDir.mkdirs()) { + throw new IOException("Failed to create directory " + driverDir) + } + driverDir + } + + /** + * Download the user jar into the supplied directory and return its local path. + * Will throw an exception if there are errors downloading the jar. + */ + private def downloadUserJar(driverDir: File): String = { + + val jarPath = new Path(driverDesc.jarUrl) + + val emptyConf = new Configuration() + val jarFileSystem = jarPath.getFileSystem(emptyConf) + + val destPath = new File(driverDir.getAbsolutePath, jarPath.getName) + val jarFileName = jarPath.getName + val localJarFile = new File(driverDir, jarFileName) + val localJarFilename = localJarFile.getAbsolutePath + + if (!localJarFile.exists()) { // May already exist if running multiple workers on one node + logInfo(s"Copying user jar $jarPath to $destPath") + FileUtil.copy(jarFileSystem, jarPath, destPath, false, emptyConf) + } + + if (!localJarFile.exists()) { // Verify copy succeeded + throw new Exception(s"Did not see expected jar $jarFileName in $driverDir") + } + + localJarFilename + } + + private def launchDriver(command: Seq[String], envVars: Map[String, String], baseDir: File, + supervise: Boolean) { + val builder = new ProcessBuilder(command: _*).directory(baseDir) + envVars.map{ case(k,v) => builder.environment().put(k, v) } + + def initialize(process: Process) = { + // Redirect stdout and stderr to files + val stdout = new File(baseDir, "stdout") + CommandUtils.redirectStream(process.getInputStream, stdout) + + val stderr = new File(baseDir, "stderr") + val header = "Launch Command: %s\n%s\n\n".format( + command.mkString("\"", "\" \"", "\""), "=" * 40) + Files.append(header, stderr, Charsets.UTF_8) + CommandUtils.redirectStream(process.getErrorStream, stderr) + } + runCommandWithRetry(ProcessBuilderLike(builder), initialize, supervise) + } + + private[deploy] def runCommandWithRetry(command: ProcessBuilderLike, initialize: Process => Unit, + supervise: Boolean) { + // Time to wait between submission retries. + var waitSeconds = 1 + // A run of this many seconds resets the exponential back-off. + val successfulRunDuration = 5 + + var keepTrying = !killed + + while (keepTrying) { + logInfo("Launch Command: " + command.command.mkString("\"", "\" \"", "\"")) + + synchronized { + if (killed) { return } + process = Some(command.start()) + initialize(process.get) + } + + val processStart = clock.currentTimeMillis() + val exitCode = process.get.waitFor() + if (clock.currentTimeMillis() - processStart > successfulRunDuration * 1000) { + waitSeconds = 1 + } + + if (supervise && exitCode != 0 && !killed) { + logInfo(s"Command exited with status $exitCode, re-launching after $waitSeconds s.") + sleeper.sleep(waitSeconds) + waitSeconds = waitSeconds * 2 // exponential back-off + } + + keepTrying = supervise && exitCode != 0 && !killed + finalExitCode = Some(exitCode) + } + } +} + +private[deploy] trait Clock { + def currentTimeMillis(): Long +} + +private[deploy] trait Sleeper { + def sleep(seconds: Int) +} + +// Needed because ProcessBuilder is a final class and cannot be mocked +private[deploy] trait ProcessBuilderLike { + def start(): Process + def command: Seq[String] +} + +private[deploy] object ProcessBuilderLike { + def apply(processBuilder: ProcessBuilder) = new ProcessBuilderLike { + def start() = processBuilder.start() + def command = processBuilder.command() + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala new file mode 100644 index 0000000000000000000000000000000000000000..1640d5fee0f77404fc93c884bd2c8ecd8212d54b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -0,0 +1,31 @@ +package org.apache.spark.deploy.worker + +import akka.actor._ + +import org.apache.spark.SparkConf +import org.apache.spark.util.{AkkaUtils, Utils} + +/** + * Utility object for launching driver programs such that they share fate with the Worker process. + */ +object DriverWrapper { + def main(args: Array[String]) { + args.toList match { + case workerUrl :: mainClass :: extraArgs => + val (actorSystem, _) = AkkaUtils.createActorSystem("Driver", + Utils.localHostName(), 0, false, new SparkConf()) + actorSystem.actorOf(Props(classOf[WorkerWatcher], workerUrl), name = "workerWatcher") + + // Delegate to supplied main class + val clazz = Class.forName(args(1)) + val mainMethod = clazz.getMethod("main", classOf[Array[String]]) + mainMethod.invoke(null, extraArgs.toArray[String]) + + actorSystem.shutdown() + + case _ => + System.err.println("Usage: DriverWrapper <workerUrl> <driverMainClass> [options]") + System.exit(-1) + } + } +} \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index fff9cb60c78498b2643af10a311c63b3b85607bb..18885d7ca6daa2ea4124c9e838fc631da7a8be5b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -18,17 +18,15 @@ package org.apache.spark.deploy.worker import java.io._ -import java.lang.System.getenv import akka.actor.ActorRef import com.google.common.base.Charsets import com.google.common.io.Files -import org.apache.spark.{Logging} -import org.apache.spark.deploy.{ExecutorState, ApplicationDescription} +import org.apache.spark.Logging +import org.apache.spark.deploy.{ExecutorState, ApplicationDescription, Command} import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged -import org.apache.spark.util.Utils /** * Manages the execution of one executor process. @@ -44,16 +42,17 @@ private[spark] class ExecutorRunner( val host: String, val sparkHome: File, val workDir: File, + val workerUrl: String, var state: ExecutorState.Value) extends Logging { val fullId = appId + "/" + execId var workerThread: Thread = null var process: Process = null - var shutdownHook: Thread = null - private def getAppEnv(key: String): Option[String] = - appDesc.command.environment.get(key).orElse(Option(getenv(key))) + // NOTE: This is now redundant with the automated shut-down enforced by the Executor. It might + // make sense to remove this in the future. + var shutdownHook: Thread = null def start() { workerThread = new Thread("ExecutorRunner for " + fullId) { @@ -92,55 +91,17 @@ private[spark] class ExecutorRunner( /** Replace variables such as {{EXECUTOR_ID}} and {{CORES}} in a command argument passed to us */ def substituteVariables(argument: String): String = argument match { + case "{{WORKER_URL}}" => workerUrl case "{{EXECUTOR_ID}}" => execId.toString case "{{HOSTNAME}}" => host case "{{CORES}}" => cores.toString case other => other } - def buildCommandSeq(): Seq[String] = { - val command = appDesc.command - val runner = getAppEnv("JAVA_HOME").map(_ + "/bin/java").getOrElse("java") - // SPARK-698: do not call the run.cmd script, as process.destroy() - // fails to kill a process tree on Windows - Seq(runner) ++ buildJavaOpts() ++ Seq(command.mainClass) ++ - (command.arguments ++ Seq(appId)).map(substituteVariables) - } - - /** - * Attention: this must always be aligned with the environment variables in the run scripts and - * the way the JAVA_OPTS are assembled there. - */ - def buildJavaOpts(): Seq[String] = { - val libraryOpts = getAppEnv("SPARK_LIBRARY_PATH") - .map(p => List("-Djava.library.path=" + p)) - .getOrElse(Nil) - val workerLocalOpts = Option(getenv("SPARK_JAVA_OPTS")).map(Utils.splitCommandString).getOrElse(Nil) - val userOpts = getAppEnv("SPARK_JAVA_OPTS").map(Utils.splitCommandString).getOrElse(Nil) - val memoryOpts = Seq("-Xms" + memory + "M", "-Xmx" + memory + "M") - - // Figure out our classpath with the external compute-classpath script - val ext = if (System.getProperty("os.name").startsWith("Windows")) ".cmd" else ".sh" - val classPath = Utils.executeAndGetOutput( - Seq(sparkHome + "/bin/compute-classpath" + ext), - extraEnvironment=appDesc.command.environment) - - Seq("-cp", classPath) ++ libraryOpts ++ workerLocalOpts ++ userOpts ++ memoryOpts - } - - /** Spawn a thread that will redirect a given stream to a file */ - def redirectStream(in: InputStream, file: File) { - val out = new FileOutputStream(file, true) - new Thread("redirect output to " + file) { - override def run() { - try { - Utils.copyStream(in, out, true) - } catch { - case e: IOException => - logInfo("Redirection to " + file + " closed: " + e.getMessage) - } - } - }.start() + def getCommandSeq = { + val command = Command(appDesc.command.mainClass, + appDesc.command.arguments.map(substituteVariables) ++ Seq(appId), appDesc.command.environment) + CommandUtils.buildCommandSeq(command, memory, sparkHome.getAbsolutePath) } /** @@ -155,7 +116,7 @@ private[spark] class ExecutorRunner( } // Launch the process - val command = buildCommandSeq() + val command = getCommandSeq logInfo("Launch command: " + command.mkString("\"", "\" \"", "\"")) val builder = new ProcessBuilder(command: _*).directory(executorDir) val env = builder.environment() @@ -172,11 +133,11 @@ private[spark] class ExecutorRunner( // Redirect its stdout and stderr to files val stdout = new File(executorDir, "stdout") - redirectStream(process.getInputStream, stdout) + CommandUtils.redirectStream(process.getInputStream, stdout) val stderr = new File(executorDir, "stderr") Files.write(header, stderr, Charsets.UTF_8) - redirectStream(process.getErrorStream, stderr) + CommandUtils.redirectStream(process.getErrorStream, stderr) // Wait for it to exit; this is actually a bad thing if it happens, because we expect to run // long-lived processes only. However, in the future, we might restart the executor a few diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index fcaf4e92b18585f1a714a54eb33b71e20d35f33a..5182dcbb2abfdccc0f1bd5468d105a917f83df84 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -26,10 +26,12 @@ import scala.concurrent.duration._ import akka.actor._ import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} + import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.deploy.{ExecutorDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ -import org.apache.spark.deploy.master.Master +import org.apache.spark.deploy.master.{DriverState, Master} +import org.apache.spark.deploy.master.DriverState.DriverState import org.apache.spark.deploy.worker.ui.WorkerWebUI import org.apache.spark.metrics.MetricsSystem import org.apache.spark.util.{AkkaUtils, Utils} @@ -44,6 +46,8 @@ private[spark] class Worker( cores: Int, memory: Int, masterUrls: Array[String], + actorSystemName: String, + actorName: String, workDirPath: String = null, val conf: SparkConf) extends Actor with Logging { @@ -55,7 +59,7 @@ private[spark] class Worker( val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs // Send a heartbeat every (heartbeat timeout) / 4 milliseconds - val HEARTBEAT_MILLIS = conf.get("spark.worker.timeout", "60").toLong * 1000 / 4 + val HEARTBEAT_MILLIS = conf.getLong("spark.worker.timeout", 60) * 1000 / 4 val REGISTRATION_TIMEOUT = 20.seconds val REGISTRATION_RETRIES = 3 @@ -68,6 +72,7 @@ private[spark] class Worker( var masterAddress: Address = null var activeMasterUrl: String = "" var activeMasterWebUiUrl : String = "" + val akkaUrl = "akka.tcp://%s@%s:%s/user/%s".format(actorSystemName, host, port, actorName) @volatile var registered = false @volatile var connected = false val workerId = generateWorkerId() @@ -75,6 +80,9 @@ private[spark] class Worker( var workDir: File = null val executors = new HashMap[String, ExecutorRunner] val finishedExecutors = new HashMap[String, ExecutorRunner] + val drivers = new HashMap[String, DriverRunner] + val finishedDrivers = new HashMap[String, DriverRunner] + val publicAddress = { val envVar = System.getenv("SPARK_PUBLIC_DNS") if (envVar != null) envVar else host @@ -185,7 +193,10 @@ private[spark] class Worker( val execs = executors.values. map(e => new ExecutorDescription(e.appId, e.execId, e.cores, e.state)) - sender ! WorkerSchedulerStateResponse(workerId, execs.toList) + sender ! WorkerSchedulerStateResponse(workerId, execs.toList, drivers.keys.toSeq) + + case Heartbeat => + logInfo(s"Received heartbeat from driver ${sender.path}") case RegisterWorkerFailed(message) => if (!registered) { @@ -199,7 +210,7 @@ private[spark] class Worker( } else { logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name)) val manager = new ExecutorRunner(appId, execId, appDesc, cores_, memory_, - self, workerId, host, new File(execSparkHome_), workDir, ExecutorState.RUNNING) + self, workerId, host, new File(execSparkHome_), workDir, akkaUrl, ExecutorState.RUNNING) executors(appId + "/" + execId) = manager manager.start() coresUsed += cores_ @@ -219,8 +230,8 @@ private[spark] class Worker( logInfo("Executor " + fullId + " finished with state " + state + message.map(" message " + _).getOrElse("") + exitStatus.map(" exitStatus " + _).getOrElse("")) - finishedExecutors(fullId) = executor executors -= fullId + finishedExecutors(fullId) = executor coresUsed -= executor.cores memoryUsed -= executor.memory } @@ -239,13 +250,52 @@ private[spark] class Worker( } } + case LaunchDriver(driverId, driverDesc) => { + logInfo(s"Asked to launch driver $driverId") + val driver = new DriverRunner(driverId, workDir, sparkHome, driverDesc, self, akkaUrl) + drivers(driverId) = driver + driver.start() + + coresUsed += driverDesc.cores + memoryUsed += driverDesc.mem + } + + case KillDriver(driverId) => { + logInfo(s"Asked to kill driver $driverId") + drivers.get(driverId) match { + case Some(runner) => + runner.kill() + case None => + logError(s"Asked to kill unknown driver $driverId") + } + } + + case DriverStateChanged(driverId, state, exception) => { + state match { + case DriverState.ERROR => + logWarning(s"Driver $driverId failed with unrecoverable exception: ${exception.get}") + case DriverState.FINISHED => + logInfo(s"Driver $driverId exited successfully") + case DriverState.KILLED => + logInfo(s"Driver $driverId was killed by user") + } + masterLock.synchronized { + master ! DriverStateChanged(driverId, state, exception) + } + val driver = drivers.remove(driverId).get + finishedDrivers(driverId) = driver + memoryUsed -= driver.driverDesc.mem + coresUsed -= driver.driverDesc.cores + } + case x: DisassociatedEvent if x.remoteAddress == masterAddress => logInfo(s"$x Disassociated !") masterDisconnected() case RequestWorkerState => { sender ! WorkerStateResponse(host, port, workerId, executors.values.toList, - finishedExecutors.values.toList, activeMasterUrl, cores, memory, + finishedExecutors.values.toList, drivers.values.toList, + finishedDrivers.values.toList, activeMasterUrl, cores, memory, coresUsed, memoryUsed, activeMasterWebUiUrl) } } @@ -282,10 +332,11 @@ private[spark] object Worker { // The LocalSparkCluster runs multiple local sparkWorkerX actor systems val conf = new SparkConf val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("") + val actorName = "Worker" val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf) actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory, - masterUrls, workDir, conf), name = "Worker") + masterUrls, systemName, actorName, workDir, conf), name = actorName) (actorSystem, boundPort) } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala new file mode 100644 index 0000000000000000000000000000000000000000..0e0d0cd6264cfdf7bd7cdaf03f20bfb126acc92e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -0,0 +1,55 @@ +package org.apache.spark.deploy.worker + +import akka.actor.{Actor, Address, AddressFromURIString} +import akka.remote.{AssociatedEvent, AssociationErrorEvent, AssociationEvent, DisassociatedEvent, RemotingLifecycleEvent} + +import org.apache.spark.Logging +import org.apache.spark.deploy.DeployMessages.SendHeartbeat + +/** + * Actor which connects to a worker process and terminates the JVM if the connection is severed. + * Provides fate sharing between a worker and its associated child processes. + */ +private[spark] class WorkerWatcher(workerUrl: String) extends Actor + with Logging { + override def preStart() { + context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + + logInfo(s"Connecting to worker $workerUrl") + val worker = context.actorSelection(workerUrl) + worker ! SendHeartbeat // need to send a message here to initiate connection + } + + // Used to avoid shutting down JVM during tests + private[deploy] var isShutDown = false + private[deploy] def setTesting(testing: Boolean) = isTesting = testing + private var isTesting = false + + // Lets us filter events only from the worker's actor system + private val expectedHostPort = AddressFromURIString(workerUrl).hostPort + private def isWorker(address: Address) = address.hostPort == expectedHostPort + + def exitNonZero() = if (isTesting) isShutDown = true else System.exit(-1) + + override def receive = { + case AssociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) => + logInfo(s"Successfully connected to $workerUrl") + + case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound) + if isWorker(remoteAddress) => + // These logs may not be seen if the worker (and associated pipe) has died + logError(s"Could not initialize connection to worker $workerUrl. Exiting.") + logError(s"Error was: $cause") + exitNonZero() + + case DisassociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) => + // This log message will never be seen + logError(s"Lost connection to worker actor $workerUrl. Exiting.") + exitNonZero() + + case e: AssociationEvent => + // pass through association events relating to other remote actor systems + + case e => logWarning(s"Received unexpected actor system event: $e") + } +} \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala index 0d59048313079132a9a5d1d7162befc3b1e3cbc2..925c6fb1832d7eff6adea8d51c91baf937ef0a97 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala @@ -17,24 +17,20 @@ package org.apache.spark.deploy.worker.ui -import javax.servlet.http.HttpServletRequest - -import scala.xml.Node - -import scala.concurrent.duration._ import scala.concurrent.Await +import scala.xml.Node import akka.pattern.ask - +import javax.servlet.http.HttpServletRequest import net.liftweb.json.JsonAST.JValue import org.apache.spark.deploy.JsonProtocol import org.apache.spark.deploy.DeployMessages.{RequestWorkerState, WorkerStateResponse} -import org.apache.spark.deploy.worker.ExecutorRunner +import org.apache.spark.deploy.master.DriverState +import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} import org.apache.spark.ui.UIUtils import org.apache.spark.util.Utils - private[spark] class IndexPage(parent: WorkerWebUI) { val workerActor = parent.worker.self val worker = parent.worker @@ -56,6 +52,16 @@ private[spark] class IndexPage(parent: WorkerWebUI) { val finishedExecutorTable = UIUtils.listingTable(executorHeaders, executorRow, workerState.finishedExecutors) + val driverHeaders = Seq("DriverID", "Main Class", "State", "Cores", "Memory", "Logs", "Notes") + val runningDrivers = workerState.drivers.sortBy(_.driverId).reverse + val runningDriverTable = UIUtils.listingTable(driverHeaders, driverRow, runningDrivers) + val finishedDrivers = workerState.finishedDrivers.sortBy(_.driverId).reverse + def finishedDriverTable = UIUtils.listingTable(driverHeaders, driverRow, finishedDrivers) + + // For now we only show driver information if the user has submitted drivers to the cluster. + // This is until we integrate the notion of drivers and applications in the UI. + def hasDrivers = runningDrivers.length > 0 || finishedDrivers.length > 0 + val content = <div class="row-fluid"> <!-- Worker Details --> <div class="span12"> @@ -79,11 +85,33 @@ private[spark] class IndexPage(parent: WorkerWebUI) { </div> </div> + <div> + {if (hasDrivers) + <div class="row-fluid"> <!-- Running Drivers --> + <div class="span12"> + <h4> Running Drivers {workerState.drivers.size} </h4> + {runningDriverTable} + </div> + </div> + } + </div> + <div class="row-fluid"> <!-- Finished Executors --> <div class="span12"> <h4> Finished Executors </h4> {finishedExecutorTable} </div> + </div> + + <div> + {if (hasDrivers) + <div class="row-fluid"> <!-- Finished Drivers --> + <div class="span12"> + <h4> Finished Drivers </h4> + {finishedDriverTable} + </div> + </div> + } </div>; UIUtils.basicSparkPage(content, "Spark Worker at %s:%s".format( @@ -111,6 +139,27 @@ private[spark] class IndexPage(parent: WorkerWebUI) { .format(executor.appId, executor.execId)}>stderr</a> </td> </tr> + } + def driverRow(driver: DriverRunner): Seq[Node] = { + <tr> + <td>{driver.driverId}</td> + <td>{driver.driverDesc.command.arguments(1)}</td> + <td>{driver.finalState.getOrElse(DriverState.RUNNING)}</td> + <td sorttable_customkey={driver.driverDesc.cores.toString}> + {driver.driverDesc.cores.toString} + </td> + <td sorttable_customkey={driver.driverDesc.mem.toString}> + {Utils.megabytesToString(driver.driverDesc.mem)} + </td> + <td> + <a href={s"logPage?driverId=${driver.driverId}&logType=stdout"}>stdout</a> + <a href={s"logPage?driverId=${driver.driverId}&logType=stderr"}>stderr</a> + </td> + <td> + {driver.finalException.getOrElse("")} + </td> + </tr> + } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index c382034c99e000adbf4b003e0a9fb06565b907bb..8daa47b2b24352ab6925175f19db6bfa014d332e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -69,30 +69,48 @@ class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[I def log(request: HttpServletRequest): String = { val defaultBytes = 100 * 1024 - val appId = request.getParameter("appId") - val executorId = request.getParameter("executorId") + + val appId = Option(request.getParameter("appId")) + val executorId = Option(request.getParameter("executorId")) + val driverId = Option(request.getParameter("driverId")) val logType = request.getParameter("logType") val offset = Option(request.getParameter("offset")).map(_.toLong) val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes) - val path = "%s/%s/%s/%s".format(workDir.getPath, appId, executorId, logType) + + val path = (appId, executorId, driverId) match { + case (Some(a), Some(e), None) => + s"${workDir.getPath}/$appId/$executorId/$logType" + case (None, None, Some(d)) => + s"${workDir.getPath}/$driverId/$logType" + case _ => + throw new Exception("Request must specify either application or driver identifiers") + } val (startByte, endByte) = getByteRange(path, offset, byteLength) val file = new File(path) val logLength = file.length - val pre = "==== Bytes %s-%s of %s of %s/%s/%s ====\n" - .format(startByte, endByte, logLength, appId, executorId, logType) + val pre = s"==== Bytes $startByte-$endByte of $logLength of $path ====\n" pre + Utils.offsetBytes(path, startByte, endByte) } def logPage(request: HttpServletRequest): Seq[scala.xml.Node] = { val defaultBytes = 100 * 1024 - val appId = request.getParameter("appId") - val executorId = request.getParameter("executorId") + val appId = Option(request.getParameter("appId")) + val executorId = Option(request.getParameter("executorId")) + val driverId = Option(request.getParameter("driverId")) val logType = request.getParameter("logType") val offset = Option(request.getParameter("offset")).map(_.toLong) val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes) - val path = "%s/%s/%s/%s".format(workDir.getPath, appId, executorId, logType) + + val (path, params) = (appId, executorId, driverId) match { + case (Some(a), Some(e), None) => + (s"${workDir.getPath}/$a/$e/$logType", s"appId=$a&executorId=$e") + case (None, None, Some(d)) => + (s"${workDir.getPath}/$d/$logType", s"driverId=$d") + case _ => + throw new Exception("Request must specify either application or driver identifiers") + } val (startByte, endByte) = getByteRange(path, offset, byteLength) val file = new File(path) @@ -106,9 +124,8 @@ class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[I val backButton = if (startByte > 0) { - <a href={"?appId=%s&executorId=%s&logType=%s&offset=%s&byteLength=%s" - .format(appId, executorId, logType, math.max(startByte-byteLength, 0), - byteLength)}> + <a href={"?%s&logType=%s&offset=%s&byteLength=%s" + .format(params, logType, math.max(startByte-byteLength, 0), byteLength)}> <button type="button" class="btn btn-default"> Previous {Utils.bytesToString(math.min(byteLength, startByte))} </button> @@ -122,8 +139,8 @@ class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[I val nextButton = if (endByte < logLength) { - <a href={"?appId=%s&executorId=%s&logType=%s&offset=%s&byteLength=%s". - format(appId, executorId, logType, endByte, byteLength)}> + <a href={"?%s&logType=%s&offset=%s&byteLength=%s". + format(params, logType, endByte, byteLength)}> <button type="button" class="btn btn-default"> Next {Utils.bytesToString(math.min(byteLength, logLength-endByte))} </button> diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 53a2b94a52aa32f4917eeb8508475e7dd59ea56f..45b43b403dd8c087fa730300b53a29e4352b4027 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -24,8 +24,9 @@ import akka.remote._ import org.apache.spark.{SparkConf, SparkContext, Logging} import org.apache.spark.TaskState.TaskState +import org.apache.spark.deploy.worker.WorkerWatcher import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{Utils, AkkaUtils} +import org.apache.spark.util.{AkkaUtils, Utils} private[spark] class CoarseGrainedExecutorBackend( driverUrl: String, @@ -91,7 +92,8 @@ private[spark] class CoarseGrainedExecutorBackend( } private[spark] object CoarseGrainedExecutorBackend { - def run(driverUrl: String, executorId: String, hostname: String, cores: Int) { + def run(driverUrl: String, executorId: String, hostname: String, cores: Int, + workerUrl: Option[String]) { // Debug code Utils.checkHost(hostname) @@ -101,21 +103,27 @@ private[spark] object CoarseGrainedExecutorBackend { indestructible = true, conf = new SparkConf) // set it val sparkHostPort = hostname + ":" + boundPort -// conf.set("spark.hostPort", sparkHostPort) actorSystem.actorOf( Props(classOf[CoarseGrainedExecutorBackend], driverUrl, executorId, sparkHostPort, cores), name = "Executor") + workerUrl.foreach{ url => + actorSystem.actorOf(Props(classOf[WorkerWatcher], url), name = "WorkerWatcher") + } actorSystem.awaitTermination() } def main(args: Array[String]) { - if (args.length < 4) { - //the reason we allow the last appid argument is to make it easy to kill rogue executors - System.err.println( - "Usage: CoarseGrainedExecutorBackend <driverUrl> <executorId> <hostname> <cores> " + - "[<appid>]") - System.exit(1) + args.length match { + case x if x < 4 => + System.err.println( + // Worker url is used in spark standalone mode to enforce fate-sharing with worker + "Usage: CoarseGrainedExecutorBackend <driverUrl> <executorId> <hostname> " + + "<cores> [<workerUrl>]") + System.exit(1) + case 4 => + run(args(0), args(1), args(2), args(3).toInt, None) + case x if x > 4 => + run(args(0), args(1), args(2), args(3).toInt, Some(args(4))) } - run(args(0), args(1), args(2), args(3).toInt) } } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index e51d274d338748cd48a162d39458326012b921cb..7f31d7e6f8aecc2d790fae650bb157f623e2485d 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -57,7 +57,7 @@ private[spark] class Executor( Utils.setCustomHostname(slaveHostname) // Set spark.* properties from executor arg - val conf = new SparkConf(false) + val conf = new SparkConf(true) conf.setAll(properties) // If we are in yarn mode, systems can have different disk layouts so we must set it @@ -279,6 +279,11 @@ private[spark] class Executor( //System.exit(1) } } finally { + // TODO: Unregister shuffle memory only for ShuffleMapTask + val shuffleMemoryMap = env.shuffleMemoryMap + shuffleMemoryMap.synchronized { + shuffleMemoryMap.remove(Thread.currentThread().getId) + } runningTasks.remove(taskId) } } diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index a1e98845f6a848b8ab0651aab5e3b0d484d32a02..59801773205bdab204e25bb227995361feebf812 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -71,7 +71,7 @@ class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec { class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec { override def compressedOutputStream(s: OutputStream): OutputStream = { - val blockSize = conf.get("spark.io.compression.snappy.block.size", "32768").toInt + val blockSize = conf.getInt("spark.io.compression.snappy.block.size", 32768) new SnappyOutputStream(s, blockSize) } diff --git a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala index f736bb3713061f5258fddde2331a973dbd77ed76..fb4c65909a9e2c62a4049d01c4e781c4e6cb53ef 100644 --- a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala @@ -46,7 +46,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: throw new Exception("Max chunk size is " + maxChunkSize) } - if (size == 0 && gotChunkForSendingOnce == false) { + if (size == 0 && !gotChunkForSendingOnce) { val newChunk = new MessageChunk( new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null) gotChunkForSendingOnce = true diff --git a/core/src/main/scala/org/apache/spark/network/Connection.scala b/core/src/main/scala/org/apache/spark/network/Connection.scala index 95cb0206acd62e67a80edcb057743d29615e4500..cba8477ed572336bdcd74d3119adba1948794570 100644 --- a/core/src/main/scala/org/apache/spark/network/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/Connection.scala @@ -330,7 +330,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, // Is highly unlikely unless there was an unclean close of socket, etc registerInterest() logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending") - return true + true } catch { case e: Exception => { logWarning("Error finishing connection to " + address, e) @@ -385,7 +385,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, } } // should not happen - to keep scala compiler happy - return true + true } // This is a hack to determine if remote socket was closed or not. @@ -559,7 +559,7 @@ private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : S } } // should not happen - to keep scala compiler happy - return true + true } def onReceive(callback: (Connection, Message) => Unit) {onReceiveCallback = callback} diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index 46c40d0a2a02959ffb50ea7541c7092ac0a9bab2..e6e01783c889524c561f0addc6f136c91f6eae57 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -54,22 +54,22 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi private val selector = SelectorProvider.provider.openSelector() private val handleMessageExecutor = new ThreadPoolExecutor( - conf.get("spark.core.connection.handler.threads.min", "20").toInt, - conf.get("spark.core.connection.handler.threads.max", "60").toInt, - conf.get("spark.core.connection.handler.threads.keepalive", "60").toInt, TimeUnit.SECONDS, + conf.getInt("spark.core.connection.handler.threads.min", 20), + conf.getInt("spark.core.connection.handler.threads.max", 60), + conf.getInt("spark.core.connection.handler.threads.keepalive", 60), TimeUnit.SECONDS, new LinkedBlockingDeque[Runnable]()) private val handleReadWriteExecutor = new ThreadPoolExecutor( - conf.get("spark.core.connection.io.threads.min", "4").toInt, - conf.get("spark.core.connection.io.threads.max", "32").toInt, - conf.get("spark.core.connection.io.threads.keepalive", "60").toInt, TimeUnit.SECONDS, + conf.getInt("spark.core.connection.io.threads.min", 4), + conf.getInt("spark.core.connection.io.threads.max", 32), + conf.getInt("spark.core.connection.io.threads.keepalive", 60), TimeUnit.SECONDS, new LinkedBlockingDeque[Runnable]()) // Use a different, yet smaller, thread pool - infrequently used with very short lived tasks : which should be executed asap private val handleConnectExecutor = new ThreadPoolExecutor( - conf.get("spark.core.connection.connect.threads.min", "1").toInt, - conf.get("spark.core.connection.connect.threads.max", "8").toInt, - conf.get("spark.core.connection.connect.threads.keepalive", "60").toInt, TimeUnit.SECONDS, + conf.getInt("spark.core.connection.connect.threads.min", 1), + conf.getInt("spark.core.connection.connect.threads.max", 8), + conf.getInt("spark.core.connection.connect.threads.keepalive", 60), TimeUnit.SECONDS, new LinkedBlockingDeque[Runnable]()) private val serverChannel = ServerSocketChannel.open() diff --git a/core/src/main/scala/org/apache/spark/network/Message.scala b/core/src/main/scala/org/apache/spark/network/Message.scala index f2ecc6d439aaad826578cb4f5c86d4ada559ee07..2612884bdbe158fdc8e11866995919d02b9a18eb 100644 --- a/core/src/main/scala/org/apache/spark/network/Message.scala +++ b/core/src/main/scala/org/apache/spark/network/Message.scala @@ -61,7 +61,7 @@ private[spark] object Message { if (dataBuffers.exists(_ == null)) { throw new Exception("Attempting to create buffer message with null buffer") } - return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer] ++= dataBuffers, ackId) + new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer] ++= dataBuffers, ackId) } def createBufferMessage(dataBuffers: Seq[ByteBuffer]): BufferMessage = @@ -69,9 +69,9 @@ private[spark] object Message { def createBufferMessage(dataBuffer: ByteBuffer, ackId: Int): BufferMessage = { if (dataBuffer == null) { - return createBufferMessage(Array(ByteBuffer.allocate(0)), ackId) + createBufferMessage(Array(ByteBuffer.allocate(0)), ackId) } else { - return createBufferMessage(Array(dataBuffer), ackId) + createBufferMessage(Array(dataBuffer), ackId) } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala index b729eb11c514268798ecd21a679b6abbb3202e90..d87157e12c4876201371663ee94dc1f4a8c6a32d 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala @@ -36,7 +36,7 @@ private[spark] class ShuffleCopier(conf: SparkConf) extends Logging { resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) { val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback) - val connectTimeout = conf.get("spark.shuffle.netty.connect.timeout", "60000").toInt + val connectTimeout = conf.getInt("spark.shuffle.netty.connect.timeout", 60000) val fc = new FileClient(handler, connectTimeout) try { diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala index 546d921067175a2db075f8896e101782eb6c40c4..44204a8c46572169c950a09b23f54d7b452fdbec 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala @@ -64,7 +64,7 @@ private[spark] object ShuffleSender { val subDirId = (hash / localDirs.length) % subDirsPerLocalDir val subDir = new File(localDirs(dirId), "%02x".format(subDirId)) val file = new File(subDir, blockId.name) - return new FileSegment(file, 0, file.length()) + new FileSegment(file, 0, file.length()) } } val sender = new ShuffleSender(port, pResovler) diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index 6d4f46125f1a60bf872dfe27baad4cf1d07070ad..83109d1a6f853f0aa9c98c3eed86c8165557f3ca 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -97,7 +97,7 @@ private[spark] object CheckpointRDD extends Logging { throw new IOException("Checkpoint failed: temporary path " + tempOutputPath + " already exists") } - val bufferSize = env.conf.get("spark.buffer.size", "65536").toInt + val bufferSize = env.conf.getInt("spark.buffer.size", 65536) val fileOutputStream = if (blockSize < 0) { fs.create(tempOutputPath, false, bufferSize) @@ -131,7 +131,7 @@ private[spark] object CheckpointRDD extends Logging { ): Iterator[T] = { val env = SparkEnv.get val fs = path.getFileSystem(broadcastedConf.value.value) - val bufferSize = env.conf.get("spark.buffer.size", "65536").toInt + val bufferSize = env.conf.getInt("spark.buffer.size", 65536) val fileInputStream = fs.open(path, bufferSize) val serializer = env.serializer.newInstance() val deserializeStream = serializer.deserializeStream(fileInputStream) diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 4ba4696fef52159c8fd6c7141e8cce1ac5fd3056..a73714abcaf7262aa1067617116202f9ab812b51 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -23,8 +23,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext} import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency} -import org.apache.spark.util.AppendOnlyMap - +import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap} private[spark] sealed trait CoGroupSplitDep extends Serializable @@ -44,14 +43,12 @@ private[spark] case class NarrowCoGroupSplitDep( private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep -private[spark] -class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep]) +private[spark] class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep]) extends Partition with Serializable { override val index: Int = idx override def hashCode(): Int = idx } - /** * A RDD that cogroups its parents. For each key k in parent RDDs, the resulting RDD contains a * tuple with the list of values for that key. @@ -62,6 +59,14 @@ class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep]) class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: Partitioner) extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) { + // For example, `(k, a) cogroup (k, b)` produces k -> Seq(ArrayBuffer as, ArrayBuffer bs). + // Each ArrayBuffer is represented as a CoGroup, and the resulting Seq as a CoGroupCombiner. + // CoGroupValue is the intermediate state of each value before being merged in compute. + private type CoGroup = ArrayBuffer[Any] + private type CoGroupValue = (Any, Int) // Int is dependency number + private type CoGroupCombiner = Seq[CoGroup] + + private val sparkConf = SparkEnv.get.conf private var serializerClass: String = null def setSerializer(cls: String): CoGroupedRDD[K] = { @@ -100,37 +105,74 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: override val partitioner = Some(part) - override def compute(s: Partition, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = { + override def compute(s: Partition, context: TaskContext): Iterator[(K, CoGroupCombiner)] = { + val externalSorting = sparkConf.getBoolean("spark.shuffle.externalSorting", true) val split = s.asInstanceOf[CoGroupPartition] val numRdds = split.deps.size - // e.g. for `(k, a) cogroup (k, b)`, K -> Seq(ArrayBuffer as, ArrayBuffer bs) - val map = new AppendOnlyMap[K, Seq[ArrayBuffer[Any]]] - val update: (Boolean, Seq[ArrayBuffer[Any]]) => Seq[ArrayBuffer[Any]] = (hadVal, oldVal) => { - if (hadVal) oldVal else Array.fill(numRdds)(new ArrayBuffer[Any]) - } - - val getSeq = (k: K) => { - map.changeValue(k, update) - } - - val ser = SparkEnv.get.serializerManager.get(serializerClass, SparkEnv.get.conf) + // A list of (rdd iterator, dependency number) pairs + val rddIterators = new ArrayBuffer[(Iterator[Product2[K, Any]], Int)] for ((dep, depNum) <- split.deps.zipWithIndex) dep match { case NarrowCoGroupSplitDep(rdd, _, itsSplit) => { // Read them from the parent - rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]].foreach { kv => - getSeq(kv._1)(depNum) += kv._2 - } + val it = rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]] + rddIterators += ((it, depNum)) } case ShuffleCoGroupSplitDep(shuffleId) => { // Read map outputs of shuffle val fetcher = SparkEnv.get.shuffleFetcher - fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser).foreach { - kv => getSeq(kv._1)(depNum) += kv._2 + val ser = SparkEnv.get.serializerManager.get(serializerClass, sparkConf) + val it = fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser) + rddIterators += ((it, depNum)) + } + } + + if (!externalSorting) { + val map = new AppendOnlyMap[K, CoGroupCombiner] + val update: (Boolean, CoGroupCombiner) => CoGroupCombiner = (hadVal, oldVal) => { + if (hadVal) oldVal else Array.fill(numRdds)(new CoGroup) + } + val getCombiner: K => CoGroupCombiner = key => { + map.changeValue(key, update) + } + rddIterators.foreach { case (it, depNum) => + while (it.hasNext) { + val kv = it.next() + getCombiner(kv._1)(depNum) += kv._2 } } + new InterruptibleIterator(context, map.iterator) + } else { + val map = createExternalMap(numRdds) + rddIterators.foreach { case (it, depNum) => + while (it.hasNext) { + val kv = it.next() + map.insert(kv._1, new CoGroupValue(kv._2, depNum)) + } + } + new InterruptibleIterator(context, map.iterator) + } + } + + private def createExternalMap(numRdds: Int) + : ExternalAppendOnlyMap[K, CoGroupValue, CoGroupCombiner] = { + + val createCombiner: (CoGroupValue => CoGroupCombiner) = value => { + val newCombiner = Array.fill(numRdds)(new CoGroup) + value match { case (v, depNum) => newCombiner(depNum) += v } + newCombiner } - new InterruptibleIterator(context, map.iterator) + val mergeValue: (CoGroupCombiner, CoGroupValue) => CoGroupCombiner = + (combiner, value) => { + value match { case (v, depNum) => combiner(depNum) += v } + combiner + } + val mergeCombiners: (CoGroupCombiner, CoGroupCombiner) => CoGroupCombiner = + (combiner1, combiner2) => { + combiner1.zip(combiner2).map { case (v1, v2) => v1 ++ v2 } + } + new ExternalAppendOnlyMap[K, CoGroupValue, CoGroupCombiner]( + createCombiner, mergeValue, mergeCombiners) } override def clearDependencies() { diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index 98da35763b9d15c96d9cd1d84330c729aec759bc..cefcc3d2d9420178c816937a6e8c615143debd8a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -295,10 +295,10 @@ private[spark] class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanc val prefPartActual = prefPart.get - if (minPowerOfTwo.size + slack <= prefPartActual.size) // more imbalance than the slack allows - return minPowerOfTwo // prefer balance over locality - else { - return prefPartActual // prefer locality over balance + if (minPowerOfTwo.size + slack <= prefPartActual.size) { // more imbalance than the slack allows + minPowerOfTwo // prefer balance over locality + } else { + prefPartActual // prefer locality over balance } } @@ -331,7 +331,7 @@ private[spark] class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanc */ def run(): Array[PartitionGroup] = { setupGroups(math.min(prev.partitions.length, maxPartitions)) // setup the groups (bins) - throwBalls() // assign partitions (balls) to each group (bins) + throwBalls() // assign partitions (balls) to each group (bins) getPartitions } } diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 53f77a38f55f685e215c5c70b467a17db4a1116c..5cdb80be1ddd8026e537e62b11494d333c753441 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -19,7 +19,10 @@ package org.apache.spark.rdd import java.io.EOFException -import org.apache.hadoop.mapred.FileInputFormat +import scala.reflect.ClassTag + +import org.apache.hadoop.conf.{Configuration, Configurable} +import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.mapred.InputSplit import org.apache.hadoop.mapred.JobConf @@ -31,7 +34,7 @@ import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.util.NextIterator -import org.apache.hadoop.conf.{Configuration, Configurable} +import org.apache.spark.util.Utils.cloneWritables /** @@ -42,14 +45,14 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSp val inputSplit = new SerializableWritable[InputSplit](s) - override def hashCode(): Int = (41 * (41 + rddId) + idx).toInt + override def hashCode(): Int = 41 * (41 + rddId) + idx override val index: Int = idx } /** * An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS, - * sources in HBase, or S3). + * sources in HBase, or S3), using the older MapReduce API (`org.apache.hadoop.mapred`). * * @param sc The SparkContext to associate the RDD with. * @param broadcastedConf A general Hadoop Configuration, or a subclass of it. If the enclosed @@ -61,15 +64,21 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSp * @param keyClass Class of the key associated with the inputFormatClass. * @param valueClass Class of the value associated with the inputFormatClass. * @param minSplits Minimum number of Hadoop Splits (HadoopRDD partitions) to generate. + * @param cloneRecords If true, Spark will clone the records produced by Hadoop RecordReader. + * Most RecordReader implementations reuse wrapper objects across multiple + * records, and can cause problems in RDD collect or aggregation operations. + * By default the records are cloned in Spark. However, application + * programmers can explicitly disable the cloning for better performance. */ -class HadoopRDD[K, V]( +class HadoopRDD[K: ClassTag, V: ClassTag]( sc: SparkContext, broadcastedConf: Broadcast[SerializableWritable[Configuration]], initLocalJobConfFuncOpt: Option[JobConf => Unit], inputFormatClass: Class[_ <: InputFormat[K, V]], keyClass: Class[K], valueClass: Class[V], - minSplits: Int) + minSplits: Int, + cloneRecords: Boolean) extends RDD[(K, V)](sc, Nil) with Logging { def this( @@ -78,7 +87,8 @@ class HadoopRDD[K, V]( inputFormatClass: Class[_ <: InputFormat[K, V]], keyClass: Class[K], valueClass: Class[V], - minSplits: Int) = { + minSplits: Int, + cloneRecords: Boolean) = { this( sc, sc.broadcast(new SerializableWritable(conf)) @@ -87,7 +97,8 @@ class HadoopRDD[K, V]( inputFormatClass, keyClass, valueClass, - minSplits) + minSplits, + cloneRecords) } protected val jobConfCacheKey = "rdd_%d_job_conf".format(id) @@ -99,11 +110,11 @@ class HadoopRDD[K, V]( val conf: Configuration = broadcastedConf.value.value if (conf.isInstanceOf[JobConf]) { // A user-broadcasted JobConf was provided to the HadoopRDD, so always use it. - return conf.asInstanceOf[JobConf] + conf.asInstanceOf[JobConf] } else if (HadoopRDD.containsCachedMetadata(jobConfCacheKey)) { // getJobConf() has been called previously, so there is already a local cache of the JobConf // needed by this RDD. - return HadoopRDD.getCachedMetadata(jobConfCacheKey).asInstanceOf[JobConf] + HadoopRDD.getCachedMetadata(jobConfCacheKey).asInstanceOf[JobConf] } else { // Create a JobConf that will be cached and used across this RDD's getJobConf() calls in the // local process. The local cache is accessed through HadoopRDD.putCachedMetadata(). @@ -111,7 +122,7 @@ class HadoopRDD[K, V]( val newJobConf = new JobConf(broadcastedConf.value.value) initLocalJobConfFuncOpt.map(f => f(newJobConf)) HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf) - return newJobConf + newJobConf } } @@ -127,7 +138,7 @@ class HadoopRDD[K, V]( newInputFormat.asInstanceOf[Configurable].setConf(conf) } HadoopRDD.putCachedMetadata(inputFormatCacheKey, newInputFormat) - return newInputFormat + newInputFormat } override def getPartitions: Array[Partition] = { @@ -158,10 +169,10 @@ class HadoopRDD[K, V]( // Register an on-task-completion callback to close the input stream. context.addOnCompleteCallback{ () => closeIfNeeded() } - val key: K = reader.createKey() + val keyCloneFunc = cloneWritables[K](jobConf) val value: V = reader.createValue() - + val valueCloneFunc = cloneWritables[V](jobConf) override def getNext() = { try { finished = !reader.next(key, value) @@ -169,7 +180,11 @@ class HadoopRDD[K, V]( case eof: EOFException => finished = true } - (key, value) + if (cloneRecords) { + (keyCloneFunc(key.asInstanceOf[Writable]), valueCloneFunc(value.asInstanceOf[Writable])) + } else { + (key, value) + } } override def close() { diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 73d15b90822addbb26062193e2f323c7f41a525a..992bd4aa0ad5dbc17283048a1f0613a3ea3144b6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -20,11 +20,14 @@ package org.apache.spark.rdd import java.text.SimpleDateFormat import java.util.Date +import scala.reflect.ClassTag + import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ import org.apache.spark.{InterruptibleIterator, Logging, Partition, SerializableWritable, SparkContext, TaskContext} +import org.apache.spark.util.Utils.cloneWritables private[spark] @@ -33,15 +36,31 @@ class NewHadoopPartition(rddId: Int, val index: Int, @transient rawSplit: InputS val serializableHadoopSplit = new SerializableWritable(rawSplit) - override def hashCode(): Int = (41 * (41 + rddId) + index) + override def hashCode(): Int = 41 * (41 + rddId) + index } -class NewHadoopRDD[K, V]( +/** + * An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS, + * sources in HBase, or S3), using the new MapReduce API (`org.apache.hadoop.mapreduce`). + * + * @param sc The SparkContext to associate the RDD with. + * @param inputFormatClass Storage format of the data to be read. + * @param keyClass Class of the key associated with the inputFormatClass. + * @param valueClass Class of the value associated with the inputFormatClass. + * @param conf The Hadoop configuration. + * @param cloneRecords If true, Spark will clone the records produced by Hadoop RecordReader. + * Most RecordReader implementations reuse wrapper objects across multiple + * records, and can cause problems in RDD collect or aggregation operations. + * By default the records are cloned in Spark. However, application + * programmers can explicitly disable the cloning for better performance. + */ +class NewHadoopRDD[K: ClassTag, V: ClassTag]( sc : SparkContext, inputFormatClass: Class[_ <: InputFormat[K, V]], keyClass: Class[K], valueClass: Class[V], - @transient conf: Configuration) + @transient conf: Configuration, + cloneRecords: Boolean) extends RDD[(K, V)](sc, Nil) with SparkHadoopMapReduceUtil with Logging { @@ -88,7 +107,8 @@ class NewHadoopRDD[K, V]( // Register an on-task-completion callback to close the input stream. context.addOnCompleteCallback(() => close()) - + val keyCloneFunc = cloneWritables[K](conf) + val valueCloneFunc = cloneWritables[V](conf) var havePair = false var finished = false @@ -105,7 +125,13 @@ class NewHadoopRDD[K, V]( throw new java.util.NoSuchElementException("End of stream") } havePair = false - (reader.getCurrentKey, reader.getCurrentValue) + val key = reader.getCurrentKey + val value = reader.getCurrentValue + if (cloneRecords) { + (keyCloneFunc(key.asInstanceOf[Writable]), valueCloneFunc(value.asInstanceOf[Writable])) + } else { + (key, value) + } } private def close() { diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 2bf7c5b8d65e9e3ce44a84494a726741cfcc762c..f6719ec57cbf7a413bea5438fbd81ddf52c9f648 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -18,35 +18,34 @@ package org.apache.spark.rdd import java.nio.ByteBuffer -import java.util.Date import java.text.SimpleDateFormat +import java.util.Date import java.util.{HashMap => JHashMap} -import scala.collection.{mutable, Map} +import scala.collection.Map +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConversions._ import scala.reflect.{ClassTag, classTag} -import org.apache.hadoop.mapred._ -import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.io.SequenceFile.CompressionType -import org.apache.hadoop.mapred.FileOutputFormat -import org.apache.hadoop.mapred.OutputFormat +import org.apache.hadoop.io.compress.CompressionCodec +import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} -import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat} -import org.apache.hadoop.mapreduce.SparkHadoopMapReduceUtil import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob} import org.apache.hadoop.mapreduce.{RecordWriter => NewRecordWriter} +import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat} import com.clearspring.analytics.stream.cardinality.HyperLogLog +// SparkHadoopWriter and SparkHadoopMapReduceUtil are actually source files defined in Spark. +import org.apache.hadoop.mapred.SparkHadoopWriter +import org.apache.hadoop.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark._ import org.apache.spark.SparkContext._ import org.apache.spark.partial.{BoundedDouble, PartialResult} -import org.apache.spark.Aggregator -import org.apache.spark.Partitioner import org.apache.spark.Partitioner.defaultPartitioner import org.apache.spark.util.SerializableHyperLogLog @@ -100,8 +99,6 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) }, preservesPartitioning = true) } else { // Don't apply map-side combiner. - // A sanity check to make sure mergeCombiners is not defined. - assert(mergeCombiners == null) val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializerClass) values.mapPartitionsWithContext((context, iter) => { new InterruptibleIterator(context, aggregator.combineValuesByKey(iter)) @@ -120,9 +117,9 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) } /** - * Merge the values for each key using an associative function and a neutral "zero value" which may - * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for - * list concatenation, 0 for addition, or 1 for multiplication.). + * Merge the values for each key using an associative function and a neutral "zero value" which + * may be added to the result an arbitrary number of times, and must not change the result + * (e.g., Nil for list concatenation, 0 for addition, or 1 for multiplication.). */ def foldByKey(zeroValue: V, partitioner: Partitioner)(func: (V, V) => V): RDD[(K, V)] = { // Serialize the zero value to a byte array so that we can get a new clone of it on each key @@ -138,18 +135,18 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) } /** - * Merge the values for each key using an associative function and a neutral "zero value" which may - * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for - * list concatenation, 0 for addition, or 1 for multiplication.). + * Merge the values for each key using an associative function and a neutral "zero value" which + * may be added to the result an arbitrary number of times, and must not change the result + * (e.g., Nil for list concatenation, 0 for addition, or 1 for multiplication.). */ def foldByKey(zeroValue: V, numPartitions: Int)(func: (V, V) => V): RDD[(K, V)] = { foldByKey(zeroValue, new HashPartitioner(numPartitions))(func) } /** - * Merge the values for each key using an associative function and a neutral "zero value" which may - * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for - * list concatenation, 0 for addition, or 1 for multiplication.). + * Merge the values for each key using an associative function and a neutral "zero value" which + * may be added to the result an arbitrary number of times, and must not change the result + * (e.g., Nil for list concatenation, 0 for addition, or 1 for multiplication.). */ def foldByKey(zeroValue: V)(func: (V, V) => V): RDD[(K, V)] = { foldByKey(zeroValue, defaultPartitioner(self))(func) @@ -226,7 +223,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) } /** - * Return approximate number of distinct values for each key in this RDD. + * Return approximate number of distinct values for each key in this RDD. * The accuracy of approximation can be controlled through the relative standard deviation * (relativeSD) parameter, which also controls the amount of memory used. Lower values result in * more accurate counts but increase the memory footprint and vise versa. HashPartitions the @@ -268,8 +265,9 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) // into a hash table, leading to more objects in the old gen. def createCombiner(v: V) = ArrayBuffer(v) def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v + def mergeCombiners(c1: ArrayBuffer[V], c2: ArrayBuffer[V]) = c1 ++ c2 val bufs = combineByKey[ArrayBuffer[V]]( - createCombiner _, mergeValue _, null, partitioner, mapSideCombine=false) + createCombiner _, mergeValue _, mergeCombiners _, partitioner, mapSideCombine=false) bufs.asInstanceOf[RDD[(K, Seq[V])]] } @@ -340,7 +338,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) * existing partitioner/parallelism level. */ def combineByKey[C](createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C) - : RDD[(K, C)] = { + : RDD[(K, C)] = { combineByKey(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(self)) } @@ -579,7 +577,8 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) */ def saveAsHadoopFile[F <: OutputFormat[K, V]]( path: String, codec: Class[_ <: CompressionCodec]) (implicit fm: ClassTag[F]) { - saveAsHadoopFile(path, getKeyClass, getValueClass, fm.runtimeClass.asInstanceOf[Class[F]], codec) + val runtimeClass = fm.runtimeClass + saveAsHadoopFile(path, getKeyClass, getValueClass, runtimeClass.asInstanceOf[Class[F]], codec) } /** @@ -599,7 +598,8 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) keyClass: Class[_], valueClass: Class[_], outputFormatClass: Class[_ <: NewOutputFormat[_, _]], - conf: Configuration = self.context.hadoopConfiguration) { + conf: Configuration = self.context.hadoopConfiguration) + { val job = new NewAPIHadoopJob(conf) job.setOutputKeyClass(keyClass) job.setOutputValueClass(valueClass) @@ -668,7 +668,9 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) codec: Option[Class[_ <: CompressionCodec]] = None) { conf.setOutputKeyClass(keyClass) conf.setOutputValueClass(valueClass) - // conf.setOutputFormat(outputFormatClass) // Doesn't work in Scala 2.9 due to what may be a generics bug + // Doesn't work in Scala 2.9 due to what may be a generics bug + // TODO: Should we uncomment this for Scala 2.10? + // conf.setOutputFormat(outputFormatClass) conf.set("mapred.output.format.class", outputFormatClass.getName) for (c <- codec) { conf.setCompressMapOutput(true) @@ -702,7 +704,8 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) throw new SparkException("Output value class not set") } - logInfo("Saving as hadoop file of type (" + keyClass.getSimpleName+ ", " + valueClass.getSimpleName+ ")") + logDebug("Saving as hadoop file of type (" + keyClass.getSimpleName + ", " + + valueClass.getSimpleName+ ")") val writer = new SparkHadoopWriter(conf) writer.preSetup() diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala index 1dbbe39898c3ea179f92f0468435d086328eea89..d4f396afb5d2bae34acd2c600c32f44118ed7056 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala @@ -96,7 +96,7 @@ class PipedRDD[T: ClassTag]( // Return an iterator that read lines from the process's stdout val lines = Source.fromInputStream(proc.getInputStream).getLines - return new Iterator[String] { + new Iterator[String] { def next() = lines.next() def hasNext = { if (lines.hasNext) { diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 2142ae730e9ffff2475376de752d21813e950ba3..cd90a1561a975e4424690491188daf007ddbf2c9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -23,7 +23,6 @@ import scala.collection.Map import scala.collection.JavaConversions.mapAsScalaMap import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap import scala.reflect.{classTag, ClassTag} import org.apache.hadoop.io.BytesWritable @@ -52,11 +51,13 @@ import org.apache.spark._ * partitioned collection of elements that can be operated on in parallel. This class contains the * basic operations available on all RDDs, such as `map`, `filter`, and `persist`. In addition, * [[org.apache.spark.rdd.PairRDDFunctions]] contains operations available only on RDDs of key-value - * pairs, such as `groupByKey` and `join`; [[org.apache.spark.rdd.DoubleRDDFunctions]] contains - * operations available only on RDDs of Doubles; and [[org.apache.spark.rdd.SequenceFileRDDFunctions]] - * contains operations available on RDDs that can be saved as SequenceFiles. These operations are - * automatically available on any RDD of the right type (e.g. RDD[(Int, Int)] through implicit - * conversions when you `import org.apache.spark.SparkContext._`. + * pairs, such as `groupByKey` and `join`; + * [[org.apache.spark.rdd.DoubleRDDFunctions]] contains operations available only on RDDs of + * Doubles; and + * [[org.apache.spark.rdd.SequenceFileRDDFunctions]] contains operations available on RDDs that + * can be saved as SequenceFiles. + * These operations are automatically available on any RDD of the right type (e.g. RDD[(Int, Int)] + * through implicit conversions when you `import org.apache.spark.SparkContext._`. * * Internally, each RDD is characterized by five main properties: * @@ -235,12 +236,9 @@ abstract class RDD[T: ClassTag]( /** * Compute an RDD partition or read it from a checkpoint if the RDD is checkpointing. */ - private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] = { - if (isCheckpointed) { - firstParent[T].iterator(split, context) - } else { - compute(split, context) - } + private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] = + { + if (isCheckpointed) firstParent[T].iterator(split, context) else compute(split, context) } // Transformations (return a new RDD) @@ -268,6 +266,9 @@ abstract class RDD[T: ClassTag]( def distinct(numPartitions: Int): RDD[T] = map(x => (x, null)).reduceByKey((x, y) => x, numPartitions).map(_._1) + /** + * Return a new RDD containing the distinct elements in this RDD. + */ def distinct(): RDD[T] = distinct(partitions.size) /** @@ -280,7 +281,7 @@ abstract class RDD[T: ClassTag]( * which can avoid performing a shuffle. */ def repartition(numPartitions: Int): RDD[T] = { - coalesce(numPartitions, true) + coalesce(numPartitions, shuffle = true) } /** @@ -651,7 +652,8 @@ abstract class RDD[T: ClassTag]( } /** - * Reduces the elements of this RDD using the specified commutative and associative binary operator. + * Reduces the elements of this RDD using the specified commutative and + * associative binary operator. */ def reduce(f: (T, T) => T): T = { val cleanF = sc.clean(f) @@ -767,7 +769,7 @@ abstract class RDD[T: ClassTag]( val entry = iter.next() m1.put(entry.getKey, m1.getLong(entry.getKey) + entry.getLongValue) } - return m1 + m1 } val myResult = mapPartitions(countPartition).reduce(mergeMaps) myResult.asInstanceOf[java.util.Map[T, Long]] // Will be wrapped as a Scala mutable Map @@ -845,7 +847,7 @@ abstract class RDD[T: ClassTag]( partsScanned += numPartsToTry } - return buf.toArray + buf.toArray } /** @@ -958,7 +960,7 @@ abstract class RDD[T: ClassTag]( private var storageLevel: StorageLevel = StorageLevel.NONE /** Record user function generating this RDD. */ - @transient private[spark] val origin = sc.getCallSite + @transient private[spark] val origin = sc.getCallSite() private[spark] def elementClassTag: ClassTag[T] = classTag[T] diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 043e01dbfbf28d930eb08ad6e03a12be5ab5a0f2..7046c06d2057d68e6e16f4d6bc2e15e5cbda42db 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -106,7 +106,7 @@ class DAGScheduler( // The time, in millis, to wait for fetch failure events to stop coming in after one is detected; // this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one // as more failure events come in - val RESUBMIT_TIMEOUT = 50.milliseconds + val RESUBMIT_TIMEOUT = 200.milliseconds // The time, in millis, to wake up between polls of the completion queue in order to potentially // resubmit failed stages @@ -133,7 +133,8 @@ class DAGScheduler( private[spark] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo] - private[spark] val listenerBus = new SparkListenerBus() + // An async scheduler event bus. The bus should be stopped when DAGSCheduler is stopped. + private[spark] val listenerBus = new SparkListenerBus // Contains the locations that each RDD's partitions are cached on private val cacheLocs = new HashMap[Int, Array[Seq[TaskLocation]]] @@ -196,7 +197,7 @@ class DAGScheduler( */ def receive = { case event: DAGSchedulerEvent => - logDebug("Got event of type " + event.getClass.getName) + logTrace("Got event of type " + event.getClass.getName) /** * All events are forwarded to `processEvent()`, so that the event processing logic can @@ -1121,5 +1122,6 @@ class DAGScheduler( } metadataCleaner.cancel() taskSched.stop() + listenerBus.stop() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala index 90eb8a747f91c88436aa9f6ffd205eade6380f8e..cc10cc0849bc78ff3440d9e7e0323efee398620c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala @@ -103,7 +103,7 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl retval ++= SplitInfo.toSplitInfo(inputFormatClazz, path, split) } - return retval.toSet + retval.toSet } // This method does not expect failures, since validate has already passed ... @@ -121,18 +121,18 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl elem => retval ++= SplitInfo.toSplitInfo(inputFormatClazz, path, elem) ) - return retval.toSet + retval.toSet } private def findPreferredLocations(): Set[SplitInfo] = { logDebug("mapreduceInputFormat : " + mapreduceInputFormat + ", mapredInputFormat : " + mapredInputFormat + ", inputFormatClazz : " + inputFormatClazz) if (mapreduceInputFormat) { - return prefLocsFromMapreduceInputFormat() + prefLocsFromMapreduceInputFormat() } else { assert(mapredInputFormat) - return prefLocsFromMapredInputFormat() + prefLocsFromMapredInputFormat() } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala index 17912422150782a72252385ee696044355397f87..4bc13c23d980be000d5fc3c3cad9216aab51d680 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala @@ -75,12 +75,12 @@ private[spark] class Pool( return schedulableNameToSchedulable(schedulableName) } for (schedulable <- schedulableQueue) { - var sched = schedulable.getSchedulableByName(schedulableName) + val sched = schedulable.getSchedulableByName(schedulableName) if (sched != null) { return sched } } - return null + null } override def executorLost(executorId: String, host: String) { @@ -92,7 +92,7 @@ private[spark] class Pool( for (schedulable <- schedulableQueue) { shouldRevive |= schedulable.checkSpeculatableTasks() } - return shouldRevive + shouldRevive } override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { @@ -101,7 +101,7 @@ private[spark] class Pool( for (schedulable <- sortedSchedulableQueue) { sortedTaskSetQueue ++= schedulable.getSortedTaskSetQueue() } - return sortedTaskSetQueue + sortedTaskSetQueue } def increaseRunningTasks(taskNum: Int) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala index 3418640b8c59ec00c16af51ab876b08808a5c784..5e62c8468f0070a8e51b74cea9a5b3a5c0a067b3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala @@ -37,9 +37,9 @@ private[spark] class FIFOSchedulingAlgorithm extends SchedulingAlgorithm { res = math.signum(stageId1 - stageId2) } if (res < 0) { - return true + true } else { - return false + false } } } @@ -56,7 +56,6 @@ private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm { val minShareRatio2 = runningTasks2.toDouble / math.max(minShare2, 1.0).toDouble val taskToWeightRatio1 = runningTasks1.toDouble / s1.weight.toDouble val taskToWeightRatio2 = runningTasks2.toDouble / s2.weight.toDouble - var res:Boolean = true var compare:Int = 0 if (s1Needy && !s2Needy) { @@ -70,11 +69,11 @@ private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm { } if (compare < 0) { - return true + true } else if (compare > 0) { - return false + false } else { - return s1.name < s2.name + s1.name < s2.name } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 627995c826e2b844cdc6a9656ee2d45c51a204df..55a40a92c96521ebdeece352ebf65324cbeee59f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -43,6 +43,9 @@ case class SparkListenerJobStart(job: ActiveJob, stageIds: Array[Int], propertie case class SparkListenerJobEnd(job: ActiveJob, jobResult: JobResult) extends SparkListenerEvents +/** An event used in the listener to shutdown the listener daemon thread. */ +private[scheduler] case object SparkListenerShutdown extends SparkListenerEvents + trait SparkListener { /** * Called when a stage is completed, with information on the completed stage diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index e7defd768b2c3857e03b8747b6026dfcb9428dd6..17b1328b86788b1cae1ed3654fd5f8954fdf494a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -24,15 +24,17 @@ import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import org.apache.spark.Logging /** Asynchronously passes SparkListenerEvents to registered SparkListeners. */ -private[spark] class SparkListenerBus() extends Logging { - private val sparkListeners = new ArrayBuffer[SparkListener]() with SynchronizedBuffer[SparkListener] +private[spark] class SparkListenerBus extends Logging { + private val sparkListeners = new ArrayBuffer[SparkListener] with SynchronizedBuffer[SparkListener] /* Cap the capacity of the SparkListenerEvent queue so we get an explicit error (rather than * an OOM exception) if it's perpetually being added to more quickly than it's being drained. */ - private val EVENT_QUEUE_CAPACITY = 10000 + private val EVENT_QUEUE_CAPACITY = 10000 private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents](EVENT_QUEUE_CAPACITY) private var queueFullErrorMessageLogged = false + // Create a new daemon thread to listen for events. This thread is stopped when it receives + // a SparkListenerShutdown event, using the stop method. new Thread("SparkListenerBus") { setDaemon(true) override def run() { @@ -53,6 +55,9 @@ private[spark] class SparkListenerBus() extends Logging { sparkListeners.foreach(_.onTaskGettingResult(taskGettingResult)) case taskEnd: SparkListenerTaskEnd => sparkListeners.foreach(_.onTaskEnd(taskEnd)) + case SparkListenerShutdown => + // Get out of the while loop and shutdown the daemon thread + return case _ => } } @@ -80,7 +85,7 @@ private[spark] class SparkListenerBus() extends Logging { */ def waitUntilEmpty(timeoutMillis: Int): Boolean = { val finishTime = System.currentTimeMillis + timeoutMillis - while (!eventQueue.isEmpty()) { + while (!eventQueue.isEmpty) { if (System.currentTimeMillis > finishTime) { return false } @@ -88,6 +93,8 @@ private[spark] class SparkListenerBus() extends Logging { * add overhead in the general case. */ Thread.sleep(10) } - return true + true } + + def stop(): Unit = post(SparkListenerShutdown) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 7cb3fe46e5baf6c9880d804b80cb3015b64a3045..c60e9896dee4ffe0ef2254deed52fc157bbe7b3b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -96,7 +96,7 @@ private[spark] class Stage( def newAttemptId(): Int = { val id = nextAttemptId nextAttemptId += 1 - return id + id } val name = callSite.getOrElse(rdd.origin) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala index e80cc6b0f64e81ff3588c92bfea1cfb4acb1624d..9d3e6158266b8bd4f70f8dffe2891d9ea3298d4f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala @@ -74,6 +74,6 @@ class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long def value(): T = { val resultSer = SparkEnv.get.serializer.newInstance() - return resultSer.deserialize(valueBytes) + resultSer.deserialize(valueBytes) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index e22b1e53e80482149fc19d7c65d5930a71a352d8..35e9544718eb2ea9e5842d58d0d572977151421d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -31,13 +31,13 @@ import org.apache.spark.util.Utils private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedulerImpl) extends Logging { - private val THREADS = sparkEnv.conf.get("spark.resultGetter.threads", "4").toInt + private val THREADS = sparkEnv.conf.getInt("spark.resultGetter.threads", 4) private val getTaskResultExecutor = Utils.newDaemonFixedThreadPool( THREADS, "Result resolver thread") protected val serializer = new ThreadLocal[SerializerInstance] { override def initialValue(): SerializerInstance = { - return sparkEnv.closureSerializer.newInstance() + sparkEnv.closureSerializer.newInstance() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 0c8ed6275991a4cd7f8def626244c1f73ed8cd28..d4f74d3e1854344af2186d063d80090760b4bf40 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -51,15 +51,15 @@ private[spark] class TaskSchedulerImpl( isLocal: Boolean = false) extends TaskScheduler with Logging { - def this(sc: SparkContext) = this(sc, sc.conf.get("spark.task.maxFailures", "4").toInt) + def this(sc: SparkContext) = this(sc, sc.conf.getInt("spark.task.maxFailures", 4)) val conf = sc.conf // How often to check for speculative tasks - val SPECULATION_INTERVAL = conf.get("spark.speculation.interval", "100").toLong + val SPECULATION_INTERVAL = conf.getLong("spark.speculation.interval", 100) // Threshold above which we warn user initial TaskSet may be starved - val STARVATION_TIMEOUT = conf.get("spark.starvation.timeout", "15000").toLong + val STARVATION_TIMEOUT = conf.getLong("spark.starvation.timeout", 15000) // TaskSetManagers are not thread safe, so any access to one should be synchronized // on this class. @@ -125,7 +125,7 @@ private[spark] class TaskSchedulerImpl( override def start() { backend.start() - if (!isLocal && conf.get("spark.speculation", "false").toBoolean) { + if (!isLocal && conf.getBoolean("spark.speculation", false)) { logInfo("Starting speculative execution thread") import sc.env.actorSystem.dispatcher sc.env.actorSystem.scheduler.schedule(SPECULATION_INTERVAL milliseconds, diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 6dd1469d8f801b829a739f878bec88c1b8746b97..fc0ee070897ddac1cfd601deace897b5d401cfc1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -57,11 +57,11 @@ private[spark] class TaskSetManager( val conf = sched.sc.conf // CPUs to request per task - val CPUS_PER_TASK = conf.get("spark.task.cpus", "1").toInt + val CPUS_PER_TASK = conf.getInt("spark.task.cpus", 1) // Quantile of tasks at which to start speculation - val SPECULATION_QUANTILE = conf.get("spark.speculation.quantile", "0.75").toDouble - val SPECULATION_MULTIPLIER = conf.get("spark.speculation.multiplier", "1.5").toDouble + val SPECULATION_QUANTILE = conf.getDouble("spark.speculation.quantile", 0.75) + val SPECULATION_MULTIPLIER = conf.getDouble("spark.speculation.multiplier", 1.5) // Serializer for closures and tasks. val env = SparkEnv.get @@ -116,7 +116,7 @@ private[spark] class TaskSetManager( // How frequently to reprint duplicate exceptions in full, in milliseconds val EXCEPTION_PRINT_INTERVAL = - conf.get("spark.logging.exceptionPrintInterval", "10000").toLong + conf.getLong("spark.logging.exceptionPrintInterval", 10000) // Map of recent exceptions (identified by string representation and top stack frame) to // duplicate count (how many times the same exception has appeared) and time the full exception @@ -228,7 +228,7 @@ private[spark] class TaskSetManager( return Some(index) } } - return None + None } /** Check whether a task is currently running an attempt on a given host */ @@ -291,7 +291,7 @@ private[spark] class TaskSetManager( } } - return None + None } /** @@ -332,7 +332,7 @@ private[spark] class TaskSetManager( } // Finally, if all else has failed, find a speculative task - return findSpeculativeTask(execId, host, locality) + findSpeculativeTask(execId, host, locality) } /** @@ -387,7 +387,7 @@ private[spark] class TaskSetManager( case _ => } } - return None + None } /** @@ -584,7 +584,7 @@ private[spark] class TaskSetManager( } override def getSchedulableByName(name: String): Schedulable = { - return null + null } override def addSchedulable(schedulable: Schedulable) {} @@ -594,7 +594,7 @@ private[spark] class TaskSetManager( override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { var sortedTaskSetQueue = ArrayBuffer[TaskSetManager](this) sortedTaskSetQueue += this - return sortedTaskSetQueue + sortedTaskSetQueue } /** Called by TaskScheduler when an executor is lost so we can re-enqueue our tasks */ @@ -669,7 +669,7 @@ private[spark] class TaskSetManager( } } } - return foundTasks + foundTasks } private def getLocalityWait(level: TaskLocality.TaskLocality): Long = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 2f5bcafe40394d24a28bfd66915bb98436652f2f..0208388e86680754ca01e9a6a3db224a35f73e39 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -63,7 +63,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) // Periodically revive offers to allow delay scheduling to work - val reviveInterval = conf.get("spark.scheduler.revive.interval", "1000").toLong + val reviveInterval = conf.getLong("spark.scheduler.revive.interval", 1000) import context.dispatcher context.system.scheduler.schedule(0.millis, reviveInterval.millis, self, ReviveOffers) } @@ -165,7 +165,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A override def start() { val properties = new ArrayBuffer[(String, String)] for ((key, value) <- scheduler.sc.conf.getAll) { - if (key.startsWith("spark.") && !key.equals("spark.hostPort")) { + if (key.startsWith("spark.")) { properties += ((key, value)) } } @@ -209,8 +209,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A } override def defaultParallelism(): Int = { - conf.getOption("spark.default.parallelism").map(_.toInt).getOrElse( - math.max(totalCoreCount.get(), 2)) + conf.getInt("spark.default.parallelism", math.max(totalCoreCount.get(), 2)) } // Called by subclasses when notified of a lost worker diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala index b44d1e43c85c770b5f7d3d76c5d29d556be2e93d..d99c76117c168a63597d56482f4eeb90cd9e3a9a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala @@ -33,7 +33,7 @@ private[spark] class SimrSchedulerBackend( val tmpPath = new Path(driverFilePath + "_tmp") val filePath = new Path(driverFilePath) - val maxCores = conf.get("spark.simr.executor.cores", "1").toInt + val maxCores = conf.getInt("spark.simr.executor.cores", 1) override def start() { super.start() diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 73fc37444e18f19c54e1a602a355b978bba29304..faa6e1ebe886f46b030f2114844a8229cee357c5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -20,7 +20,7 @@ package org.apache.spark.scheduler.cluster import scala.collection.mutable.HashMap import org.apache.spark.{Logging, SparkContext} -import org.apache.spark.deploy.client.{Client, ClientListener} +import org.apache.spark.deploy.client.{AppClient, AppClientListener} import org.apache.spark.deploy.{Command, ApplicationDescription} import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SlaveLost, TaskSchedulerImpl} import org.apache.spark.util.Utils @@ -31,10 +31,10 @@ private[spark] class SparkDeploySchedulerBackend( masters: Array[String], appName: String) extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) - with ClientListener + with AppClientListener with Logging { - var client: Client = null + var client: AppClient = null var stopping = false var shutdownCallback : (SparkDeploySchedulerBackend) => Unit = _ @@ -47,14 +47,14 @@ private[spark] class SparkDeploySchedulerBackend( val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format( conf.get("spark.driver.host"), conf.get("spark.driver.port"), CoarseGrainedSchedulerBackend.ACTOR_NAME) - val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}") + val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}", "{{WORKER_URL}}") val command = Command( "org.apache.spark.executor.CoarseGrainedExecutorBackend", args, sc.executorEnvs) val sparkHome = sc.getSparkHome().getOrElse(null) val appDesc = new ApplicationDescription(appName, maxCores, sc.executorMemory, command, sparkHome, "http://" + sc.ui.appUIAddress) - client = new Client(sc.env.actorSystem, masters, appDesc, this, conf) + client = new AppClient(sc.env.actorSystem, masters, appDesc, this, conf) client.start() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index d46fceba8918a1cc76deea05b938315072f86ff3..c27049bdb520834787b2b30efb5b0cf9d0b916db 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -77,7 +77,7 @@ private[spark] class CoarseMesosSchedulerBackend( "Spark home is not set; set it through the spark.home system " + "property, the SPARK_HOME environment variable or the SparkContext constructor")) - val extraCoresPerSlave = conf.get("spark.mesos.extra.cores", "0").toInt + val extraCoresPerSlave = conf.getInt("spark.mesos.extra.cores", 0) var nextMesosTaskId = 0 @@ -140,7 +140,7 @@ private[spark] class CoarseMesosSchedulerBackend( .format(basename, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores)) command.addUris(CommandInfo.URI.newBuilder().setValue(uri)) } - return command.build() + command.build() } override def offerRescinded(d: SchedulerDriver, o: OfferID) {} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index ae8d527352f733b5f8e002372b0a422e0b33c8c6..49781485d9f967638f5de41d581760e8c8d1bc37 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -141,13 +141,13 @@ private[spark] class MesosSchedulerBackend( // Serialize the map as an array of (String, String) pairs execArgs = Utils.serialize(props.toArray) } - return execArgs + execArgs } private def setClassLoader(): ClassLoader = { val oldClassLoader = Thread.currentThread.getContextClassLoader Thread.currentThread.setContextClassLoader(classLoader) - return oldClassLoader + oldClassLoader } private def restoreClassLoader(oldClassLoader: ClassLoader) { @@ -255,7 +255,7 @@ private[spark] class MesosSchedulerBackend( .setType(Value.Type.SCALAR) .setScalar(Value.Scalar.newBuilder().setValue(1).build()) .build() - return MesosTaskInfo.newBuilder() + MesosTaskInfo.newBuilder() .setTaskId(taskId) .setSlaveId(SlaveID.newBuilder().setValue(slaveId).build()) .setExecutor(createExecutorInfo(slaveId)) @@ -340,5 +340,5 @@ private[spark] class MesosSchedulerBackend( } // TODO: query Mesos for number of cores - override def defaultParallelism() = sc.conf.get("spark.default.parallelism", "8").toInt + override def defaultParallelism() = sc.conf.getInt("spark.default.parallelism", 8) } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index a24a3b04b87ccf57f365143781b98226b4c256a0..c14cd4755698776cff65a4abe24be5f1a2b26d5c 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -36,7 +36,7 @@ import org.apache.spark.storage.{GetBlock, GotBlock, PutBlock} */ class KryoSerializer(conf: SparkConf) extends org.apache.spark.serializer.Serializer with Logging { private val bufferSize = { - conf.get("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024 + conf.getInt("spark.kryoserializer.buffer.mb", 2) * 1024 * 1024 } def newKryoOutput() = new KryoOutput(bufferSize) @@ -48,7 +48,7 @@ class KryoSerializer(conf: SparkConf) extends org.apache.spark.serializer.Serial // Allow disabling Kryo reference tracking if user knows their object graphs don't have loops. // Do this before we invoke the user registrator so the user registrator can override this. - kryo.setReferences(conf.get("spark.kryo.referenceTracking", "true").toBoolean) + kryo.setReferences(conf.getBoolean("spark.kryo.referenceTracking", true)) for (cls <- KryoSerializer.toRegister) kryo.register(cls) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala index 47478631a11f0289282250fe9ba4d00622479236..4fa2ab96d97255c0d8235e0077697a33caa97889 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala @@ -327,7 +327,7 @@ object BlockFetcherIterator { fetchRequestsSync.put(request) } - copiers = startCopiers(conf.get("spark.shuffle.copier.threads", "6").toInt) + copiers = startCopiers(conf.getInt("spark.shuffle.copier.threads", 6)) logInfo("Started " + fetchRequestsSync.size + " remote gets in " + Utils.getUsedTimeMs(startTime)) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 7156d855d873b3ec63710ca72204922aa564a5d8..301d784b350a3a8e77a595bf80eab03ed5489fcf 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -17,12 +17,14 @@ package org.apache.spark.storage +import java.util.UUID + /** * Identifies a particular Block of data, usually associated with a single file. * A Block can be uniquely identified by its filename, but each type of Block has a different * set of keys which produce its unique name. * - * If your BlockId should be serializable, be sure to add it to the BlockId.fromString() method. + * If your BlockId should be serializable, be sure to add it to the BlockId.apply() method. */ private[spark] sealed abstract class BlockId { /** A globally unique identifier for this Block. Can be used for ser/de. */ @@ -55,7 +57,8 @@ private[spark] case class BroadcastBlockId(broadcastId: Long) extends BlockId { def name = "broadcast_" + broadcastId } -private[spark] case class BroadcastHelperBlockId(broadcastId: BroadcastBlockId, hType: String) extends BlockId { +private[spark] +case class BroadcastHelperBlockId(broadcastId: BroadcastBlockId, hType: String) extends BlockId { def name = broadcastId.name + "_" + hType } @@ -67,6 +70,11 @@ private[spark] case class StreamBlockId(streamId: Int, uniqueId: Long) extends B def name = "input-" + streamId + "-" + uniqueId } +/** Id associated with temporary data managed as blocks. Not serializable. */ +private[spark] case class TempBlockId(id: UUID) extends BlockId { + def name = "temp_" + id +} + // Intended only for testing purposes private[spark] case class TestBlockId(id: String) extends BlockId { def name = "test_" + id diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 6d2cda97b04ebb5c9e08b2ee33c6450a2499a449..6f1345c57a295609ef53eebeaaecfd2c505c4efa 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -58,8 +58,8 @@ private[spark] class BlockManager( // If we use Netty for shuffle, start a new Netty-based shuffle sender service. private val nettyPort: Int = { - val useNetty = conf.get("spark.shuffle.use.netty", "false").toBoolean - val nettyPortConfig = conf.get("spark.shuffle.sender.port", "0").toInt + val useNetty = conf.getBoolean("spark.shuffle.use.netty", false) + val nettyPortConfig = conf.getInt("spark.shuffle.sender.port", 0) if (useNetty) diskBlockManager.startShuffleBlockSender(nettyPortConfig) else 0 } @@ -72,19 +72,17 @@ private[spark] class BlockManager( // Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory // for receiving shuffle outputs) val maxBytesInFlight = - conf.get("spark.reducer.maxMbInFlight", "48").toLong * 1024 * 1024 + conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024 // Whether to compress broadcast variables that are stored - val compressBroadcast = conf.get("spark.broadcast.compress", "true").toBoolean + val compressBroadcast = conf.getBoolean("spark.broadcast.compress", true) // Whether to compress shuffle output that are stored - val compressShuffle = conf.get("spark.shuffle.compress", "true").toBoolean + val compressShuffle = conf.getBoolean("spark.shuffle.compress", true) // Whether to compress RDD partitions that are stored serialized - val compressRdds = conf.get("spark.rdd.compress", "false").toBoolean + val compressRdds = conf.getBoolean("spark.rdd.compress", false) val heartBeatFrequency = BlockManager.getHeartBeatFrequency(conf) - val hostPort = Utils.localHostPort(conf) - val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)), name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next) @@ -159,7 +157,7 @@ private[spark] class BlockManager( /** * Reregister with the master and report all blocks to it. This will be called by the heart beat - * thread if our heartbeat to the block amnager indicates that we were not registered. + * thread if our heartbeat to the block manager indicates that we were not registered. * * Note that this method must be called without any BlockInfo locks held. */ @@ -412,7 +410,7 @@ private[spark] class BlockManager( logDebug("The value of block " + blockId + " is null") } logDebug("Block " + blockId + " not found") - return None + None } /** @@ -443,7 +441,7 @@ private[spark] class BlockManager( : BlockFetcherIterator = { val iter = - if (conf.get("spark.shuffle.use.netty", "false").toBoolean) { + if (conf.getBoolean("spark.shuffle.use.netty", false)) { new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer) } else { new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer) @@ -469,7 +467,7 @@ private[spark] class BlockManager( def getDiskWriter(blockId: BlockId, file: File, serializer: Serializer, bufferSize: Int) : BlockObjectWriter = { val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _) - val syncWrites = conf.get("spark.shuffle.sync", "false").toBoolean + val syncWrites = conf.getBoolean("spark.shuffle.sync", false) new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream, syncWrites) } @@ -864,15 +862,15 @@ private[spark] object BlockManager extends Logging { val ID_GENERATOR = new IdGenerator def getMaxMemory(conf: SparkConf): Long = { - val memoryFraction = conf.get("spark.storage.memoryFraction", "0.66").toDouble + val memoryFraction = conf.getDouble("spark.storage.memoryFraction", 0.6) (Runtime.getRuntime.maxMemory * memoryFraction).toLong } def getHeartBeatFrequency(conf: SparkConf): Long = - conf.get("spark.storage.blockManagerTimeoutIntervalMs", "60000").toLong / 4 + conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", 60000) / 4 def getDisableHeartBeatsForTesting(conf: SparkConf): Boolean = - conf.get("spark.test.disableBlockManagerHeartBeat", "false").toBoolean + conf.getBoolean("spark.test.disableBlockManagerHeartBeat", false) /** * Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 51a29ed8ef81ad2818b86f5fad6bbc6a874c2378..c54e4f2664753c8a8f3f79b0a9e4dd2c98604612 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -30,8 +30,8 @@ import org.apache.spark.util.AkkaUtils private[spark] class BlockManagerMaster(var driverActor : ActorRef, conf: SparkConf) extends Logging { - val AKKA_RETRY_ATTEMPTS: Int = conf.get("spark.akka.num.retries", "3").toInt - val AKKA_RETRY_INTERVAL_MS: Int = conf.get("spark.akka.retry.wait", "3000").toInt + val AKKA_RETRY_ATTEMPTS: Int = conf.getInt("spark.akka.num.retries", 3) + val AKKA_RETRY_INTERVAL_MS: Int = conf.getInt("spark.akka.retry.wait", 3000) val DRIVER_AKKA_ACTOR_NAME = "BlockManagerMaster" diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 58452d96574c9b2d78dc148b14d91e0a2b50769e..2c1a4e2f5d3a18ce86faa39831a6403d8ea422e6 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -348,14 +348,19 @@ object BlockManagerMasterActor { if (storageLevel.isValid) { // isValid means it is either stored in-memory or on-disk. - _blocks.put(blockId, BlockStatus(storageLevel, memSize, diskSize)) + // But the memSize here indicates the data size in or dropped from memory, + // and the diskSize here indicates the data size in or dropped to disk. + // They can be both larger than 0, when a block is dropped from memory to disk. + // Therefore, a safe way to set BlockStatus is to set its info in accurate modes. if (storageLevel.useMemory) { + _blocks.put(blockId, BlockStatus(storageLevel, memSize, 0)) _remainingMem -= memSize logInfo("Added %s in memory on %s (size: %s, free: %s)".format( blockId, blockManagerId.hostPort, Utils.bytesToString(memSize), Utils.bytesToString(_remainingMem))) } if (storageLevel.useDisk) { + _blocks.put(blockId, BlockStatus(storageLevel, 0, diskSize)) logInfo("Added %s on disk on %s (size: %s)".format( blockId, blockManagerId.hostPort, Utils.bytesToString(diskSize))) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala index 21f003609b14de5f511ebca2006afe72faf53d0b..42f52d7b26a04b677a66e3d14a8f4a0cd797b598 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala @@ -42,15 +42,15 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends val blockMessages = BlockMessageArray.fromBufferMessage(bufferMessage) logDebug("Parsed as a block message array") val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get) - return Some(new BlockMessageArray(responseMessages).toBufferMessage) + Some(new BlockMessageArray(responseMessages).toBufferMessage) } catch { case e: Exception => logError("Exception handling buffer message", e) - return None + None } } case otherMessage: Any => { logError("Unknown type message received: " + otherMessage) - return None + None } } } @@ -61,7 +61,7 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends val pB = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel) logDebug("Received [" + pB + "]") putBlock(pB.id, pB.data, pB.level) - return None + None } case BlockMessage.TYPE_GET_BLOCK => { val gB = new GetBlock(blockMessage.getId) @@ -70,9 +70,9 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends if (buffer == null) { return None } - return Some(BlockMessage.fromGotBlock(GotBlock(gB.id, buffer))) + Some(BlockMessage.fromGotBlock(GotBlock(gB.id, buffer))) } - case _ => return None + case _ => None } } @@ -93,7 +93,7 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends } logDebug("GetBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs) + " and got buffer " + buffer) - return buffer + buffer } } @@ -111,7 +111,7 @@ private[spark] object BlockManagerWorker extends Logging { val blockMessageArray = new BlockMessageArray(blockMessage) val resultMessage = connectionManager.sendMessageReliablySync( toConnManagerId, blockMessageArray.toBufferMessage) - return (resultMessage != None) + resultMessage != None } def syncGetBlock(msg: GetBlock, toConnManagerId: ConnectionManagerId): ByteBuffer = { @@ -130,8 +130,8 @@ private[spark] object BlockManagerWorker extends Logging { return blockMessage.getData }) } - case None => logDebug("No response message received"); return null + case None => logDebug("No response message received") } - return null + null } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala b/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala index 80dcb5a2074d0a6c07c899db61564bd5c894657c..fbafcf79d28339af7cebfe34bc827ff83dcbee43 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala @@ -154,7 +154,7 @@ private[spark] class BlockMessage() { println() */ val finishTime = System.currentTimeMillis - return Message.createBufferMessage(buffers) + Message.createBufferMessage(buffers) } override def toString: String = { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala index a06f50a0ac89cf6f3ada7fad21d8beb59245ab82..59329361f320b147618cfbd4e5048f801e44a37e 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala @@ -96,7 +96,7 @@ class BlockMessageArray(var blockMessages: Seq[BlockMessage]) extends Seq[BlockM println() println() */ - return Message.createBufferMessage(buffers) + Message.createBufferMessage(buffers) } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index 61e63c60d56e3b8e46bfa6fd1599863cfd0143be..369a277232b19ca28ce7a4aee1f5779afddb79e6 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -181,4 +181,8 @@ class DiskBlockObjectWriter( // Only valid if called after close() override def timeWriting() = _timeWriting + + def bytesWritten: Long = { + lastValidPosition - initialPosition + } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 55dcb3742c9677829a900779686142a4f3e6559d..a8ef7fa8b63ebc69c35ef40cf55e47b747cc0da3 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -19,7 +19,7 @@ package org.apache.spark.storage import java.io.File import java.text.SimpleDateFormat -import java.util.{Date, Random} +import java.util.{Date, Random, UUID} import org.apache.spark.Logging import org.apache.spark.executor.ExecutorExitCode @@ -38,7 +38,7 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD extends PathResolver with Logging { private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 - private val subDirsPerLocalDir = shuffleManager.conf.get("spark.diskStore.subDirectories", "64").toInt + private val subDirsPerLocalDir = shuffleManager.conf.getInt("spark.diskStore.subDirectories", 64) // Create one local directory for each path mentioned in spark.local.dir; then, inside this // directory, create multiple subdirectories that we will hash files into, in order to avoid @@ -90,6 +90,15 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD def getFile(blockId: BlockId): File = getFile(blockId.name) + /** Produces a unique block id and File suitable for intermediate results. */ + def createTempBlock(): (TempBlockId, File) = { + var blockId = new TempBlockId(UUID.randomUUID()) + while (getFile(blockId).exists()) { + blockId = new TempBlockId(UUID.randomUUID()) + } + (blockId, getFile(blockId)) + } + private def createLocalDirs(): Array[File] = { logDebug("Creating local directories at root dirs '" + rootDirs + "'") val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss") diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 05f676c6e224998550d1108d40e392b3104a6e2d..27f057b9f22f4df42432acefb9efd8e4cab9e3aa 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -245,7 +245,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) return false } } - return true + true } override def contains(blockId: BlockId): Boolean = { diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index 39dc7bb19afeed7bd87a230f7a9b9ac23221675a..e2b24298a55e89b3410d8bace91e026795d8f8d4 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -64,9 +64,9 @@ class ShuffleBlockManager(blockManager: BlockManager) { // Turning off shuffle file consolidation causes all shuffle Blocks to get their own file. // TODO: Remove this once the shuffle file consolidation feature is stable. val consolidateShuffleFiles = - conf.get("spark.shuffle.consolidateFiles", "false").toBoolean + conf.getBoolean("spark.shuffle.consolidateFiles", false) - private val bufferSize = conf.get("spark.shuffle.file.buffer.kb", "100").toInt * 1024 + private val bufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024 /** * Contains all the state related to a particular shuffle. This includes a pool of unused diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala index b5596dffd3449afe8279f83fa528623ea9ba7a1e..0f84810d6be06cc106e44aa5219e902801d959f0 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala @@ -74,7 +74,7 @@ class StorageLevel private( if (deserialized_) { ret |= 1 } - return ret + ret } override def writeExternal(out: ObjectOutput) { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index b7b87250b98ee86adf9adc3abe2b6cac733f9bf9..bcd282445050da5b81a8b6072abc12ca577fc0f5 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -33,7 +33,7 @@ import org.apache.spark.scheduler._ */ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkListener { // How many stages to remember - val RETAINED_STAGES = sc.conf.get("spark.ui.retained_stages", "1000").toInt + val RETAINED_STAGES = sc.conf.getInt("spark.ui.retainedStages", 1000) val DEFAULT_POOL_NAME = "default" val stageIdToPool = new HashMap[Int, String]() diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 8dcfeacb60fc35e108d8439e0d6601a62e96a09e..d1e58016beaac5a7a90d5ef90d78400c3fcad4f6 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -171,7 +171,7 @@ private[spark] class StagePage(parent: JobProgressUI) { summary ++ <h4>Summary Metrics for {numCompleted} Completed Tasks</h4> ++ <div>{summaryTable.getOrElse("No tasks have reported metrics yet.")}</div> ++ - <h4>Aggregated Metrics by Executors</h4> ++ executorTable.toNodeSeq() ++ + <h4>Aggregated Metrics by Executor</h4> ++ executorTable.toNodeSeq() ++ <h4>Tasks</h4> ++ taskTable headerSparkPage(content, parent.sc, "Details for Stage %d".format(stageId), Stages) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 463d85dfd54fdf79ddc15510b301c2a3ab8ff297..9ad6de3c6d8de79c758f1d0764b1a171bd56012c 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -48,7 +48,7 @@ private[spark] class StageTable(val stages: Seq[StageInfo], val parent: JobProgr {if (isFairScheduler) {<th>Pool Name</th>} else {}} <th>Description</th> <th>Submitted</th> - <th>Task Time</th> + <th>Duration</th> <th>Tasks: Succeeded/Total</th> <th>Shuffle Read</th> <th>Shuffle Write</th> diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 3f009a8998cbd7ed1afdecbba4b408b3abd3cee4..761d378c7fd8b5f264151f54a2a78aac435e6368 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -44,13 +44,13 @@ private[spark] object AkkaUtils { def createActorSystem(name: String, host: String, port: Int, indestructible: Boolean = false, conf: SparkConf): (ActorSystem, Int) = { - val akkaThreads = conf.get("spark.akka.threads", "4").toInt - val akkaBatchSize = conf.get("spark.akka.batchSize", "15").toInt + val akkaThreads = conf.getInt("spark.akka.threads", 4) + val akkaBatchSize = conf.getInt("spark.akka.batchSize", 15) - val akkaTimeout = conf.get("spark.akka.timeout", "100").toInt + val akkaTimeout = conf.getInt("spark.akka.timeout", 100) - val akkaFrameSize = conf.get("spark.akka.frameSize", "10").toInt - val akkaLogLifecycleEvents = conf.get("spark.akka.logLifecycleEvents", "false").toBoolean + val akkaFrameSize = conf.getInt("spark.akka.frameSize", 10) + val akkaLogLifecycleEvents = conf.getBoolean("spark.akka.logLifecycleEvents", false) val lifecycleEvents = if (akkaLogLifecycleEvents) "on" else "off" if (!akkaLogLifecycleEvents) { // As a workaround for Akka issue #3787, we coerce the "EndpointWriter" log to be silent. @@ -58,12 +58,12 @@ private[spark] object AkkaUtils { Option(Logger.getLogger("akka.remote.EndpointWriter")).map(l => l.setLevel(Level.FATAL)) } - val logAkkaConfig = if (conf.get("spark.akka.logAkkaConfig", "false").toBoolean) "on" else "off" + val logAkkaConfig = if (conf.getBoolean("spark.akka.logAkkaConfig", false)) "on" else "off" - val akkaHeartBeatPauses = conf.get("spark.akka.heartbeat.pauses", "600").toInt + val akkaHeartBeatPauses = conf.getInt("spark.akka.heartbeat.pauses", 600) val akkaFailureDetector = - conf.get("spark.akka.failure-detector.threshold", "300.0").toDouble - val akkaHeartBeatInterval = conf.get("spark.akka.heartbeat.interval", "1000").toInt + conf.getDouble("spark.akka.failure-detector.threshold", 300.0) + val akkaHeartBeatInterval = conf.getInt("spark.akka.heartbeat.interval", 1000) val akkaConf = ConfigFactory.parseMap(conf.getAkkaConf.toMap[String, String]).withFallback( ConfigFactory.parseString( @@ -103,7 +103,7 @@ private[spark] object AkkaUtils { /** Returns the default Spark timeout to use for Akka ask operations. */ def askTimeout(conf: SparkConf): FiniteDuration = { - Duration.create(conf.get("spark.akka.askTimeout", "30").toLong, "seconds") + Duration.create(conf.getLong("spark.akka.askTimeout", 30), "seconds") } /** Returns the default Spark timeout to use for Akka remote actor lookup. */ diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 7108595e3e44518d38f48d5a76b529d2b0febdfe..1df6b87fb0730cf2aa4ecb05632fd1d49ba8d7b3 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -61,7 +61,7 @@ private[spark] object ClosureCleaner extends Logging { return f.getType :: Nil // Stop at the first $outer that is not a closure } } - return Nil + Nil } // Get a list of the outer objects for a given closure object. @@ -74,7 +74,7 @@ private[spark] object ClosureCleaner extends Logging { return f.get(obj) :: Nil // Stop at the first $outer that is not a closure } } - return Nil + Nil } private def getInnerClasses(obj: AnyRef): List[Class[_]] = { @@ -174,7 +174,7 @@ private[spark] object ClosureCleaner extends Logging { field.setAccessible(true) field.set(obj, outer) } - return obj + obj } } } @@ -182,7 +182,7 @@ private[spark] object ClosureCleaner extends Logging { private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends ClassVisitor(ASM4) { override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { - return new MethodVisitor(ASM4) { + new MethodVisitor(ASM4) { override def visitFieldInsn(op: Int, owner: String, name: String, desc: String) { if (op == GETFIELD) { for (cl <- output.keys if cl.getName == owner.replace('/', '.')) { @@ -215,7 +215,7 @@ private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisi override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { - return new MethodVisitor(ASM4) { + new MethodVisitor(ASM4) { override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) { val argTypes = Type.getArgumentTypes(desc) diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala index aa7f52cafbf37f5f3e45a286652b8308fee82216..ac07a55cb9101a1d867af68e97eac25451c7ca79 100644 --- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala @@ -74,7 +74,7 @@ object MetadataCleanerType extends Enumeration { // initialization of StreamingContext. It's okay for users trying to configure stuff themselves. object MetadataCleaner { def getDelaySeconds(conf: SparkConf) = { - conf.get("spark.cleaner.ttl", "3500").toInt + conf.getInt("spark.cleaner.ttl", -1) } def getDelaySeconds(conf: SparkConf, cleanerType: MetadataCleanerType.MetadataCleanerType): Int = diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index bddb3bb7350bc43af2b958f15ddc5f523de3eb4a..3cf94892e9680e14283148b8c70cd2d41048b5c0 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -108,7 +108,7 @@ private[spark] object SizeEstimator extends Logging { val bean = ManagementFactory.newPlatformMXBeanProxy(server, hotSpotMBeanName, hotSpotMBeanClass) // TODO: We could use reflection on the VMOption returned ? - return getVMMethod.invoke(bean, "UseCompressedOops").toString.contains("true") + getVMMethod.invoke(bean, "UseCompressedOops").toString.contains("true") } catch { case e: Exception => { // Guess whether they've enabled UseCompressedOops based on whether maxMemory < 32 GB @@ -141,7 +141,7 @@ private[spark] object SizeEstimator extends Logging { def dequeue(): AnyRef = { val elem = stack.last stack.trimEnd(1) - return elem + elem } } @@ -162,7 +162,7 @@ private[spark] object SizeEstimator extends Logging { while (!state.isFinished) { visitSingleObject(state.dequeue(), state) } - return state.size + state.size } private def visitSingleObject(obj: AnyRef, state: SearchState) { @@ -276,11 +276,11 @@ private[spark] object SizeEstimator extends Logging { // Create and cache a new ClassInfo val newInfo = new ClassInfo(shellSize, pointerFields) classInfos.put(cls, newInfo) - return newInfo + newInfo } private def alignSize(size: Long): Long = { val rem = size % ALIGN_SIZE - return if (rem == 0) size else (size + ALIGN_SIZE - rem) + if (rem == 0) size else (size + ALIGN_SIZE - rem) } } diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala index 181ae2fd45baf8a4f69ea1b94d8d2aa31d23abba..8e07a0f29addf7c1246c1850737830fc42dde5a3 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala @@ -26,16 +26,23 @@ import org.apache.spark.Logging /** * This is a custom implementation of scala.collection.mutable.Map which stores the insertion - * time stamp along with each key-value pair. Key-value pairs that are older than a particular - * threshold time can them be removed using the clearOldValues method. This is intended to be a drop-in - * replacement of scala.collection.mutable.HashMap. + * timestamp along with each key-value pair. If specified, the timestamp of each pair can be + * updated every time it is accessed. Key-value pairs whose timestamp are older than a particular + * threshold time can then be removed using the clearOldValues method. This is intended to + * be a drop-in replacement of scala.collection.mutable.HashMap. + * @param updateTimeStampOnGet When enabled, the timestamp of a pair will be + * updated when it is accessed */ -class TimeStampedHashMap[A, B] extends Map[A, B]() with Logging { +class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = false) + extends Map[A, B]() with Logging { val internalMap = new ConcurrentHashMap[A, (B, Long)]() def get(key: A): Option[B] = { val value = internalMap.get(key) - if (value != null) Some(value._1) else None + if (value != null && updateTimeStampOnGet) { + internalMap.replace(key, value, (value._1, currentTime)) + } + Option(value).map(_._1) } def iterator: Iterator[(A, B)] = { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 5f1253100b33888834b65f2d6f238412b65dded8..caa9bf4c9280eff3e6ed396b38c0e2d896218498 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -26,37 +26,61 @@ import scala.collection.JavaConversions._ import scala.collection.Map import scala.collection.mutable.ArrayBuffer import scala.io.Source -import scala.reflect.ClassTag +import scala.reflect.{classTag, ClassTag} import com.google.common.io.Files import com.google.common.util.concurrent.ThreadFactoryBuilder +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{Path, FileSystem, FileUtil} +import org.apache.hadoop.io._ import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} import org.apache.spark.deploy.SparkHadoopUtil import java.nio.ByteBuffer -import org.apache.spark.{SparkConf, SparkContext, SparkException, Logging} +import org.apache.spark.{SparkConf, SparkException, Logging} /** * Various utility methods used by Spark. */ private[spark] object Utils extends Logging { + + /** + * We try to clone for most common types of writables and we call WritableUtils.clone otherwise + * intention is to optimize, for example for NullWritable there is no need and for Long, int and + * String creating a new object with value set would be faster. + */ + def cloneWritables[T: ClassTag](conf: Configuration): Writable => T = { + val cloneFunc = classTag[T] match { + case ClassTag(_: Text) => + (w: Writable) => new Text(w.asInstanceOf[Text].getBytes).asInstanceOf[T] + case ClassTag(_: LongWritable) => + (w: Writable) => new LongWritable(w.asInstanceOf[LongWritable].get).asInstanceOf[T] + case ClassTag(_: IntWritable) => + (w: Writable) => new IntWritable(w.asInstanceOf[IntWritable].get).asInstanceOf[T] + case ClassTag(_: NullWritable) => + (w: Writable) => w.asInstanceOf[T] // TODO: should we clone this ? + case _ => + (w: Writable) => WritableUtils.clone(w, conf).asInstanceOf[T] // slower way of cloning. + } + cloneFunc + } + /** Serialize an object using Java serialization */ def serialize[T](o: T): Array[Byte] = { val bos = new ByteArrayOutputStream() val oos = new ObjectOutputStream(bos) oos.writeObject(o) oos.close() - return bos.toByteArray + bos.toByteArray } /** Deserialize an object using Java serialization */ def deserialize[T](bytes: Array[Byte]): T = { val bis = new ByteArrayInputStream(bytes) val ois = new ObjectInputStream(bis) - return ois.readObject.asInstanceOf[T] + ois.readObject.asInstanceOf[T] } /** Deserialize an object using Java serialization and the given ClassLoader */ @@ -66,7 +90,7 @@ private[spark] object Utils extends Logging { override def resolveClass(desc: ObjectStreamClass) = Class.forName(desc.getName, false, loader) } - return ois.readObject.asInstanceOf[T] + ois.readObject.asInstanceOf[T] } /** Deserialize a Long value (used for {@link org.apache.spark.api.python.PythonPartitioner}) */ @@ -144,7 +168,7 @@ private[spark] object Utils extends Logging { i += 1 } } - return buf + buf } private val shutdownDeletePaths = new scala.collection.mutable.HashSet[String]() @@ -396,15 +420,6 @@ private[spark] object Utils extends Logging { InetAddress.getByName(address).getHostName } - def localHostPort(conf: SparkConf): String = { - val retval = conf.get("spark.hostPort", null) - if (retval == null) { - logErrorWithStack("spark.hostPort not set but invoking localHostPort") - return localHostName() - } - retval - } - def checkHost(host: String, message: String = "") { assert(host.indexOf(':') == -1, message) } @@ -413,14 +428,6 @@ private[spark] object Utils extends Logging { assert(hostPort.indexOf(':') != -1, message) } - def logErrorWithStack(msg: String) { - try { - throw new Exception - } catch { - case ex: Exception => logError(msg, ex) - } - } - // Typically, this will be of order of number of nodes in cluster // If not, we should change it to LRUCache or something. private val hostPortParseResults = new ConcurrentHashMap[String, (String, Int)]() @@ -428,7 +435,7 @@ private[spark] object Utils extends Logging { def parseHostPort(hostPort: String): (String, Int) = { { // Check cache first. - var cached = hostPortParseResults.get(hostPort) + val cached = hostPortParseResults.get(hostPort) if (cached != null) return cached } @@ -731,7 +738,7 @@ private[spark] object Utils extends Logging { } catch { case ise: IllegalStateException => return true } - return false + false } def isSpace(c: Char): Boolean = { @@ -748,7 +755,7 @@ private[spark] object Utils extends Logging { var inWord = false var inSingleQuote = false var inDoubleQuote = false - var curWord = new StringBuilder + val curWord = new StringBuilder def endWord() { buf += curWord.toString curWord.clear() @@ -794,7 +801,7 @@ private[spark] object Utils extends Logging { if (inWord || inDoubleQuote || inSingleQuote) { endWord() } - return buf + buf } /* Calculates 'x' modulo 'mod', takes to consideration sign of x, @@ -822,8 +829,7 @@ private[spark] object Utils extends Logging { /** Returns a copy of the system properties that is thread-safe to iterator over. */ def getSystemProperties(): Map[String, String] = { - return System.getProperties().clone() - .asInstanceOf[java.util.Properties].toMap[String, String] + System.getProperties.clone().asInstanceOf[java.util.Properties].toMap[String, String] } /** diff --git a/core/src/main/scala/org/apache/spark/util/Vector.scala b/core/src/main/scala/org/apache/spark/util/Vector.scala index fe710c58acc44b53a0b962033fe3f8ad6435572c..fcdf8486371a40e64ac73110aa77a0edfc8013f9 100644 --- a/core/src/main/scala/org/apache/spark/util/Vector.scala +++ b/core/src/main/scala/org/apache/spark/util/Vector.scala @@ -17,6 +17,8 @@ package org.apache.spark.util +import scala.util.Random + class Vector(val elements: Array[Double]) extends Serializable { def length = elements.length @@ -25,7 +27,7 @@ class Vector(val elements: Array[Double]) extends Serializable { def + (other: Vector): Vector = { if (length != other.length) throw new IllegalArgumentException("Vectors of different length") - return Vector(length, i => this(i) + other(i)) + Vector(length, i => this(i) + other(i)) } def add(other: Vector) = this + other @@ -33,7 +35,7 @@ class Vector(val elements: Array[Double]) extends Serializable { def - (other: Vector): Vector = { if (length != other.length) throw new IllegalArgumentException("Vectors of different length") - return Vector(length, i => this(i) - other(i)) + Vector(length, i => this(i) - other(i)) } def subtract(other: Vector) = this - other @@ -47,7 +49,7 @@ class Vector(val elements: Array[Double]) extends Serializable { ans += this(i) * other(i) i += 1 } - return ans + ans } /** @@ -67,7 +69,7 @@ class Vector(val elements: Array[Double]) extends Serializable { ans += (this(i) + plus(i)) * other(i) i += 1 } - return ans + ans } def += (other: Vector): Vector = { @@ -102,7 +104,7 @@ class Vector(val elements: Array[Double]) extends Serializable { ans += (this(i) - other(i)) * (this(i) - other(i)) i += 1 } - return ans + ans } def dist(other: Vector): Double = math.sqrt(squaredDist(other)) @@ -117,13 +119,19 @@ object Vector { def apply(length: Int, initializer: Int => Double): Vector = { val elements: Array[Double] = Array.tabulate(length)(initializer) - return new Vector(elements) + new Vector(elements) } def zeros(length: Int) = new Vector(new Array[Double](length)) def ones(length: Int) = Vector(length, _ => 1) + /** + * Creates this [[org.apache.spark.util.Vector]] of given length containing random numbers + * between 0.0 and 1.0. Optional [[scala.util.Random]] number generator can be provided. + */ + def random(length: Int, random: Random = new XORShiftRandom()) = Vector(length, _ => random.nextDouble()) + class Multiplier(num: Double) { def * (vec: Vector) = vec * num } diff --git a/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala index e9907e6c855aea4cc945f98c48b59293a441a1d8..08b31ac64f290561d5c5b21edb032e81315099f1 100644 --- a/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala +++ b/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala @@ -91,4 +91,4 @@ private[spark] object XORShiftRandom { } -} \ No newline at end of file +} diff --git a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala similarity index 68% rename from core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala rename to core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala index 8bb4ee3bfa22e3ad1233d778e59feedc892115f7..b8c852b4ff5c78783e64af63f179671027eb90b5 100644 --- a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala @@ -15,7 +15,9 @@ * limitations under the License. */ -package org.apache.spark.util +package org.apache.spark.util.collection + +import java.util.{Arrays, Comparator} /** * A simple open hash table optimized for the append-only use case, where keys @@ -28,14 +30,15 @@ package org.apache.spark.util * TODO: Cache the hash values of each key? java.util.HashMap does that. */ private[spark] -class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] with Serializable { +class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, + V)] with Serializable { require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements") require(initialCapacity >= 1, "Invalid initial capacity") private var capacity = nextPowerOf2(initialCapacity) private var mask = capacity - 1 private var curSize = 0 - private var growThreshold = LOAD_FACTOR * capacity + private var growThreshold = (LOAD_FACTOR * capacity).toInt // Holds keys and values in the same array for memory locality; specifically, the order of // elements is key0, value0, key1, value1, key2, value2, etc. @@ -45,10 +48,15 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi private var haveNullValue = false private var nullValue: V = null.asInstanceOf[V] + // Triggered by destructiveSortedIterator; the underlying data array may no longer be used + private var destroyed = false + private val destructionMessage = "Map state is invalid from destructive sorting!" + private val LOAD_FACTOR = 0.7 /** Get the value for a given key */ def apply(key: K): V = { + assert(!destroyed, destructionMessage) val k = key.asInstanceOf[AnyRef] if (k.eq(null)) { return nullValue @@ -67,11 +75,12 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi i += 1 } } - return null.asInstanceOf[V] + null.asInstanceOf[V] } /** Set the value for a key */ def update(key: K, value: V): Unit = { + assert(!destroyed, destructionMessage) val k = key.asInstanceOf[AnyRef] if (k.eq(null)) { if (!haveNullValue) { @@ -106,6 +115,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi * for key, if any, or null otherwise. Returns the newly updated value. */ def changeValue(key: K, updateFunc: (Boolean, V) => V): V = { + assert(!destroyed, destructionMessage) val k = key.asInstanceOf[AnyRef] if (k.eq(null)) { if (!haveNullValue) { @@ -139,35 +149,38 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi } /** Iterator method from Iterable */ - override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] { - var pos = -1 - - /** Get the next value we should return from next(), or null if we're finished iterating */ - def nextValue(): (K, V) = { - if (pos == -1) { // Treat position -1 as looking at the null value - if (haveNullValue) { - return (null.asInstanceOf[K], nullValue) + override def iterator: Iterator[(K, V)] = { + assert(!destroyed, destructionMessage) + new Iterator[(K, V)] { + var pos = -1 + + /** Get the next value we should return from next(), or null if we're finished iterating */ + def nextValue(): (K, V) = { + if (pos == -1) { // Treat position -1 as looking at the null value + if (haveNullValue) { + return (null.asInstanceOf[K], nullValue) + } + pos += 1 } - pos += 1 - } - while (pos < capacity) { - if (!data(2 * pos).eq(null)) { - return (data(2 * pos).asInstanceOf[K], data(2 * pos + 1).asInstanceOf[V]) + while (pos < capacity) { + if (!data(2 * pos).eq(null)) { + return (data(2 * pos).asInstanceOf[K], data(2 * pos + 1).asInstanceOf[V]) + } + pos += 1 } - pos += 1 + null } - null - } - override def hasNext: Boolean = nextValue() != null + override def hasNext: Boolean = nextValue() != null - override def next(): (K, V) = { - val value = nextValue() - if (value == null) { - throw new NoSuchElementException("End of iterator") + override def next(): (K, V) = { + val value = nextValue() + if (value == null) { + throw new NoSuchElementException("End of iterator") + } + pos += 1 + value } - pos += 1 - value } } @@ -190,7 +203,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi } /** Double the table's size and re-hash everything */ - private def growTable() { + protected def growTable() { val newCapacity = capacity * 2 if (newCapacity >= (1 << 30)) { // We can't make the table this big because we want an array of 2x @@ -227,11 +240,58 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi data = newData capacity = newCapacity mask = newMask - growThreshold = LOAD_FACTOR * newCapacity + growThreshold = (LOAD_FACTOR * newCapacity).toInt } private def nextPowerOf2(n: Int): Int = { val highBit = Integer.highestOneBit(n) if (highBit == n) n else highBit << 1 } + + /** + * Return an iterator of the map in sorted order. This provides a way to sort the map without + * using additional memory, at the expense of destroying the validity of the map. + */ + def destructiveSortedIterator(cmp: Comparator[(K, V)]): Iterator[(K, V)] = { + destroyed = true + // Pack KV pairs into the front of the underlying array + var keyIndex, newIndex = 0 + while (keyIndex < capacity) { + if (data(2 * keyIndex) != null) { + data(newIndex) = (data(2 * keyIndex), data(2 * keyIndex + 1)) + newIndex += 1 + } + keyIndex += 1 + } + assert(curSize == newIndex + (if (haveNullValue) 1 else 0)) + + // Sort by the given ordering + val rawOrdering = new Comparator[AnyRef] { + def compare(x: AnyRef, y: AnyRef): Int = { + cmp.compare(x.asInstanceOf[(K, V)], y.asInstanceOf[(K, V)]) + } + } + Arrays.sort(data, 0, newIndex, rawOrdering) + + new Iterator[(K, V)] { + var i = 0 + var nullValueReady = haveNullValue + def hasNext: Boolean = (i < newIndex || nullValueReady) + def next(): (K, V) = { + if (nullValueReady) { + nullValueReady = false + (null.asInstanceOf[K], nullValue) + } else { + val item = data(i).asInstanceOf[(K, V)] + i += 1 + item + } + } + } + } + + /** + * Return whether the next insert will cause the map to grow + */ + def atGrowThreshold: Boolean = curSize == growThreshold } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala new file mode 100644 index 0000000000000000000000000000000000000000..e3bcd895aa28ffc90b5864f420f37048eecd02fa --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -0,0 +1,350 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import java.io._ +import java.util.Comparator + +import it.unimi.dsi.fastutil.io.FastBufferedInputStream + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{Logging, SparkEnv} +import org.apache.spark.serializer.Serializer +import org.apache.spark.storage.{DiskBlockManager, DiskBlockObjectWriter} + +/** + * An append-only map that spills sorted content to disk when there is insufficient space for it + * to grow. + * + * This map takes two passes over the data: + * + * (1) Values are merged into combiners, which are sorted and spilled to disk as necessary + * (2) Combiners are read from disk and merged together + * + * The setting of the spill threshold faces the following trade-off: If the spill threshold is + * too high, the in-memory map may occupy more memory than is available, resulting in OOM. + * However, if the spill threshold is too low, we spill frequently and incur unnecessary disk + * writes. This may lead to a performance regression compared to the normal case of using the + * non-spilling AppendOnlyMap. + * + * Two parameters control the memory threshold: + * + * `spark.shuffle.memoryFraction` specifies the collective amount of memory used for storing + * these maps as a fraction of the executor's total memory. Since each concurrently running + * task maintains one map, the actual threshold for each map is this quantity divided by the + * number of running tasks. + * + * `spark.shuffle.safetyFraction` specifies an additional margin of safety as a fraction of + * this threshold, in case map size estimation is not sufficiently accurate. + */ + +private[spark] class ExternalAppendOnlyMap[K, V, C]( + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C, + serializer: Serializer = SparkEnv.get.serializerManager.default, + diskBlockManager: DiskBlockManager = SparkEnv.get.blockManager.diskBlockManager) + extends Iterable[(K, C)] with Serializable with Logging { + + import ExternalAppendOnlyMap._ + + private var currentMap = new SizeTrackingAppendOnlyMap[K, C] + private val spilledMaps = new ArrayBuffer[DiskMapIterator] + private val sparkConf = SparkEnv.get.conf + + // Collective memory threshold shared across all running tasks + private val maxMemoryThreshold = { + val memoryFraction = sparkConf.getDouble("spark.shuffle.memoryFraction", 0.3) + val safetyFraction = sparkConf.getDouble("spark.shuffle.safetyFraction", 0.8) + (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong + } + + // Number of pairs in the in-memory map + private var numPairsInMemory = 0 + + // Number of in-memory pairs inserted before tracking the map's shuffle memory usage + private val trackMemoryThreshold = 1000 + + // How many times we have spilled so far + private var spillCount = 0 + + private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024 + private val syncWrites = sparkConf.getBoolean("spark.shuffle.sync", false) + private val comparator = new KCComparator[K, C] + private val ser = serializer.newInstance() + + /** + * Insert the given key and value into the map. + * + * If the underlying map is about to grow, check if the global pool of shuffle memory has + * enough room for this to happen. If so, allocate the memory required to grow the map; + * otherwise, spill the in-memory map to disk. + * + * The shuffle memory usage of the first trackMemoryThreshold entries is not tracked. + */ + def insert(key: K, value: V) { + val update: (Boolean, C) => C = (hadVal, oldVal) => { + if (hadVal) mergeValue(oldVal, value) else createCombiner(value) + } + if (numPairsInMemory > trackMemoryThreshold && currentMap.atGrowThreshold) { + val mapSize = currentMap.estimateSize() + var shouldSpill = false + val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap + + // Atomically check whether there is sufficient memory in the global pool for + // this map to grow and, if possible, allocate the required amount + shuffleMemoryMap.synchronized { + val threadId = Thread.currentThread().getId + val previouslyOccupiedMemory = shuffleMemoryMap.get(threadId) + val availableMemory = maxMemoryThreshold - + (shuffleMemoryMap.values.sum - previouslyOccupiedMemory.getOrElse(0L)) + + // Assume map growth factor is 2x + shouldSpill = availableMemory < mapSize * 2 + if (!shouldSpill) { + shuffleMemoryMap(threadId) = mapSize * 2 + } + } + // Do not synchronize spills + if (shouldSpill) { + spill(mapSize) + } + } + currentMap.changeValue(key, update) + numPairsInMemory += 1 + } + + /** + * Sort the existing contents of the in-memory map and spill them to a temporary file on disk + */ + private def spill(mapSize: Long) { + spillCount += 1 + logWarning("Spilling in-memory map of %d MB to disk (%d time%s so far)" + .format(mapSize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else "")) + val (blockId, file) = diskBlockManager.createTempBlock() + val writer = + new DiskBlockObjectWriter(blockId, file, serializer, fileBufferSize, identity, syncWrites) + try { + val it = currentMap.destructiveSortedIterator(comparator) + while (it.hasNext) { + val kv = it.next() + writer.write(kv) + } + writer.commit() + } finally { + // Partial failures cannot be tolerated; do not revert partial writes + writer.close() + } + currentMap = new SizeTrackingAppendOnlyMap[K, C] + spilledMaps.append(new DiskMapIterator(file)) + + // Reset the amount of shuffle memory used by this map in the global pool + val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap + shuffleMemoryMap.synchronized { + shuffleMemoryMap(Thread.currentThread().getId) = 0 + } + numPairsInMemory = 0 + } + + /** + * Return an iterator that merges the in-memory map with the spilled maps. + * If no spill has occurred, simply return the in-memory map's iterator. + */ + override def iterator: Iterator[(K, C)] = { + if (spilledMaps.isEmpty) { + currentMap.iterator + } else { + new ExternalIterator() + } + } + + /** + * An iterator that sort-merges (K, C) pairs from the in-memory map and the spilled maps + */ + private class ExternalIterator extends Iterator[(K, C)] { + + // A fixed-size queue that maintains a buffer for each stream we are currently merging + val mergeHeap = new mutable.PriorityQueue[StreamBuffer] + + // Input streams are derived both from the in-memory map and spilled maps on disk + // The in-memory map is sorted in place, while the spilled maps are already in sorted order + val sortedMap = currentMap.destructiveSortedIterator(comparator) + val inputStreams = Seq(sortedMap) ++ spilledMaps + + inputStreams.foreach { it => + val kcPairs = getMorePairs(it) + mergeHeap.enqueue(StreamBuffer(it, kcPairs)) + } + + /** + * Fetch from the given iterator until a key of different hash is retrieved. In the + * event of key hash collisions, this ensures no pairs are hidden from being merged. + * Assume the given iterator is in sorted order. + */ + def getMorePairs(it: Iterator[(K, C)]): ArrayBuffer[(K, C)] = { + val kcPairs = new ArrayBuffer[(K, C)] + if (it.hasNext) { + var kc = it.next() + kcPairs += kc + val minHash = kc._1.hashCode() + while (it.hasNext && kc._1.hashCode() == minHash) { + kc = it.next() + kcPairs += kc + } + } + kcPairs + } + + /** + * If the given buffer contains a value for the given key, merge that value into + * baseCombiner and remove the corresponding (K, C) pair from the buffer + */ + def mergeIfKeyExists(key: K, baseCombiner: C, buffer: StreamBuffer): C = { + var i = 0 + while (i < buffer.pairs.size) { + val (k, c) = buffer.pairs(i) + if (k == key) { + buffer.pairs.remove(i) + return mergeCombiners(baseCombiner, c) + } + i += 1 + } + baseCombiner + } + + /** + * Return true if there exists an input stream that still has unvisited pairs + */ + override def hasNext: Boolean = mergeHeap.exists(!_.pairs.isEmpty) + + /** + * Select a key with the minimum hash, then combine all values with the same key from all input streams. + */ + override def next(): (K, C) = { + // Select a key from the StreamBuffer that holds the lowest key hash + val minBuffer = mergeHeap.dequeue() + val (minPairs, minHash) = (minBuffer.pairs, minBuffer.minKeyHash) + if (minPairs.length == 0) { + // Should only happen when no other stream buffers have any pairs left + throw new NoSuchElementException + } + var (minKey, minCombiner) = minPairs.remove(0) + assert(minKey.hashCode() == minHash) + + // For all other streams that may have this key (i.e. have the same minimum key hash), + // merge in the corresponding value (if any) from that stream + val mergedBuffers = ArrayBuffer[StreamBuffer](minBuffer) + while (!mergeHeap.isEmpty && mergeHeap.head.minKeyHash == minHash) { + val newBuffer = mergeHeap.dequeue() + minCombiner = mergeIfKeyExists(minKey, minCombiner, newBuffer) + mergedBuffers += newBuffer + } + + // Repopulate each visited stream buffer and add it back to the merge heap + mergedBuffers.foreach { buffer => + if (buffer.pairs.length == 0) { + buffer.pairs ++= getMorePairs(buffer.iterator) + } + mergeHeap.enqueue(buffer) + } + + (minKey, minCombiner) + } + + /** + * A buffer for streaming from a map iterator (in-memory or on-disk) sorted by key hash. + * Each buffer maintains the lowest-ordered keys in the corresponding iterator. Due to + * hash collisions, it is possible for multiple keys to be "tied" for being the lowest. + * + * StreamBuffers are ordered by the minimum key hash found across all of their own pairs. + */ + case class StreamBuffer(iterator: Iterator[(K, C)], pairs: ArrayBuffer[(K, C)]) + extends Comparable[StreamBuffer] { + + def minKeyHash: Int = { + if (pairs.length > 0){ + // pairs are already sorted by key hash + pairs(0)._1.hashCode() + } else { + Int.MaxValue + } + } + + override def compareTo(other: StreamBuffer): Int = { + // minus sign because mutable.PriorityQueue dequeues the max, not the min + -minKeyHash.compareTo(other.minKeyHash) + } + } + } + + /** + * An iterator that returns (K, C) pairs in sorted order from an on-disk map + */ + private class DiskMapIterator(file: File) extends Iterator[(K, C)] { + val fileStream = new FileInputStream(file) + val bufferedStream = new FastBufferedInputStream(fileStream) + val deserializeStream = ser.deserializeStream(bufferedStream) + var nextItem: (K, C) = null + var eof = false + + def readNextItem(): (K, C) = { + if (!eof) { + try { + return deserializeStream.readObject().asInstanceOf[(K, C)] + } catch { + case e: EOFException => + eof = true + cleanup() + } + } + null + } + + override def hasNext: Boolean = { + if (nextItem == null) { + nextItem = readNextItem() + } + nextItem != null + } + + override def next(): (K, C) = { + val item = if (nextItem == null) readNextItem() else nextItem + if (item == null) { + throw new NoSuchElementException + } + nextItem = null + item + } + + // TODO: Ensure this gets called even if the iterator isn't drained. + def cleanup() { + deserializeStream.close() + file.delete() + } + } +} + +private[spark] object ExternalAppendOnlyMap { + private class KCComparator[K, C] extends Comparator[(K, C)] { + def compare(kc1: (K, C), kc2: (K, C)): Int = { + kc1._1.hashCode().compareTo(kc2._1.hashCode()) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala new file mode 100644 index 0000000000000000000000000000000000000000..204330dad48b929670d83a13d4ddda5051c6f114 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.util.SizeEstimator +import org.apache.spark.util.collection.SizeTrackingAppendOnlyMap.Sample + +/** + * Append-only map that keeps track of its estimated size in bytes. + * We sample with a slow exponential back-off using the SizeEstimator to amortize the time, + * as each call to SizeEstimator can take a sizable amount of time (order of a few milliseconds). + */ +private[spark] class SizeTrackingAppendOnlyMap[K, V] extends AppendOnlyMap[K, V] { + + /** + * Controls the base of the exponential which governs the rate of sampling. + * E.g., a value of 2 would mean we sample at 1, 2, 4, 8, ... elements. + */ + private val SAMPLE_GROWTH_RATE = 1.1 + + /** All samples taken since last resetSamples(). Only the last two are used for extrapolation. */ + private val samples = new ArrayBuffer[Sample]() + + /** Total number of insertions and updates into the map since the last resetSamples(). */ + private var numUpdates: Long = _ + + /** The value of 'numUpdates' at which we will take our next sample. */ + private var nextSampleNum: Long = _ + + /** The average number of bytes per update between our last two samples. */ + private var bytesPerUpdate: Double = _ + + resetSamples() + + /** Called after the map grows in size, as this can be a dramatic change for small objects. */ + def resetSamples() { + numUpdates = 1 + nextSampleNum = 1 + samples.clear() + takeSample() + } + + override def update(key: K, value: V): Unit = { + super.update(key, value) + numUpdates += 1 + if (nextSampleNum == numUpdates) { takeSample() } + } + + override def changeValue(key: K, updateFunc: (Boolean, V) => V): V = { + val newValue = super.changeValue(key, updateFunc) + numUpdates += 1 + if (nextSampleNum == numUpdates) { takeSample() } + newValue + } + + /** Takes a new sample of the current map's size. */ + def takeSample() { + samples += Sample(SizeEstimator.estimate(this), numUpdates) + // Only use the last two samples to extrapolate. If fewer than 2 samples, assume no change. + bytesPerUpdate = math.max(0, samples.toSeq.reverse match { + case latest :: previous :: tail => + (latest.size - previous.size).toDouble / (latest.numUpdates - previous.numUpdates) + case _ => + 0 + }) + nextSampleNum = math.ceil(numUpdates * SAMPLE_GROWTH_RATE).toLong + } + + override protected def growTable() { + super.growTable() + resetSamples() + } + + /** Estimates the current size of the map in bytes. O(1) time. */ + def estimateSize(): Long = { + assert(samples.nonEmpty) + val extrapolatedDelta = bytesPerUpdate * (numUpdates - samples.last.numUpdates) + (samples.last.size + extrapolatedDelta).toLong + } +} + +private object SizeTrackingAppendOnlyMap { + case class Sample(size: Long, numUpdates: Long) +} diff --git a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala index 8dd5786da6ff5c93f3bee3b3287baf89c7551220..3ac706110e287dfb6bd15978af9c53b6013b7ac2 100644 --- a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala @@ -53,7 +53,6 @@ object LocalSparkContext { } // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown System.clearProperty("spark.driver.port") - System.clearProperty("spark.hostPort") } /** Runs `f` by passing in `sc` and ensures that `sc` is stopped. */ diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index afc1beff989c4d47a5181e8b22b6fe4385f14d3d..930c2523caf8c0adc2ced32532b5e54bff090fbf 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -99,7 +99,6 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val hostname = "localhost" val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf) System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext - System.setProperty("spark.hostPort", hostname + ":" + boundPort) val masterTracker = new MapOutputTrackerMaster(conf) masterTracker.trackerActor = actorSystem.actorOf( diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala index 331fa3a642967518eb98c9d46db442415e0a9bf3..d05bbd6ff7e6f0a84a47a3e1c10f834ef5e8b658 100644 --- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala @@ -25,8 +25,8 @@ import net.liftweb.json.JsonAST.JValue import org.scalatest.FunSuite import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, WorkerStateResponse} -import org.apache.spark.deploy.master.{ApplicationInfo, RecoveryState, WorkerInfo} -import org.apache.spark.deploy.worker.ExecutorRunner +import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, RecoveryState, WorkerInfo} +import org.apache.spark.deploy.worker.{ExecutorRunner, DriverRunner} class JsonProtocolSuite extends FunSuite { test("writeApplicationInfo") { @@ -50,11 +50,13 @@ class JsonProtocolSuite extends FunSuite { } test("writeMasterState") { - val workers = Array[WorkerInfo](createWorkerInfo(), createWorkerInfo()) - val activeApps = Array[ApplicationInfo](createAppInfo()) + val workers = Array(createWorkerInfo(), createWorkerInfo()) + val activeApps = Array(createAppInfo()) val completedApps = Array[ApplicationInfo]() + val activeDrivers = Array(createDriverInfo()) + val completedDrivers = Array(createDriverInfo()) val stateResponse = new MasterStateResponse("host", 8080, workers, activeApps, completedApps, - RecoveryState.ALIVE) + activeDrivers, completedDrivers, RecoveryState.ALIVE) val output = JsonProtocol.writeMasterState(stateResponse) assertValidJson(output) } @@ -62,26 +64,44 @@ class JsonProtocolSuite extends FunSuite { test("writeWorkerState") { val executors = List[ExecutorRunner]() val finishedExecutors = List[ExecutorRunner](createExecutorRunner(), createExecutorRunner()) + val drivers = List(createDriverRunner()) + val finishedDrivers = List(createDriverRunner(), createDriverRunner()) val stateResponse = new WorkerStateResponse("host", 8080, "workerId", executors, - finishedExecutors, "masterUrl", 4, 1234, 4, 1234, "masterWebUiUrl") + finishedExecutors, drivers, finishedDrivers, "masterUrl", 4, 1234, 4, 1234, "masterWebUiUrl") val output = JsonProtocol.writeWorkerState(stateResponse) assertValidJson(output) } - def createAppDesc() : ApplicationDescription = { + def createAppDesc(): ApplicationDescription = { val cmd = new Command("mainClass", List("arg1", "arg2"), Map()) new ApplicationDescription("name", Some(4), 1234, cmd, "sparkHome", "appUiUrl") } + def createAppInfo() : ApplicationInfo = { new ApplicationInfo( 3, "id", createAppDesc(), new Date(123456789), null, "appUriStr", Int.MaxValue) } - def createWorkerInfo() : WorkerInfo = { + + def createDriverCommand() = new Command( + "org.apache.spark.FakeClass", Seq("some arg --and-some options -g foo"), + Map(("K1", "V1"), ("K2", "V2")) + ) + + def createDriverDesc() = new DriverDescription("hdfs://some-dir/some.jar", 100, 3, + false, createDriverCommand()) + + def createDriverInfo(): DriverInfo = new DriverInfo(3, "driver-3", createDriverDesc(), new Date()) + + def createWorkerInfo(): WorkerInfo = { new WorkerInfo("id", "host", 8080, 4, 1234, null, 80, "publicAddress") } - def createExecutorRunner() : ExecutorRunner = { + def createExecutorRunner(): ExecutorRunner = { new ExecutorRunner("appId", 123, createAppDesc(), 4, 1234, null, "workerId", "host", - new File("sparkHome"), new File("workDir"), ExecutorState.RUNNING) + new File("sparkHome"), new File("workDir"), "akka://worker", ExecutorState.RUNNING) + } + def createDriverRunner(): DriverRunner = { + new DriverRunner("driverId", new File("workDir"), new File("sparkHome"), createDriverDesc(), + null, "akka://worker") } def assertValidJson(json: JValue) { diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala new file mode 100644 index 0000000000000000000000000000000000000000..45dbcaffae94f335d73ffc5d756b9c92e9a384d7 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala @@ -0,0 +1,131 @@ +package org.apache.spark.deploy.worker + +import java.io.File + +import scala.collection.JavaConversions._ + +import org.mockito.Mockito._ +import org.mockito.Matchers._ +import org.scalatest.FunSuite + +import org.apache.spark.deploy.{Command, DriverDescription} +import org.mockito.stubbing.Answer +import org.mockito.invocation.InvocationOnMock + +class DriverRunnerTest extends FunSuite { + private def createDriverRunner() = { + val command = new Command("mainClass", Seq(), Map()) + val driverDescription = new DriverDescription("jarUrl", 512, 1, true, command) + new DriverRunner("driverId", new File("workDir"), new File("sparkHome"), driverDescription, + null, "akka://1.2.3.4/worker/") + } + + private def createProcessBuilderAndProcess(): (ProcessBuilderLike, Process) = { + val processBuilder = mock(classOf[ProcessBuilderLike]) + when(processBuilder.command).thenReturn(Seq("mocked", "command")) + val process = mock(classOf[Process]) + when(processBuilder.start()).thenReturn(process) + (processBuilder, process) + } + + test("Process succeeds instantly") { + val runner = createDriverRunner() + + val sleeper = mock(classOf[Sleeper]) + runner.setSleeper(sleeper) + + val (processBuilder, process) = createProcessBuilderAndProcess() + // One failure then a successful run + when(process.waitFor()).thenReturn(0) + runner.runCommandWithRetry(processBuilder, p => (), supervise = true) + + verify(process, times(1)).waitFor() + verify(sleeper, times(0)).sleep(anyInt()) + } + + test("Process failing several times and then succeeding") { + val runner = createDriverRunner() + + val sleeper = mock(classOf[Sleeper]) + runner.setSleeper(sleeper) + + val (processBuilder, process) = createProcessBuilderAndProcess() + // fail, fail, fail, success + when(process.waitFor()).thenReturn(-1).thenReturn(-1).thenReturn(-1).thenReturn(0) + runner.runCommandWithRetry(processBuilder, p => (), supervise = true) + + verify(process, times(4)).waitFor() + verify(sleeper, times(3)).sleep(anyInt()) + verify(sleeper, times(1)).sleep(1) + verify(sleeper, times(1)).sleep(2) + verify(sleeper, times(1)).sleep(4) + } + + test("Process doesn't restart if not supervised") { + val runner = createDriverRunner() + + val sleeper = mock(classOf[Sleeper]) + runner.setSleeper(sleeper) + + val (processBuilder, process) = createProcessBuilderAndProcess() + when(process.waitFor()).thenReturn(-1) + + runner.runCommandWithRetry(processBuilder, p => (), supervise = false) + + verify(process, times(1)).waitFor() + verify(sleeper, times(0)).sleep(anyInt()) + } + + test("Process doesn't restart if killed") { + val runner = createDriverRunner() + + val sleeper = mock(classOf[Sleeper]) + runner.setSleeper(sleeper) + + val (processBuilder, process) = createProcessBuilderAndProcess() + when(process.waitFor()).thenAnswer(new Answer[Int] { + def answer(invocation: InvocationOnMock): Int = { + runner.kill() + -1 + } + }) + + runner.runCommandWithRetry(processBuilder, p => (), supervise = true) + + verify(process, times(1)).waitFor() + verify(sleeper, times(0)).sleep(anyInt()) + } + + test("Reset of backoff counter") { + val runner = createDriverRunner() + + val sleeper = mock(classOf[Sleeper]) + runner.setSleeper(sleeper) + + val clock = mock(classOf[Clock]) + runner.setClock(clock) + + val (processBuilder, process) = createProcessBuilderAndProcess() + + when(process.waitFor()) + .thenReturn(-1) // fail 1 + .thenReturn(-1) // fail 2 + .thenReturn(-1) // fail 3 + .thenReturn(-1) // fail 4 + .thenReturn(0) // success + when(clock.currentTimeMillis()) + .thenReturn(0).thenReturn(1000) // fail 1 (short) + .thenReturn(1000).thenReturn(2000) // fail 2 (short) + .thenReturn(2000).thenReturn(10000) // fail 3 (long) + .thenReturn(10000).thenReturn(11000) // fail 4 (short) + .thenReturn(11000).thenReturn(21000) // success (long) + + runner.runCommandWithRetry(processBuilder, p => (), supervise = true) + + verify(sleeper, times(4)).sleep(anyInt()) + // Expected sequence of sleeps is 1,2,1,2 + verify(sleeper, times(2)).sleep(1) + verify(sleeper, times(2)).sleep(2) + } + +} diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala index be93074b7b3b01d8c335cf773f9987a595d657d7..a79ee690d39ff37eb6f77100471f1ee9d76040fb 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala @@ -31,8 +31,8 @@ class ExecutorRunnerTest extends FunSuite { sparkHome, "appUiUrl") val appId = "12345-worker321-9876" val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321", f(sparkHome), - f("ooga"), ExecutorState.RUNNING) + f("ooga"), "blah", ExecutorState.RUNNING) - assert(er.buildCommandSeq().last === appId) + assert(er.getCommandSeq.last === appId) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..94d88d307a163ee08fd2167ad572deb027a6c28b --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala @@ -0,0 +1,32 @@ +package org.apache.spark.deploy.worker + + +import akka.testkit.TestActorRef +import org.scalatest.FunSuite +import akka.remote.DisassociatedEvent +import akka.actor.{ActorSystem, AddressFromURIString, Props} + +class WorkerWatcherSuite extends FunSuite { + test("WorkerWatcher shuts down on valid disassociation") { + val actorSystem = ActorSystem("test") + val targetWorkerUrl = "akka://1.2.3.4/user/Worker" + val targetWorkerAddress = AddressFromURIString(targetWorkerUrl) + val actorRef = TestActorRef[WorkerWatcher](Props(classOf[WorkerWatcher], targetWorkerUrl))(actorSystem) + val workerWatcher = actorRef.underlyingActor + workerWatcher.setTesting(testing = true) + actorRef.underlyingActor.receive(new DisassociatedEvent(null, targetWorkerAddress, false)) + assert(actorRef.underlyingActor.isShutDown) + } + + test("WorkerWatcher stays alive on invalid disassociation") { + val actorSystem = ActorSystem("test") + val targetWorkerUrl = "akka://1.2.3.4/user/Worker" + val otherAkkaURL = "akka://4.3.2.1/user/OtherActor" + val otherAkkaAddress = AddressFromURIString(otherAkkaURL) + val actorRef = TestActorRef[WorkerWatcher](Props(classOf[WorkerWatcher], targetWorkerUrl))(actorSystem) + val workerWatcher = actorRef.underlyingActor + workerWatcher.setTesting(testing = true) + actorRef.underlyingActor.receive(new DisassociatedEvent(null, otherAkkaAddress, false)) + assert(!actorRef.underlyingActor.isShutDown) + } +} \ No newline at end of file diff --git a/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala index 7bf2020fe378eae92a89355d20005efc9c822ec9..235d31709af2b69d3965a8fa262cc1febee6e01e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala @@ -64,7 +64,7 @@ class FakeTaskSetManager( } override def getSchedulableByName(name: String): Schedulable = { - return null + null } override def executorLost(executorId: String, host: String): Unit = { @@ -79,13 +79,14 @@ class FakeTaskSetManager( { if (tasksSuccessful + runningTasks < numTasks) { increaseRunningTasks(1) - return Some(new TaskDescription(0, execId, "task 0:0", 0, null)) + Some(new TaskDescription(0, execId, "task 0:0", 0, null)) + } else { + None } - return None } override def checkSpeculatableTasks(): Boolean = { - return true + true } def taskFinished() { diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 2aa259daf38b30b5e8670c70e5a4e95bc0cb6cf9..f0236ef1e975b4c0c28959d7c50a63a535509087 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -122,7 +122,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont locations: Seq[Seq[String]] = Nil ): MyRDD = { val maxPartition = numPartitions - 1 - return new MyRDD(sc, dependencies) { + val newRDD = new MyRDD(sc, dependencies) { override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = throw new RuntimeException("should not be reached") override def getPartitions = (0 to maxPartition).map(i => new Partition { @@ -135,6 +135,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont Nil override def toString: String = "DAGSchedulerSuiteRDD " + id } + newRDD } /** diff --git a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala index 5cc48ee00a8990bbe5a97ba7812e2a6c0af9cdfc..29102913c719c12a9067782beb8a85a526ec91e2 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala @@ -42,12 +42,9 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers def buildJobDepTest(jobID: Int, stage: Stage) = buildJobDep(jobID, stage) } type MyRDD = RDD[(Int, Int)] - def makeRdd( - numPartitions: Int, - dependencies: List[Dependency[_]] - ): MyRDD = { + def makeRdd(numPartitions: Int, dependencies: List[Dependency[_]]): MyRDD = { val maxPartition = numPartitions - 1 - return new MyRDD(sc, dependencies) { + new MyRDD(sc, dependencies) { override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = throw new RuntimeException("should not be reached") override def getPartitions = (0 to maxPartition).map(i => new Partition { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 1eec6726f48bc0900c64c97b0bb9edaaa641705f..c9f6cc5d079b5909459ec4fb843c57c9d0354aef 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -83,7 +83,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { private val conf = new SparkConf - val LOCALITY_WAIT = conf.get("spark.locality.wait", "3000").toLong + val LOCALITY_WAIT = conf.getLong("spark.locality.wait", 3000) val MAX_TASK_FAILURES = 4 test("TaskSet with no preferences") { diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index f60ce270c7387f06b08afcd1e68168ca53037ce8..18aa587662d245b9444ddbb9d68c2c6bc1c1046f 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -53,7 +53,6 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0, conf = conf) this.actorSystem = actorSystem conf.set("spark.driver.port", boundPort.toString) - conf.set("spark.hostPort", "localhost:" + boundPort) master = new BlockManagerMaster( actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf))), conf) @@ -65,13 +64,10 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT conf.set("spark.storage.disableBlockManagerHeartBeat", "true") val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() - // Set some value ... - conf.set("spark.hostPort", Utils.localHostName() + ":" + 1111) } after { System.clearProperty("spark.driver.port") - System.clearProperty("spark.hostPort") if (store != null) { store.stop() diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 0ed366fb707b7d960e504a240eadc51cd654311a..de4871d0433ef35d5ba20fd982dc4c5ff7185874 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -61,8 +61,8 @@ class NonSerializable {} object TestObject { def run(): Int = { var nonSer = new NonSerializable - var x = 5 - return withSpark(new SparkContext("local", "test")) { sc => + val x = 5 + withSpark(new SparkContext("local", "test")) { sc => val nums = sc.parallelize(Array(1, 2, 3, 4)) nums.map(_ + x).reduce(_ + _) } @@ -76,7 +76,7 @@ class TestClass extends Serializable { def run(): Int = { var nonSer = new NonSerializable - return withSpark(new SparkContext("local", "test")) { sc => + withSpark(new SparkContext("local", "test")) { sc => val nums = sc.parallelize(Array(1, 2, 3, 4)) nums.map(_ + getX).reduce(_ + _) } @@ -88,7 +88,7 @@ class TestClassWithoutDefaultConstructor(x: Int) extends Serializable { def run(): Int = { var nonSer = new NonSerializable - return withSpark(new SparkContext("local", "test")) { sc => + withSpark(new SparkContext("local", "test")) { sc => val nums = sc.parallelize(Array(1, 2, 3, 4)) nums.map(_ + getX).reduce(_ + _) } @@ -103,7 +103,7 @@ class TestClassWithoutFieldAccess { def run(): Int = { var nonSer2 = new NonSerializable var x = 5 - return withSpark(new SparkContext("local", "test")) { sc => + withSpark(new SparkContext("local", "test")) { sc => val nums = sc.parallelize(Array(1, 2, 3, 4)) nums.map(_ + x).reduce(_ + _) } @@ -115,7 +115,7 @@ object TestObjectWithNesting { def run(): Int = { var nonSer = new NonSerializable var answer = 0 - return withSpark(new SparkContext("local", "test")) { sc => + withSpark(new SparkContext("local", "test")) { sc => val nums = sc.parallelize(Array(1, 2, 3, 4)) var y = 1 for (i <- 1 to 4) { @@ -134,7 +134,7 @@ class TestClassWithNesting(val y: Int) extends Serializable { def run(): Int = { var nonSer = new NonSerializable var answer = 0 - return withSpark(new SparkContext("local", "test")) { sc => + withSpark(new SparkContext("local", "test")) { sc => val nums = sc.parallelize(Array(1, 2, 3, 4)) for (i <- 1 to 4) { var nonSer2 = new NonSerializable diff --git a/core/src/test/scala/org/apache/spark/util/SizeTrackingAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeTrackingAppendOnlyMapSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..93f0c6a8e64089138e2e1ef3433942a3e0eb0982 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/SizeTrackingAppendOnlyMapSuite.scala @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import scala.util.Random + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import org.apache.spark.util.SizeTrackingAppendOnlyMapSuite.LargeDummyClass +import org.apache.spark.util.collection.{AppendOnlyMap, SizeTrackingAppendOnlyMap} + +class SizeTrackingAppendOnlyMapSuite extends FunSuite with BeforeAndAfterAll { + val NORMAL_ERROR = 0.20 + val HIGH_ERROR = 0.30 + + test("fixed size insertions") { + testWith[Int, Long](10000, i => (i, i.toLong)) + testWith[Int, (Long, Long)](10000, i => (i, (i.toLong, i.toLong))) + testWith[Int, LargeDummyClass](10000, i => (i, new LargeDummyClass())) + } + + test("variable size insertions") { + val rand = new Random(123456789) + def randString(minLen: Int, maxLen: Int): String = { + "a" * (rand.nextInt(maxLen - minLen) + minLen) + } + testWith[Int, String](10000, i => (i, randString(0, 10))) + testWith[Int, String](10000, i => (i, randString(0, 100))) + testWith[Int, String](10000, i => (i, randString(90, 100))) + } + + test("updates") { + val rand = new Random(123456789) + def randString(minLen: Int, maxLen: Int): String = { + "a" * (rand.nextInt(maxLen - minLen) + minLen) + } + testWith[String, Int](10000, i => (randString(0, 10000), i)) + } + + def testWith[K, V](numElements: Int, makeElement: (Int) => (K, V)) { + val map = new SizeTrackingAppendOnlyMap[K, V]() + for (i <- 0 until numElements) { + val (k, v) = makeElement(i) + map(k) = v + expectWithinError(map, map.estimateSize(), if (i < 32) HIGH_ERROR else NORMAL_ERROR) + } + } + + def expectWithinError(obj: AnyRef, estimatedSize: Long, error: Double) { + val betterEstimatedSize = SizeEstimator.estimate(obj) + assert(betterEstimatedSize * (1 - error) < estimatedSize, + s"Estimated size $estimatedSize was less than expected size $betterEstimatedSize") + assert(betterEstimatedSize * (1 + 2 * error) > estimatedSize, + s"Estimated size $estimatedSize was greater than expected size $betterEstimatedSize") + } +} + +object SizeTrackingAppendOnlyMapSuite { + // Speed test, for reproducibility of results. + // These could be highly non-deterministic in general, however. + // Results: + // AppendOnlyMap: 31 ms + // SizeTracker: 54 ms + // SizeEstimator: 1500 ms + def main(args: Array[String]) { + val numElements = 100000 + + val baseTimes = for (i <- 0 until 10) yield time { + val map = new AppendOnlyMap[Int, LargeDummyClass]() + for (i <- 0 until numElements) { + map(i) = new LargeDummyClass() + } + } + + val sampledTimes = for (i <- 0 until 10) yield time { + val map = new SizeTrackingAppendOnlyMap[Int, LargeDummyClass]() + for (i <- 0 until numElements) { + map(i) = new LargeDummyClass() + map.estimateSize() + } + } + + val unsampledTimes = for (i <- 0 until 3) yield time { + val map = new AppendOnlyMap[Int, LargeDummyClass]() + for (i <- 0 until numElements) { + map(i) = new LargeDummyClass() + SizeEstimator.estimate(map) + } + } + + println("Base: " + baseTimes) + println("SizeTracker (sampled): " + sampledTimes) + println("SizeEstimator (unsampled): " + unsampledTimes) + } + + def time(f: => Unit): Long = { + val start = System.currentTimeMillis() + f + System.currentTimeMillis() - start + } + + private class LargeDummyClass { + val arr = new Array[Int](100) + } +} diff --git a/core/src/test/scala/org/apache/spark/util/VectorSuite.scala b/core/src/test/scala/org/apache/spark/util/VectorSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..7006571ef0ef6d36503ba5b176212f56455e344d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/VectorSuite.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import scala.util.Random + +import org.scalatest.FunSuite + +/** + * Tests org.apache.spark.util.Vector functionality + */ +class VectorSuite extends FunSuite { + + def verifyVector(vector: Vector, expectedLength: Int) = { + assert(vector.length == expectedLength) + assert(vector.elements.min > 0.0) + assert(vector.elements.max < 1.0) + } + + test("random with default random number generator") { + val vector100 = Vector.random(100) + verifyVector(vector100, 100) + } + + test("random with given random number generator") { + val vector100 = Vector.random(100, new Random(100)) + verifyVector(vector100, 100) + } +} diff --git a/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala index b78367b6cac028b217abea0fcbb79c08ab270045..f1d7b61b31e635ba816fdabbcda780363a03ae49 100644 --- a/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala @@ -73,4 +73,4 @@ class XORShiftRandomSuite extends FunSuite with ShouldMatchers { } -} \ No newline at end of file +} diff --git a/core/src/test/scala/org/apache/spark/util/AppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/AppendOnlyMapSuite.scala similarity index 75% rename from core/src/test/scala/org/apache/spark/util/AppendOnlyMapSuite.scala rename to core/src/test/scala/org/apache/spark/util/collection/AppendOnlyMapSuite.scala index 7177919a58157c89400ae2fe88eaa256b08b945e..f44442f1a5328017b0205e09b19f1b92a1a5412e 100644 --- a/core/src/test/scala/org/apache/spark/util/AppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/AppendOnlyMapSuite.scala @@ -15,11 +15,12 @@ * limitations under the License. */ -package org.apache.spark.util +package org.apache.spark.util.collection import scala.collection.mutable.HashSet import org.scalatest.FunSuite +import java.util.Comparator class AppendOnlyMapSuite extends FunSuite { test("initialization") { @@ -151,4 +152,47 @@ class AppendOnlyMapSuite extends FunSuite { assert(map("" + i) === "" + i) } } + + test("destructive sort") { + val map = new AppendOnlyMap[String, String]() + for (i <- 1 to 100) { + map("" + i) = "" + i + } + map.update(null, "happy new year!") + + try { + map.apply("1") + map.update("1", "2013") + map.changeValue("1", (hadValue, oldValue) => "2014") + map.iterator + } catch { + case e: IllegalStateException => fail() + } + + val it = map.destructiveSortedIterator(new Comparator[(String, String)] { + def compare(kv1: (String, String), kv2: (String, String)): Int = { + val x = if (kv1 != null && kv1._1 != null) kv1._1.toInt else Int.MinValue + val y = if (kv2 != null && kv2._1 != null) kv2._1.toInt else Int.MinValue + x.compareTo(y) + } + }) + + // Should be sorted by key + assert(it.hasNext) + var previous = it.next() + assert(previous == (null, "happy new year!")) + previous = it.next() + assert(previous == ("1", "2014")) + while (it.hasNext) { + val kv = it.next() + assert(kv._1.toInt > previous._1.toInt) + previous = kv + } + + // All subsequent calls to apply, update, changeValue and iterator should throw exception + intercept[AssertionError] { map.apply("1") } + intercept[AssertionError] { map.update("1", "2013") } + intercept[AssertionError] { map.changeValue("1", (hadValue, oldValue) => "2014") } + intercept[AssertionError] { map.iterator } + } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..ef957bb0e5d17d5911976645a9c93e2b49db08f5 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -0,0 +1,230 @@ +package org.apache.spark.util.collection + +import scala.collection.mutable.ArrayBuffer + +import org.scalatest.{BeforeAndAfter, FunSuite} + +import org.apache.spark._ +import org.apache.spark.SparkContext._ + +class ExternalAppendOnlyMapSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { + + override def beforeEach() { + val conf = new SparkConf(false) + conf.set("spark.shuffle.externalSorting", "true") + sc = new SparkContext("local", "test", conf) + } + + val createCombiner: (Int => ArrayBuffer[Int]) = i => ArrayBuffer[Int](i) + val mergeValue: (ArrayBuffer[Int], Int) => ArrayBuffer[Int] = (buffer, i) => { + buffer += i + } + val mergeCombiners: (ArrayBuffer[Int], ArrayBuffer[Int]) => ArrayBuffer[Int] = + (buf1, buf2) => { + buf1 ++= buf2 + } + + test("simple insert") { + val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner, + mergeValue, mergeCombiners) + + // Single insert + map.insert(1, 10) + var it = map.iterator + assert(it.hasNext) + val kv = it.next() + assert(kv._1 == 1 && kv._2 == ArrayBuffer[Int](10)) + assert(!it.hasNext) + + // Multiple insert + map.insert(2, 20) + map.insert(3, 30) + it = map.iterator + assert(it.hasNext) + assert(it.toSet == Set[(Int, ArrayBuffer[Int])]( + (1, ArrayBuffer[Int](10)), + (2, ArrayBuffer[Int](20)), + (3, ArrayBuffer[Int](30)))) + } + + test("insert with collision") { + val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner, + mergeValue, mergeCombiners) + + map.insert(1, 10) + map.insert(2, 20) + map.insert(3, 30) + map.insert(1, 100) + map.insert(2, 200) + map.insert(1, 1000) + val it = map.iterator + assert(it.hasNext) + val result = it.toSet[(Int, ArrayBuffer[Int])].map(kv => (kv._1, kv._2.toSet)) + assert(result == Set[(Int, Set[Int])]( + (1, Set[Int](10, 100, 1000)), + (2, Set[Int](20, 200)), + (3, Set[Int](30)))) + } + + test("ordering") { + val map1 = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner, + mergeValue, mergeCombiners) + map1.insert(1, 10) + map1.insert(2, 20) + map1.insert(3, 30) + + val map2 = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner, + mergeValue, mergeCombiners) + map2.insert(2, 20) + map2.insert(3, 30) + map2.insert(1, 10) + + val map3 = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner, + mergeValue, mergeCombiners) + map3.insert(3, 30) + map3.insert(1, 10) + map3.insert(2, 20) + + val it1 = map1.iterator + val it2 = map2.iterator + val it3 = map3.iterator + + var kv1 = it1.next() + var kv2 = it2.next() + var kv3 = it3.next() + assert(kv1._1 == kv2._1 && kv2._1 == kv3._1) + assert(kv1._2 == kv2._2 && kv2._2 == kv3._2) + + kv1 = it1.next() + kv2 = it2.next() + kv3 = it3.next() + assert(kv1._1 == kv2._1 && kv2._1 == kv3._1) + assert(kv1._2 == kv2._2 && kv2._2 == kv3._2) + + kv1 = it1.next() + kv2 = it2.next() + kv3 = it3.next() + assert(kv1._1 == kv2._1 && kv2._1 == kv3._1) + assert(kv1._2 == kv2._2 && kv2._2 == kv3._2) + } + + test("null keys and values") { + val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner, + mergeValue, mergeCombiners) + map.insert(1, 5) + map.insert(2, 6) + map.insert(3, 7) + assert(map.size === 3) + assert(map.iterator.toSet == Set[(Int, Seq[Int])]( + (1, Seq[Int](5)), + (2, Seq[Int](6)), + (3, Seq[Int](7)) + )) + + // Null keys + val nullInt = null.asInstanceOf[Int] + map.insert(nullInt, 8) + assert(map.size === 4) + assert(map.iterator.toSet == Set[(Int, Seq[Int])]( + (1, Seq[Int](5)), + (2, Seq[Int](6)), + (3, Seq[Int](7)), + (nullInt, Seq[Int](8)) + )) + + // Null values + map.insert(4, nullInt) + map.insert(nullInt, nullInt) + assert(map.size === 5) + val result = map.iterator.toSet[(Int, ArrayBuffer[Int])].map(kv => (kv._1, kv._2.toSet)) + assert(result == Set[(Int, Set[Int])]( + (1, Set[Int](5)), + (2, Set[Int](6)), + (3, Set[Int](7)), + (4, Set[Int](nullInt)), + (nullInt, Set[Int](nullInt, 8)) + )) + } + + test("simple aggregator") { + // reduceByKey + val rdd = sc.parallelize(1 to 10).map(i => (i%2, 1)) + val result1 = rdd.reduceByKey(_+_).collect() + assert(result1.toSet == Set[(Int, Int)]((0, 5), (1, 5))) + + // groupByKey + val result2 = rdd.groupByKey().collect() + assert(result2.toSet == Set[(Int, Seq[Int])] + ((0, ArrayBuffer[Int](1, 1, 1, 1, 1)), (1, ArrayBuffer[Int](1, 1, 1, 1, 1)))) + } + + test("simple cogroup") { + val rdd1 = sc.parallelize(1 to 4).map(i => (i, i)) + val rdd2 = sc.parallelize(1 to 4).map(i => (i%2, i)) + val result = rdd1.cogroup(rdd2).collect() + + result.foreach { case (i, (seq1, seq2)) => + i match { + case 0 => assert(seq1.toSet == Set[Int]() && seq2.toSet == Set[Int](2, 4)) + case 1 => assert(seq1.toSet == Set[Int](1) && seq2.toSet == Set[Int](1, 3)) + case 2 => assert(seq1.toSet == Set[Int](2) && seq2.toSet == Set[Int]()) + case 3 => assert(seq1.toSet == Set[Int](3) && seq2.toSet == Set[Int]()) + case 4 => assert(seq1.toSet == Set[Int](4) && seq2.toSet == Set[Int]()) + } + } + } + + test("spilling") { + // TODO: Figure out correct memory parameters to actually induce spilling + // System.setProperty("spark.shuffle.buffer.mb", "1") + // System.setProperty("spark.shuffle.buffer.fraction", "0.05") + + // reduceByKey - should spill exactly 6 times + val rddA = sc.parallelize(0 until 10000).map(i => (i/2, i)) + val resultA = rddA.reduceByKey(math.max(_, _)).collect() + assert(resultA.length == 5000) + resultA.foreach { case(k, v) => + k match { + case 0 => assert(v == 1) + case 2500 => assert(v == 5001) + case 4999 => assert(v == 9999) + case _ => + } + } + + // groupByKey - should spill exactly 11 times + val rddB = sc.parallelize(0 until 10000).map(i => (i/4, i)) + val resultB = rddB.groupByKey().collect() + assert(resultB.length == 2500) + resultB.foreach { case(i, seq) => + i match { + case 0 => assert(seq.toSet == Set[Int](0, 1, 2, 3)) + case 1250 => assert(seq.toSet == Set[Int](5000, 5001, 5002, 5003)) + case 2499 => assert(seq.toSet == Set[Int](9996, 9997, 9998, 9999)) + case _ => + } + } + + // cogroup - should spill exactly 7 times + val rddC1 = sc.parallelize(0 until 1000).map(i => (i, i)) + val rddC2 = sc.parallelize(0 until 1000).map(i => (i%100, i)) + val resultC = rddC1.cogroup(rddC2).collect() + assert(resultC.length == 1000) + resultC.foreach { case(i, (seq1, seq2)) => + i match { + case 0 => + assert(seq1.toSet == Set[Int](0)) + assert(seq2.toSet == Set[Int](0, 100, 200, 300, 400, 500, 600, 700, 800, 900)) + case 500 => + assert(seq1.toSet == Set[Int](500)) + assert(seq2.toSet == Set[Int]()) + case 999 => + assert(seq1.toSet == Set[Int](999)) + assert(seq2.toSet == Set[Int]()) + case _ => + } + } + } + + // TODO: Test memory allocation for multiple concurrently running tasks +} diff --git a/docs/configuration.md b/docs/configuration.md index 1d6c3d16333c5ebc86a348113ca98a265bbc94c4..40a57c4bc6a20cf6d5d104405a11b9ea8299d952 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -104,13 +104,24 @@ Apart from these, the following properties are also available, and may be useful </tr> <tr> <td>spark.storage.memoryFraction</td> - <td>0.66</td> + <td>0.6</td> <td> Fraction of Java heap to use for Spark's memory cache. This should not be larger than the "old" - generation of objects in the JVM, which by default is given 2/3 of the heap, but you can increase + generation of objects in the JVM, which by default is given 0.6 of the heap, but you can increase it if you configure your own old generation size. </td> </tr> +<tr> + <td>spark.shuffle.memoryFraction</td> + <td>0.3</td> + <td> + Fraction of Java heap to use for aggregation and cogroups during shuffles, if + <code>spark.shuffle.externalSorting</code> is enabled. At any given time, the collective size of + all in-memory maps used for shuffles is bounded by this limit, beyond which the contents will + begin to spill to disk. If spills are often, consider increasing this value at the expense of + <code>spark.storage.memoryFraction</code>. + </td> +</tr> <tr> <td>spark.mesos.coarse</td> <td>false</td> @@ -130,7 +141,7 @@ Apart from these, the following properties are also available, and may be useful </td> </tr> <tr> - <td>spark.ui.retained_stages</td> + <td>spark.ui.retainedStages</td> <td>1000</td> <td> How many stages the Spark UI remembers before garbage collecting. @@ -376,6 +387,14 @@ Apart from these, the following properties are also available, and may be useful If set to "true", consolidates intermediate files created during a shuffle. Creating fewer files can improve filesystem performance for shuffles with large numbers of reduce tasks. It is recommended to set this to "true" when using ext4 or xfs filesystems. On ext3, this option might degrade performance on machines with many (>8) cores due to filesystem limitations. </td> </tr> +<tr> + <td>spark.shuffle.externalSorting</td> + <td>true</td> + <td> + If set to "true", limits the amount of memory used during reduces by spilling data out to disk. This spilling + threshold is specified by <code>spark.shuffle.memoryFraction</code>. + </td> +</tr> <tr> <td>spark.speculation</td> <td>false</td> diff --git a/docs/python-programming-guide.md b/docs/python-programming-guide.md index dc187b3efec9b7b7b8f6075275a4faf9da40b6fe..c4236f83124b213a9293e2f1dece69d5edc06dd3 100644 --- a/docs/python-programming-guide.md +++ b/docs/python-programming-guide.md @@ -99,8 +99,9 @@ $ MASTER=local[4] ./bin/pyspark ## IPython -It is also possible to launch PySpark in [IPython](http://ipython.org), the enhanced Python interpreter. -To do this, set the `IPYTHON` variable to `1` when running `bin/pyspark`: +It is also possible to launch PySpark in [IPython](http://ipython.org), the +enhanced Python interpreter. PySpark works with IPython 1.0.0 and later. To +use IPython, set the `IPYTHON` variable to `1` when running `bin/pyspark`: {% highlight bash %} $ IPYTHON=1 ./bin/pyspark diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index b20627010798a9fe8461a661c281f12847593eab..3bd62646bab060eba316406fa9a1c7a4983a31b1 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -101,7 +101,19 @@ With this mode, your application is actually run on the remote machine where the With yarn-client mode, the application will be launched locally. Just like running application or spark-shell on Local / Mesos / Standalone mode. The launch method is also the similar with them, just make sure that when you need to specify a master url, use "yarn-client" instead. And you also need to export the env value for SPARK_JAR and SPARK_YARN_APP_JAR -In order to tune worker core/number/memory etc. You need to export SPARK_WORKER_CORES, SPARK_WORKER_MEMORY, SPARK_WORKER_INSTANCES e.g. by ./conf/spark-env.sh +Configuration in yarn-client mode: + +In order to tune worker core/number/memory etc. You need to export environment variables or add them to the spark configuration file (./conf/spark_env.sh). The following are the list of options. + +* `SPARK_YARN_APP_JAR`, Path to your application's JAR file (required) +* `SPARK_WORKER_INSTANCES`, Number of workers to start (Default: 2) +* `SPARK_WORKER_CORES`, Number of cores for the workers (Default: 1). +* `SPARK_WORKER_MEMORY`, Memory per Worker (e.g. 1000M, 2G) (Default: 1G) +* `SPARK_MASTER_MEMORY`, Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb) +* `SPARK_YARN_APP_NAME`, The name of your application (Default: Spark) +* `SPARK_YARN_QUEUE`, The hadoop queue to use for allocation requests (Default: 'default') +* `SPARK_YARN_DIST_FILES`, Comma separated list of files to be distributed with the job. +* `SPARK_YARN_DIST_ARCHIVES`, Comma separated list of archives to be distributed with the job. For example: @@ -114,7 +126,6 @@ For example: SPARK_YARN_APP_JAR=examples/target/scala-{{site.SCALA_VERSION}}/spark-examples-assembly-{{site.SPARK_VERSION}}.jar \ MASTER=yarn-client ./bin/spark-shell -You can also send extra files to yarn cluster for worker to use by exporting SPARK_YARN_DIST_FILES=file1,file2... etc. # Building Spark for Hadoop/YARN 2.2.x diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index f47d41f966e3bc8b39a3b1dc93b1bba9d464758a..2a186261b754ab8a29841d16913d7d69e7d1dda8 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -10,11 +10,7 @@ In addition to running on the Mesos or YARN cluster managers, Spark also provide # Installing Spark Standalone to a Cluster -The easiest way to deploy Spark is by running the `./make-distribution.sh` script to create a binary distribution. -This distribution can be deployed to any machine with the Java runtime installed; there is no need to install Scala. - -The recommended procedure is to deploy and start the master on one node first, get the master spark URL, -then modify `conf/spark-env.sh` in the `dist/` directory before deploying to all the other nodes. +To install Spark Standlone mode, you simply place a compiled version of Spark on each node on the cluster. You can obtain pre-built versions of Spark with each release or [build it yourself](index.html#building). # Starting a Cluster Manually @@ -150,6 +146,38 @@ automatically set MASTER from the `SPARK_MASTER_IP` and `SPARK_MASTER_PORT` vari You can also pass an option `-c <numCores>` to control the number of cores that spark-shell uses on the cluster. +# Launching Applications Inside the Cluster + +You may also run your application entirely inside of the cluster by submitting your application driver using the submission client. The syntax for submitting applications is as follows: + + + ./spark-class org.apache.spark.deploy.Client launch + [client-options] \ + <cluster-url> <application-jar-url> <main-class> \ + [application-options] + + cluster-url: The URL of the master node. + application-jar-url: Path to a bundled jar including your application and all dependencies. Currently, the URL must be globally visible inside of your cluster, for instance, an `hdfs://` path or a `file://` path that is present on all nodes. + main-class: The entry point for your application. + + Client Options: + --memory <count> (amount of memory, in MB, allocated for your driver program) + --cores <count> (number of cores allocated for your driver program) + --supervise (whether to automatically restart your driver on application or node failure) + --verbose (prints increased logging output) + +Keep in mind that your driver program will be executed on a remote worker machine. You can control the execution environment in the following ways: + + * _Environment variables_: These will be captured from the environment in which you launch the client and applied when launching the driver program. + * _Java options_: You can add java options by setting `SPARK_JAVA_OPTS` in the environment in which you launch the submission client. + * _Dependencies_: You'll still need to call `sc.addJar` inside of your program to make your bundled application jar visible on all worker nodes. + +Once you submit a driver program, it will appear in the cluster management UI at port 8080 and +be assigned an identifier. If you'd like to prematurely terminate the program, you can do so using +the same client: + + ./spark-class org.apache.spark.deploy.client.DriverClient kill <driverId> + # Resource Scheduling The standalone cluster mode currently only supports a simple FIFO scheduler across applications. diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 1c9ece62707818cfc403103269a81c2448dc3552..4e8a680a75d07cc8ff54dfe9cc0bb2f9feb361af 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -167,7 +167,7 @@ Spark Streaming features windowed computations, which allow you to apply transfo </tr> </table> -A complete list of DStream operations is available in the API documentation of [DStream](api/streaming/index.html#org.apache.spark.streaming.DStream) and [PairDStreamFunctions](api/streaming/index.html#org.apache.spark.streaming.PairDStreamFunctions). +A complete list of DStream operations is available in the API documentation of [DStream](api/streaming/index.html#org.apache.spark.streaming.dstream.DStream) and [PairDStreamFunctions](api/streaming/index.html#org.apache.spark.streaming.dstream.PairDStreamFunctions). ## Output Operations When an output operator is called, it triggers the computation of a stream. Currently the following output operators are defined: @@ -175,7 +175,7 @@ When an output operator is called, it triggers the computation of a stream. Curr <table class="table"> <tr><th style="width:30%">Operator</th><th>Meaning</th></tr> <tr> - <td> <b>foreach</b>(<i>func</i>) </td> + <td> <b>foreachRDD</b>(<i>func</i>) </td> <td> The fundamental output operator. Applies a function, <i>func</i>, to each RDD generated from the stream. This function should have side effects, such as printing output, saving the RDD to external files, or writing it over the network to an external system. </td> </tr> @@ -375,7 +375,7 @@ There are two failure behaviors based on which input sources are used. 1. _Using HDFS files as input source_ - Since the data is reliably stored on HDFS, all data can re-computed and therefore no data will be lost due to any failure. 1. _Using any input source that receives data through a network_ - For network-based data sources like Kafka and Flume, the received input data is replicated in memory between nodes of the cluster (default replication factor is 2). So if a worker node fails, then the system can recompute the lost from the the left over copy of the input data. However, if the worker node where a network receiver was running fails, then a tiny bit of data may be lost, that is, the data received by the system but not yet replicated to other node(s). The receiver will be started on a different node and it will continue to receive data. -Since all data is modeled as RDDs with their lineage of deterministic operations, any recomputation always leads to the same result. As a result, all DStream transformations are guaranteed to have _exactly-once_ semantics. That is, the final transformed result will be same even if there were was a worker node failure. However, output operations (like `foreach`) have _at-least once_ semantics, that is, the transformed data may get written to an external entity more than once in the event of a worker failure. While this is acceptable for saving to HDFS using the `saveAs*Files` operations (as the file will simply get over-written by the same data), additional transactions-like mechanisms may be necessary to achieve exactly-once semantics for output operations. +Since all data is modeled as RDDs with their lineage of deterministic operations, any recomputation always leads to the same result. As a result, all DStream transformations are guaranteed to have _exactly-once_ semantics. That is, the final transformed result will be same even if there were was a worker node failure. However, output operations (like `foreachRDD`) have _at-least once_ semantics, that is, the transformed data may get written to an external entity more than once in the event of a worker failure. While this is acceptable for saving to HDFS using the `saveAs*Files` operations (as the file will simply get over-written by the same data), additional transactions-like mechanisms may be necessary to achieve exactly-once semantics for output operations. ## Failure of the Driver Node A system that is required to operate 24/7 needs to be able tolerate the failure of the driver node as well. Spark Streaming does this by saving the state of the DStream computation periodically to a HDFS file, that can be used to restart the streaming computation in the event of a failure of the driver node. This checkpointing is enabled by setting a HDFS directory for checkpointing using `ssc.checkpoint(<checkpoint directory>)` as described [earlier](#rdd-checkpointing-within-dstreams). To elaborate, the following state is periodically saved to a file. diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index d82a1e1490cc0e7125030a7b93a10bfd1ddca208..e7cb5ab3ff9b02be14f76d94aec63aca91eb9663 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -185,7 +185,11 @@ def get_spark_ami(opts): "hi1.4xlarge": "hvm", "m3.xlarge": "hvm", "m3.2xlarge": "hvm", - "cr1.8xlarge": "hvm" + "cr1.8xlarge": "hvm", + "i2.xlarge": "hvm", + "i2.2xlarge": "hvm", + "i2.4xlarge": "hvm", + "i2.8xlarge": "hvm" } if opts.instance_type in instance_types: instance_type = instance_types[opts.instance_type] @@ -478,7 +482,11 @@ def get_num_disks(instance_type): "cr1.8xlarge": 2, "hi1.4xlarge": 2, "m3.xlarge": 0, - "m3.2xlarge": 0 + "m3.2xlarge": 0, + "i2.xlarge": 1, + "i2.2xlarge": 2, + "i2.4xlarge": 4, + "i2.8xlarge": 8 } if instance_type in disks_by_instance: return disks_by_instance[instance_type] diff --git a/examples/src/main/java/org/apache/spark/streaming/examples/JavaFlumeEventCount.java b/examples/src/main/java/org/apache/spark/streaming/examples/JavaFlumeEventCount.java index b11cfa667eb9238671b1126357e5d4e5642b8b51..7b5a243e26414ef3d77212f360707ba2c01f460a 100644 --- a/examples/src/main/java/org/apache/spark/streaming/examples/JavaFlumeEventCount.java +++ b/examples/src/main/java/org/apache/spark/streaming/examples/JavaFlumeEventCount.java @@ -47,6 +47,8 @@ public final class JavaFlumeEventCount { System.exit(1); } + StreamingExamples.setStreamingLogLevels(); + String master = args[0]; String host = args[1]; int port = Integer.parseInt(args[2]); diff --git a/examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java b/examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java index 16b8a948e6154ad7527041efdfb546a2f859d11f..04f62ee2041451db8ad249cd655762d0d5fde503 100644 --- a/examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java +++ b/examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java @@ -59,6 +59,8 @@ public final class JavaKafkaWordCount { System.exit(1); } + StreamingExamples.setStreamingLogLevels(); + // Create the context with a 1 second batch size JavaStreamingContext jssc = new JavaStreamingContext(args[0], "KafkaWordCount", new Duration(2000), System.getenv("SPARK_HOME"), diff --git a/examples/src/main/java/org/apache/spark/streaming/examples/JavaNetworkWordCount.java b/examples/src/main/java/org/apache/spark/streaming/examples/JavaNetworkWordCount.java index 1e2efd359cff297129b6aa35ac42cb19b40072f7..349d826ab5df76d02de32d70202ac93b6502bee3 100644 --- a/examples/src/main/java/org/apache/spark/streaming/examples/JavaNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/streaming/examples/JavaNetworkWordCount.java @@ -38,7 +38,7 @@ import java.util.regex.Pattern; * To run this on your local machine, you need to first run a Netcat server * `$ nc -lk 9999` * and then run the example - * `$ ./run spark.streaming.examples.JavaNetworkWordCount local[2] localhost 9999` + * `$ ./run org.apache.spark.streaming.examples.JavaNetworkWordCount local[2] localhost 9999` */ public final class JavaNetworkWordCount { private static final Pattern SPACE = Pattern.compile(" "); @@ -48,18 +48,20 @@ public final class JavaNetworkWordCount { public static void main(String[] args) { if (args.length < 3) { - System.err.println("Usage: NetworkWordCount <master> <hostname> <port>\n" + + System.err.println("Usage: JavaNetworkWordCount <master> <hostname> <port>\n" + "In local mode, <master> should be 'local[n]' with n > 1"); System.exit(1); } + StreamingExamples.setStreamingLogLevels(); + // Create the context with a 1 second batch size - JavaStreamingContext ssc = new JavaStreamingContext(args[0], "NetworkWordCount", + JavaStreamingContext ssc = new JavaStreamingContext(args[0], "JavaNetworkWordCount", new Duration(1000), System.getenv("SPARK_HOME"), JavaStreamingContext.jarOfClass(JavaNetworkWordCount.class)); // Create a NetworkInputDStream on target ip:port and count the - // words in input stream of \n delimited test (eg. generated by 'nc') + // words in input stream of \n delimited text (eg. generated by 'nc') JavaDStream<String> lines = ssc.socketTextStream(args[1], Integer.parseInt(args[2])); JavaDStream<String> words = lines.flatMap(new FlatMapFunction<String, String>() { @Override @@ -82,6 +84,5 @@ public final class JavaNetworkWordCount { wordCounts.print(); ssc.start(); - } } diff --git a/examples/src/main/java/org/apache/spark/streaming/examples/JavaQueueStream.java b/examples/src/main/java/org/apache/spark/streaming/examples/JavaQueueStream.java index e05551ab833010df77cb761a17709583967fa84a..7ef9c6c8f4aaf0cad16304289bda46a4a114fc03 100644 --- a/examples/src/main/java/org/apache/spark/streaming/examples/JavaQueueStream.java +++ b/examples/src/main/java/org/apache/spark/streaming/examples/JavaQueueStream.java @@ -41,6 +41,8 @@ public final class JavaQueueStream { System.exit(1); } + StreamingExamples.setStreamingLogLevels(); + // Create the context JavaStreamingContext ssc = new JavaStreamingContext(args[0], "QueueStream", new Duration(1000), System.getenv("SPARK_HOME"), JavaStreamingContext.jarOfClass(JavaQueueStream.class)); diff --git a/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala b/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala new file mode 100644 index 0000000000000000000000000000000000000000..65251e93190f01dccf25d79373cf9d989078d732 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import scala.collection.JavaConversions._ + +/** Prints out environmental information, sleeps, and then exits. Made to + * test driver submission in the standalone scheduler. */ +object DriverSubmissionTest { + def main(args: Array[String]) { + if (args.size < 1) { + println("Usage: DriverSubmissionTest <seconds-to-sleep>") + System.exit(0) + } + val numSecondsToSleep = args(0).toInt + + val env = System.getenv() + val properties = System.getProperties() + + println("Environment variables containing SPARK_TEST:") + env.filter{case (k, v) => k.contains("SPARK_TEST")}.foreach(println) + + println("System properties containing spark.test:") + properties.filter{case (k, v) => k.toString.contains("spark.test")}.foreach(println) + + for (i <- 1 until numSecondsToSleep) { + println(s"Alive for $i out of $numSecondsToSleep seconds") + Thread.sleep(1000) + } + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala index 83db8b9e26411cb23f08892be5f59a5664a23813..c8ecbb8e41a8689a2d10421cae74f2ceb45f7f70 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala @@ -43,7 +43,7 @@ object LocalALS { def generateR(): DoubleMatrix2D = { val mh = factory2D.random(M, F) val uh = factory2D.random(U, F) - return algebra.mult(mh, algebra.transpose(uh)) + algebra.mult(mh, algebra.transpose(uh)) } def rmse(targetR: DoubleMatrix2D, ms: Array[DoubleMatrix1D], @@ -56,7 +56,7 @@ object LocalALS { //println("R: " + r) blas.daxpy(-1, targetR, r) val sumSqs = r.aggregate(Functions.plus, Functions.square) - return sqrt(sumSqs / (M * U)) + sqrt(sumSqs / (M * U)) } def updateMovie(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D], @@ -80,7 +80,7 @@ object LocalALS { val ch = new CholeskyDecomposition(XtX) val Xty2D = factory2D.make(Xty.toArray, F) val solved2D = ch.solve(Xty2D) - return solved2D.viewColumn(0) + solved2D.viewColumn(0) } def updateUser(j: Int, u: DoubleMatrix1D, ms: Array[DoubleMatrix1D], @@ -104,7 +104,7 @@ object LocalALS { val ch = new CholeskyDecomposition(XtX) val Xty2D = factory2D.make(Xty.toArray, F) val solved2D = ch.solve(Xty2D) - return solved2D.viewColumn(0) + solved2D.viewColumn(0) } def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala index fb130ea1988f76ff9456e5f2d2db58b1e53377cc..9ab5f5a48620be26b439c5bf8dc4f1fc918b9cce 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala @@ -28,7 +28,7 @@ object LocalFileLR { def parsePoint(line: String): DataPoint = { val nums = line.split(' ').map(_.toDouble) - return DataPoint(new Vector(nums.slice(1, D+1)), nums(0)) + DataPoint(new Vector(nums.slice(1, D+1)), nums(0)) } def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala index f90ea35cd447c035848fbf124d726d94645ade74..a730464ea158ef5219c25f3e9feda6f2d1d591c0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala @@ -55,7 +55,7 @@ object LocalKMeans { } } - return bestIndex + bestIndex } def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala index 30c86d83e688c5755f5c00c34ce4cd701ea7d160..17bafc2218a31e51551fc29a2bf2fcc41a35863c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala @@ -44,7 +44,7 @@ object SparkALS { def generateR(): DoubleMatrix2D = { val mh = factory2D.random(M, F) val uh = factory2D.random(U, F) - return algebra.mult(mh, algebra.transpose(uh)) + algebra.mult(mh, algebra.transpose(uh)) } def rmse(targetR: DoubleMatrix2D, ms: Array[DoubleMatrix1D], @@ -57,7 +57,7 @@ object SparkALS { //println("R: " + r) blas.daxpy(-1, targetR, r) val sumSqs = r.aggregate(Functions.plus, Functions.square) - return sqrt(sumSqs / (M * U)) + sqrt(sumSqs / (M * U)) } def update(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D], @@ -83,7 +83,7 @@ object SparkALS { val ch = new CholeskyDecomposition(XtX) val Xty2D = factory2D.make(Xty.toArray, F) val solved2D = ch.solve(Xty2D) - return solved2D.viewColumn(0) + solved2D.viewColumn(0) } def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala index ff72532db1f17adedd5e4a7c7e76aa95223c18ae..39819064edbaa814159c6b3abdcfeececa3b7c33 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala @@ -43,7 +43,7 @@ object SparkHdfsLR { while (i < D) { x(i) = tok.nextToken.toDouble; i += 1 } - return DataPoint(new Vector(x), y) + DataPoint(new Vector(x), y) } def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala index 8c99025eaa6da472335f74ca8371d3d60d8a73c5..9fe24652358f3a11994a40135f128a4aebc48ad0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala @@ -30,7 +30,7 @@ object SparkKMeans { val rand = new Random(42) def parseVector(line: String): Vector = { - return new Vector(line.split(' ').map(_.toDouble)) + new Vector(line.split(' ').map(_.toDouble)) } def closestPoint(p: Vector, centers: Array[Vector]): Int = { @@ -46,7 +46,7 @@ object SparkKMeans { } } - return bestIndex + bestIndex } def main(args: Array[String]) { @@ -61,15 +61,15 @@ object SparkKMeans { val K = args(2).toInt val convergeDist = args(3).toDouble - var kPoints = data.takeSample(false, K, 42).toArray + val kPoints = data.takeSample(withReplacement = false, K, 42).toArray var tempDist = 1.0 while(tempDist > convergeDist) { - var closest = data.map (p => (closestPoint(p, kPoints), (p, 1))) + val closest = data.map (p => (closestPoint(p, kPoints), (p, 1))) - var pointStats = closest.reduceByKey{case ((x1, y1), (x2, y2)) => (x1 + x2, y1 + y2)} + val pointStats = closest.reduceByKey{case ((x1, y1), (x2, y2)) => (x1 + x2, y1 + y2)} - var newPoints = pointStats.map {pair => (pair._1, pair._2._1 / pair._2._2)}.collectAsMap() + val newPoints = pointStats.map {pair => (pair._1, pair._2._1 / pair._2._2)}.collectAsMap() tempDist = 0.0 for (i <- 0 until K) { diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala index 4e0058cd707777c32ce3baa4d68e7d6e0d9d1e09..57e1b1f806e82ecfb88b3fa88d0360e77ea28bd4 100644 --- a/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala @@ -18,17 +18,13 @@ package org.apache.spark.streaming.examples import scala.collection.mutable.LinkedList -import scala.util.Random import scala.reflect.ClassTag +import scala.util.Random -import akka.actor.Actor -import akka.actor.ActorRef -import akka.actor.Props -import akka.actor.actorRef2Scala +import akka.actor.{Actor, ActorRef, Props, actorRef2Scala} import org.apache.spark.SparkConf -import org.apache.spark.streaming.Seconds -import org.apache.spark.streaming.StreamingContext +import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.StreamingContext.toPairDStreamFunctions import org.apache.spark.streaming.receivers.Receiver import org.apache.spark.util.AkkaUtils @@ -147,6 +143,8 @@ object ActorWordCount { System.exit(1) } + StreamingExamples.setStreamingLogLevels() + val Seq(master, host, port) = args.toSeq // Create the context and set the batch size diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/FlumeEventCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/FlumeEventCount.scala index ae3709b3d97f5561f857b7d6fab8ef0436888cbc..a59be7899dd37de2e50cfcd49e625992af589f2c 100644 --- a/examples/src/main/scala/org/apache/spark/streaming/examples/FlumeEventCount.scala +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/FlumeEventCount.scala @@ -17,10 +17,10 @@ package org.apache.spark.streaming.examples -import org.apache.spark.util.IntParam import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.streaming.flume._ +import org.apache.spark.util.IntParam /** * Produces a count of events received from Flume. @@ -44,6 +44,8 @@ object FlumeEventCount { System.exit(1) } + StreamingExamples.setStreamingLogLevels() + val Array(master, host, IntParam(port)) = args val batchInterval = Milliseconds(2000) diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/HdfsWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/HdfsWordCount.scala index ea6ea674196a1effb94ab7138998d4076cb41799..704b315ef8b2214ece9bb44ab1518d1838c9dfd5 100644 --- a/examples/src/main/scala/org/apache/spark/streaming/examples/HdfsWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/HdfsWordCount.scala @@ -20,7 +20,6 @@ package org.apache.spark.streaming.examples import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.StreamingContext._ - /** * Counts words in new text files created in the given directory * Usage: HdfsWordCount <master> <directory> @@ -38,6 +37,8 @@ object HdfsWordCount { System.exit(1) } + StreamingExamples.setStreamingLogLevels() + // Create the context val ssc = new StreamingContext(args(0), "HdfsWordCount", Seconds(2), System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass)) diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/KafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/KafkaWordCount.scala index 31a94bd224a45f8f68f177777c7d37b9ff423de7..4a3d81c09a122aadbc5bc9672694a8be78261ce4 100644 --- a/examples/src/main/scala/org/apache/spark/streaming/examples/KafkaWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/KafkaWordCount.scala @@ -23,8 +23,8 @@ import kafka.producer._ import org.apache.spark.streaming._ import org.apache.spark.streaming.StreamingContext._ -import org.apache.spark.streaming.util.RawTextHelper._ import org.apache.spark.streaming.kafka._ +import org.apache.spark.streaming.util.RawTextHelper._ /** * Consumes messages from one or more topics in Kafka and does wordcount. @@ -40,12 +40,13 @@ import org.apache.spark.streaming.kafka._ */ object KafkaWordCount { def main(args: Array[String]) { - if (args.length < 5) { System.err.println("Usage: KafkaWordCount <master> <zkQuorum> <group> <topics> <numThreads>") System.exit(1) } + StreamingExamples.setStreamingLogLevels() + val Array(master, zkQuorum, group, topics, numThreads) = args val ssc = new StreamingContext(master, "KafkaWordCount", Seconds(2), diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/MQTTWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/MQTTWordCount.scala index 325290b66f4decbe55ed025a865816d09ecbcc5c..78b49fdcf1eb3a84bd8d4b6814b978c039e9f689 100644 --- a/examples/src/main/scala/org/apache/spark/streaming/examples/MQTTWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/MQTTWordCount.scala @@ -17,12 +17,8 @@ package org.apache.spark.streaming.examples -import org.eclipse.paho.client.mqttv3.MqttClient -import org.eclipse.paho.client.mqttv3.MqttClientPersistence +import org.eclipse.paho.client.mqttv3.{MqttClient, MqttClientPersistence, MqttException, MqttMessage, MqttTopic} import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence -import org.eclipse.paho.client.mqttv3.MqttException -import org.eclipse.paho.client.mqttv3.MqttMessage -import org.eclipse.paho.client.mqttv3.MqttTopic import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Seconds, StreamingContext} @@ -43,6 +39,8 @@ object MQTTPublisher { System.exit(1) } + StreamingExamples.setStreamingLogLevels() + val Seq(brokerUrl, topic) = args.toSeq try { diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/NetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/NetworkWordCount.scala index 6a32c75373a7efcc3b3787c8cc3179dacaa169c2..25f7013307fef4d59a93baecca19a97fb197f0e7 100644 --- a/examples/src/main/scala/org/apache/spark/streaming/examples/NetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/NetworkWordCount.scala @@ -21,7 +21,8 @@ import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.StreamingContext._ /** - * Counts words in UTF8 encoded, '\n' delimited text received from the network every second. + * Counts words in text encoded with UTF8 received from the network every second. + * * Usage: NetworkWordCount <master> <hostname> <port> * <master> is the Spark master URL. In local mode, <master> should be 'local[n]' with n > 1. * <hostname> and <port> describe the TCP server that Spark Streaming would connect to receive data. @@ -39,12 +40,14 @@ object NetworkWordCount { System.exit(1) } + StreamingExamples.setStreamingLogLevels() + // Create the context with a 1 second batch size val ssc = new StreamingContext(args(0), "NetworkWordCount", Seconds(1), System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass)) // Create a NetworkInputDStream on target ip:port and count the - // words in input stream of \n delimited test (eg. generated by 'nc') + // words in input stream of \n delimited text (eg. generated by 'nc') val lines = ssc.socketTextStream(args(1), args(2).toInt) val words = lines.flatMap(_.split(" ")) val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/QueueStream.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/QueueStream.scala index 9d640e716bca978b74c76bb5cab3ffef6ad37ebf..4d4968ba6ae3e0d2d90685bc2841d6895c139501 100644 --- a/examples/src/main/scala/org/apache/spark/streaming/examples/QueueStream.scala +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/QueueStream.scala @@ -17,12 +17,12 @@ package org.apache.spark.streaming.examples +import scala.collection.mutable.SynchronizedQueue + import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.StreamingContext._ -import scala.collection.mutable.SynchronizedQueue - object QueueStream { def main(args: Array[String]) { @@ -30,7 +30,9 @@ object QueueStream { System.err.println("Usage: QueueStream <master>") System.exit(1) } - + + StreamingExamples.setStreamingLogLevels() + // Create the context val ssc = new StreamingContext(args(0), "QueueStream", Seconds(1), System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass)) diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/RawNetworkGrep.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/RawNetworkGrep.scala index c0706d07249824cc740968d141e2d874c9c185c6..99b79c3949a4ed4c6fd166501070d3907e2aff6f 100644 --- a/examples/src/main/scala/org/apache/spark/streaming/examples/RawNetworkGrep.scala +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/RawNetworkGrep.scala @@ -17,11 +17,10 @@ package org.apache.spark.streaming.examples -import org.apache.spark.util.IntParam import org.apache.spark.storage.StorageLevel - import org.apache.spark.streaming._ import org.apache.spark.streaming.util.RawTextHelper +import org.apache.spark.util.IntParam /** * Receives text from multiple rawNetworkStreams and counts how many '\n' delimited @@ -45,6 +44,8 @@ object RawNetworkGrep { System.exit(1) } + StreamingExamples.setStreamingLogLevels() + val Array(master, IntParam(numStreams), host, IntParam(port), IntParam(batchMillis)) = args // Create the context @@ -57,7 +58,7 @@ object RawNetworkGrep { val rawStreams = (1 to numStreams).map(_ => ssc.rawSocketStream[String](host, port, StorageLevel.MEMORY_ONLY_SER_2)).toArray val union = ssc.union(rawStreams) - union.filter(_.contains("the")).count().foreach(r => + union.filter(_.contains("the")).count().foreachRDD(r => println("Grep count: " + r.collect().mkString)) ssc.start() } diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/RecoverableNetworkWordCount.scala new file mode 100644 index 0000000000000000000000000000000000000000..8c5d0bd56845bc8d2262af4f339c0e85a0660b6e --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/RecoverableNetworkWordCount.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.examples + +import org.apache.spark.streaming.{Time, Seconds, StreamingContext} +import org.apache.spark.streaming.StreamingContext._ +import org.apache.spark.util.IntParam +import java.io.File +import org.apache.spark.rdd.RDD +import com.google.common.io.Files +import java.nio.charset.Charset + +/** + * Counts words in text encoded with UTF8 received from the network every second. + * + * Usage: NetworkWordCount <master> <hostname> <port> <checkpoint-directory> <output-file> + * <master> is the Spark master URL. In local mode, <master> should be 'local[n]' with n > 1. + * <hostname> and <port> describe the TCP server that Spark Streaming would connect to receive data. + * <checkpoint-directory> directory to HDFS-compatible file system which checkpoint data + * <output-file> file to which the word counts will be appended + * + * In local mode, <master> should be 'local[n]' with n > 1 + * <checkpoint-directory> and <output-file> must be absolute paths + * + * + * To run this on your local machine, you need to first run a Netcat server + * + * `$ nc -lk 9999` + * + * and run the example as + * + * `$ ./run-example org.apache.spark.streaming.examples.RecoverableNetworkWordCount \ + * local[2] localhost 9999 ~/checkpoint/ ~/out` + * + * If the directory ~/checkpoint/ does not exist (e.g. running for the first time), it will create + * a new StreamingContext (will print "Creating new context" to the console). Otherwise, if + * checkpoint data exists in ~/checkpoint/, then it will create StreamingContext from + * the checkpoint data. + * + * To run this example in a local standalone cluster with automatic driver recovery, + * + * `$ ./spark-class org.apache.spark.deploy.Client -s launch <cluster-url> <path-to-examples-jar> \ + * org.apache.spark.streaming.examples.RecoverableNetworkWordCount <cluster-url> \ + * localhost 9999 ~/checkpoint ~/out` + * + * <path-to-examples-jar> would typically be <spark-dir>/examples/target/scala-XX/spark-examples....jar + * + * Refer to the online documentation for more details. + */ + +object RecoverableNetworkWordCount { + + def createContext(master: String, ip: String, port: Int, outputPath: String) = { + + // If you do not see this printed, that means the StreamingContext has been loaded + // from the new checkpoint + println("Creating new context") + val outputFile = new File(outputPath) + if (outputFile.exists()) outputFile.delete() + + // Create the context with a 1 second batch size + val ssc = new StreamingContext(master, "RecoverableNetworkWordCount", Seconds(1), + System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass)) + + // Create a NetworkInputDStream on target ip:port and count the + // words in input stream of \n delimited text (eg. generated by 'nc') + val lines = ssc.socketTextStream(ip, port) + val words = lines.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) + wordCounts.foreachRDD((rdd: RDD[(String, Int)], time: Time) => { + val counts = "Counts at time " + time + " " + rdd.collect().mkString("[", ", ", "]") + println(counts) + println("Appending to " + outputFile.getAbsolutePath) + Files.append(counts + "\n", outputFile, Charset.defaultCharset()) + }) + ssc + } + + def main(args: Array[String]) { + if (args.length != 5) { + System.err.println("You arguments were " + args.mkString("[", ", ", "]")) + System.err.println( + """ + |Usage: RecoverableNetworkWordCount <master> <hostname> <port> <checkpoint-directory> <output-file> + | <master> is the Spark master URL. In local mode, <master> should be 'local[n]' with n > 1. + | <hostname> and <port> describe the TCP server that Spark Streaming would connect to receive data. + | <checkpoint-directory> directory to HDFS-compatible file system which checkpoint data + | <output-file> file to which the word counts will be appended + | + |In local mode, <master> should be 'local[n]' with n > 1 + |Both <checkpoint-directory> and <output-file> must be absolute paths + """.stripMargin + ) + System.exit(1) + } + val Array(master, ip, IntParam(port), checkpointDirectory, outputPath) = args + val ssc = StreamingContext.getOrCreate(checkpointDirectory, + () => { + createContext(master, ip, port, outputPath) + }) + ssc.start() + } +} diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/StatefulNetworkWordCount.scala index 002db57d592b2a88f491ca1813d95cd9e1a14ca6..1183eba84686bd5712c7668e58af05add060691d 100644 --- a/examples/src/main/scala/org/apache/spark/streaming/examples/StatefulNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/StatefulNetworkWordCount.scala @@ -39,6 +39,8 @@ object StatefulNetworkWordCount { System.exit(1) } + StreamingExamples.setStreamingLogLevels() + val updateFunc = (values: Seq[Int], state: Option[Int]) => { val currentCount = values.foldLeft(0)(_ + _) diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/StreamingExamples.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/StreamingExamples.scala new file mode 100644 index 0000000000000000000000000000000000000000..d41d84a980dc73b8f4bcacf56d2a4de2af399c05 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/StreamingExamples.scala @@ -0,0 +1,21 @@ +package org.apache.spark.streaming.examples + +import org.apache.spark.Logging + +import org.apache.log4j.{Level, Logger} + +/** Utility functions for Spark Streaming examples. */ +object StreamingExamples extends Logging { + + /** Set reasonable logging levels for streaming if the user has not configured log4j. */ + def setStreamingLogLevels() { + val log4jInitialized = Logger.getRootLogger.getAllAppenders.hasMoreElements + if (!log4jInitialized) { + // We first log something to initialize Spark's default logging, then we override the + // logging level. + logInfo("Setting log level to [WARN] for streaming example." + + " To override add a custom log4j.properties to the classpath.") + Logger.getRootLogger.setLevel(Level.WARN) + } + } +} diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdCMS.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdCMS.scala index 3ccdc908e23c43d0c3424a5b6719950e7c658702..483c4d311810fb81cb789415d67f11d2840be6f4 100644 --- a/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdCMS.scala +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdCMS.scala @@ -17,12 +17,12 @@ package org.apache.spark.streaming.examples -import org.apache.spark.streaming.{Seconds, StreamingContext} -import org.apache.spark.storage.StorageLevel import com.twitter.algebird._ -import org.apache.spark.streaming.StreamingContext._ -import org.apache.spark.SparkContext._ +import org.apache.spark.SparkContext._ +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Seconds, StreamingContext} +import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.streaming.twitter._ /** @@ -51,6 +51,8 @@ object TwitterAlgebirdCMS { System.exit(1) } + StreamingExamples.setStreamingLogLevels() + // CMS parameters val DELTA = 1E-3 val EPS = 0.01 @@ -79,7 +81,7 @@ object TwitterAlgebirdCMS { val exactTopUsers = users.map(id => (id, 1)) .reduceByKey((a, b) => a + b) - approxTopUsers.foreach(rdd => { + approxTopUsers.foreachRDD(rdd => { if (rdd.count() != 0) { val partial = rdd.first() val partialTopK = partial.heavyHitters.map(id => @@ -94,7 +96,7 @@ object TwitterAlgebirdCMS { } }) - exactTopUsers.foreach(rdd => { + exactTopUsers.foreachRDD(rdd => { if (rdd.count() != 0) { val partialMap = rdd.collect().toMap val partialTopK = rdd.map( diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdHLL.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdHLL.scala index c7e83e76b00570e721ee1fe965d119178b14528d..94c2bf29ac4333c188ac4e2749fab7d4a3eac348 100644 --- a/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdHLL.scala +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdHLL.scala @@ -17,10 +17,11 @@ package org.apache.spark.streaming.examples -import org.apache.spark.streaming.{Seconds, StreamingContext} -import org.apache.spark.storage.StorageLevel -import com.twitter.algebird.HyperLogLog._ import com.twitter.algebird.HyperLogLogMonoid +import com.twitter.algebird.HyperLogLog._ + +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.twitter._ /** @@ -44,6 +45,8 @@ object TwitterAlgebirdHLL { System.exit(1) } + StreamingExamples.setStreamingLogLevels() + /** Bit size parameter for HyperLogLog, trades off accuracy vs size */ val BIT_SIZE = 12 val (master, filters) = (args.head, args.tail) @@ -64,7 +67,7 @@ object TwitterAlgebirdHLL { val exactUsers = users.map(id => Set(id)).reduce(_ ++ _) - approxUsers.foreach(rdd => { + approxUsers.foreachRDD(rdd => { if (rdd.count() != 0) { val partial = rdd.first() globalHll += partial @@ -73,7 +76,7 @@ object TwitterAlgebirdHLL { } }) - exactUsers.foreach(rdd => { + exactUsers.foreachRDD(rdd => { if (rdd.count() != 0) { val partial = rdd.first() userSet ++= partial diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterPopularTags.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterPopularTags.scala index e2b0418d55d2b14b08d50aa1500a6806306a7f5c..8a70d4a978cd43c5a8507373704e9a99d09e8e7f 100644 --- a/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterPopularTags.scala +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterPopularTags.scala @@ -36,6 +36,8 @@ object TwitterPopularTags { System.exit(1) } + StreamingExamples.setStreamingLogLevels() + val (master, filters) = (args.head, args.tail) val ssc = new StreamingContext(master, "TwitterPopularTags", Seconds(2), @@ -54,13 +56,13 @@ object TwitterPopularTags { // Print popular hashtags - topCounts60.foreach(rdd => { + topCounts60.foreachRDD(rdd => { val topList = rdd.take(5) println("\nPopular topics in last 60 seconds (%s total):".format(rdd.count())) topList.foreach{case (count, tag) => println("%s (%s tweets)".format(tag, count))} }) - topCounts10.foreach(rdd => { + topCounts10.foreachRDD(rdd => { val topList = rdd.take(5) println("\nPopular topics in last 10 seconds (%s total):".format(rdd.count())) topList.foreach{case (count, tag) => println("%s (%s tweets)".format(tag, count))} diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/ZeroMQWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/ZeroMQWordCount.scala index 03902ec353babfbc6d2991d6af01d8e6ed19d09e..12d2a1084f9002bef4957faf1ea126229219db7e 100644 --- a/examples/src/main/scala/org/apache/spark/streaming/examples/ZeroMQWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/ZeroMQWordCount.scala @@ -76,6 +76,7 @@ object ZeroMQWordCount { "In local mode, <master> should be 'local[n]' with n > 1") System.exit(1) } + StreamingExamples.setStreamingLogLevels() val Seq(master, url, topic) = args.toSeq // Create the context and set the batch size diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewGenerator.scala index 4fe57de4a4058fe3ebf13332c5fe10be491bf8d9..a2600989ca1a62bda36526215829e21cbab7460d 100644 --- a/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewGenerator.scala +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewGenerator.scala @@ -65,7 +65,7 @@ object PageViewGenerator { return item } } - return inputMap.take(1).head._1 // Shouldn't get here if probabilities add up to 1.0 + inputMap.take(1).head._1 // Shouldn't get here if probabilities add up to 1.0 } def getNextClickEvent() : String = { diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewStream.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewStream.scala index 807af199f4fd0813b2961784f1dca9ed88d9a69a..bb44bc3d06ef3918ff8b80352ae6e14b62c7c646 100644 --- a/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewStream.scala +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewStream.scala @@ -17,9 +17,10 @@ package org.apache.spark.streaming.examples.clickstream +import org.apache.spark.SparkContext._ import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.StreamingContext._ -import org.apache.spark.SparkContext._ +import org.apache.spark.streaming.examples.StreamingExamples /** Analyses a streaming dataset of web page views. This class demonstrates several types of * operators available in Spark streaming. @@ -36,6 +37,7 @@ object PageViewStream { " errorRatePerZipCode, activeUserCount, popularUsersSeen") System.exit(1) } + StreamingExamples.setStreamingLogLevels() val metric = args(0) val host = args(1) val port = args(2).toInt @@ -89,7 +91,7 @@ object PageViewStream { case "popularUsersSeen" => // Look for users in our existing dataset and print it out if we have a match pageViews.map(view => (view.userID, 1)) - .foreach((rdd, time) => rdd.join(userList) + .foreachRDD((rdd, time) => rdd.join(userList) .map(_._2._2) .take(10) .foreach(u => println("Saw user %s at time %s".format(u, time)))) diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala index 834b775d4fd2b09ea1c2a24dce48af963c57871b..d53b66dd4677141518ca671d77957039ad06543c 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala @@ -18,8 +18,9 @@ package org.apache.spark.streaming.flume import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{StreamingContext, DStream} +import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.api.java.{JavaStreamingContext, JavaDStream} +import org.apache.spark.streaming.dstream.DStream object FlumeUtils { /** @@ -42,6 +43,7 @@ object FlumeUtils { /** * Creates a input stream from a Flume source. + * Storage level of the data will be the default StorageLevel.MEMORY_AND_DISK_SER_2. * @param hostname Hostname of the slave machine to which the flume data will be sent * @param port Port of the slave machine to which the flume data will be sent */ diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index f782e0e126d4528e67305955bdd7fe39a397e188..23b2fead657e6ff318cc718891af54c210efbb8d 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -45,9 +45,9 @@ <scope>test</scope> </dependency> <dependency> - <groupId>com.sksamuel.kafka</groupId> + <groupId>org.apache.kafka</groupId> <artifactId>kafka_${scala.binary.version}</artifactId> - <version>0.8.0-beta1</version> + <version>0.8.0</version> <exclusions> <exclusion> <groupId>com.sun.jmx</groupId> diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index c2d851f94311d48fe7ce19ca6a1d14c49d5c48e4..37c03be4e77ade881b58e4a76b3f0de7ac682cd8 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -26,8 +26,9 @@ import java.util.{Map => JMap} import kafka.serializer.{Decoder, StringDecoder} import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{StreamingContext, DStream} +import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.api.java.{JavaStreamingContext, JavaPairDStream} +import org.apache.spark.streaming.dstream.DStream object KafkaUtils { @@ -77,6 +78,7 @@ object KafkaUtils { /** * Create an input stream that pulls messages form a Kafka Broker. + * Storage level of the data will be the default StorageLevel.MEMORY_AND_DISK_SER_2. * @param jssc JavaStreamingContext object * @param zkQuorum Zookeeper quorum (hostname:port,hostname:port,..) * @param groupId The group id for this consumer @@ -127,7 +129,7 @@ object KafkaUtils { * see http://kafka.apache.org/08/configuration.html * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread - * @param storageLevel RDD storage level. Defaults to MEMORY_AND_DISK_2. + * @param storageLevel RDD storage level. */ def createStream[K, V, U <: Decoder[_], T <: Decoder[_]]( jssc: JavaStreamingContext, diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala index c8987a3ee06bc74921fdb26dfe30ece76185963b..41e813d48c7b8d66e3bc1f2c9f40dc7f2b22f845 100644 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala +++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala @@ -80,7 +80,7 @@ class MQTTReceiver(brokerUrl: String, var peristance: MqttClientPersistence = new MemoryPersistence() // Initializing Mqtt Client specifying brokerUrl, clientID and MqttClientPersistance - var client: MqttClient = new MqttClient(brokerUrl, "MQTTSub", peristance) + var client: MqttClient = new MqttClient(brokerUrl, MqttClient.generateClientId(), peristance) // Connect to MqttBroker client.connect() diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala index 0e6c25dbee8fbe5e63f0648c890143e9a4ae4509..3636e46bb82576013e6362c1faa84c974959f12e 100644 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala +++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala @@ -18,9 +18,10 @@ package org.apache.spark.streaming.mqtt import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{StreamingContext, DStream} +import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.api.java.{JavaStreamingContext, JavaDStream} import scala.reflect.ClassTag +import org.apache.spark.streaming.dstream.DStream object MQTTUtils { /** @@ -43,6 +44,7 @@ object MQTTUtils { /** * Create an input stream that receives messages pushed by a MQTT publisher. + * Storage level of the data will be the default StorageLevel.MEMORY_AND_DISK_SER_2. * @param jssc JavaStreamingContext object * @param brokerUrl Url of remote MQTT publisher * @param topic Topic name to subscribe to diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala index fcc159e85a85bf918f587dc0c626bca6bebe8231..73e7ce6e968c6fdc1575562beb969a3a5b616b79 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.storage.StorageLevel class MQTTStreamSuite extends TestSuiteBase { - test("MQTT input stream") { + test("mqtt input stream") { val ssc = new StreamingContext(master, framework, batchDuration) val brokerUrl = "abc" val topic = "def" diff --git a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterUtils.scala b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterUtils.scala index 5e506ffabcfc4a796bc0cf1a10247ad6023d7757..b8bae7b6d33855fa20f8aeee2391260026044a93 100644 --- a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterUtils.scala +++ b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterUtils.scala @@ -20,8 +20,9 @@ package org.apache.spark.streaming.twitter import twitter4j.Status import twitter4j.auth.Authorization import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{StreamingContext, DStream} +import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.api.java.{JavaDStream, JavaStreamingContext} +import org.apache.spark.streaming.dstream.DStream object TwitterUtils { /** @@ -50,6 +51,7 @@ object TwitterUtils { * OAuth authentication; this requires the system properties twitter4j.oauth.consumerKey, * twitter4j.oauth.consumerSecret, twitter4j.oauth.accessToken and * twitter4j.oauth.accessTokenSecret. + * Storage level of the data will be the default StorageLevel.MEMORY_AND_DISK_SER_2. * @param jssc JavaStreamingContext object */ def createStream(jssc: JavaStreamingContext): JavaDStream[Status] = { @@ -61,6 +63,7 @@ object TwitterUtils { * OAuth authentication; this requires the system properties twitter4j.oauth.consumerKey, * twitter4j.oauth.consumerSecret, twitter4j.oauth.accessToken and * twitter4j.oauth.accessTokenSecret. + * Storage level of the data will be the default StorageLevel.MEMORY_AND_DISK_SER_2. * @param jssc JavaStreamingContext object * @param filters Set of filter strings to get only those tweets that match them */ @@ -87,6 +90,7 @@ object TwitterUtils { /** * Create a input stream that returns tweets received from Twitter. + * Storage level of the data will be the default StorageLevel.MEMORY_AND_DISK_SER_2. * @param jssc JavaStreamingContext object * @param twitterAuth Twitter4J Authorization */ @@ -96,6 +100,7 @@ object TwitterUtils { /** * Create a input stream that returns tweets received from Twitter. + * Storage level of the data will be the default StorageLevel.MEMORY_AND_DISK_SER_2. * @param jssc JavaStreamingContext object * @param twitterAuth Twitter4J Authorization * @param filters Set of filter strings to get only those tweets that match them diff --git a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala b/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala index a0a8fe617b134c7c39ae8df77690f12a1fa42797..ccc38784ef671185cc564ea9354e9c11af2f8407 100644 --- a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala +++ b/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala @@ -23,7 +23,7 @@ import twitter4j.auth.{NullAuthorization, Authorization} class TwitterStreamSuite extends TestSuiteBase { - test("kafka input stream") { + test("twitter input stream") { val ssc = new StreamingContext(master, framework, batchDuration) val filters = Seq("filter1", "filter2") val authorization: Authorization = NullAuthorization.getInstance() diff --git a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala index 546d9df3b5df9b0a8232a74356dba607151c9ad8..7a14b3d2bf27859c346f4490cac61acd83e68a69 100644 --- a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala +++ b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala @@ -25,8 +25,9 @@ import akka.zeromq.Subscribe import org.apache.spark.api.java.function.{Function => JFunction} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.receivers.ReceiverSupervisorStrategy -import org.apache.spark.streaming.{StreamingContext, DStream} +import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.api.java.{JavaStreamingContext, JavaDStream} +import org.apache.spark.streaming.dstream.DStream object ZeroMQUtils { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 2d8623392eb4e42cfaaf8fbfd329b1646b8c3322..c972a71349e2555c558b17249b972de73a143eaa 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -48,7 +48,7 @@ class PythonMLLibAPI extends Serializable { val db = bb.asDoubleBuffer() val ans = new Array[Double](length.toInt) db.get(ans) - return ans + ans } private def serializeDoubleVector(doubles: Array[Double]): Array[Byte] = { @@ -60,7 +60,7 @@ class PythonMLLibAPI extends Serializable { bb.putLong(len) val db = bb.asDoubleBuffer() db.put(doubles) - return bytes + bytes } private def deserializeDoubleMatrix(bytes: Array[Byte]): Array[Array[Double]] = { @@ -86,7 +86,7 @@ class PythonMLLibAPI extends Serializable { ans(i) = new Array[Double](cols.toInt) db.get(ans(i)) } - return ans + ans } private def serializeDoubleMatrix(doubles: Array[Array[Double]]): Array[Byte] = { @@ -102,11 +102,10 @@ class PythonMLLibAPI extends Serializable { bb.putLong(rows) bb.putLong(cols) val db = bb.asDoubleBuffer() - var i = 0 for (i <- 0 until rows) { db.put(doubles(i)) } - return bytes + bytes } private def trainRegressionModel(trainFunc: (RDD[LabeledPoint], Array[Double]) => GeneralizedLinearModel, @@ -121,7 +120,7 @@ class PythonMLLibAPI extends Serializable { val ret = new java.util.LinkedList[java.lang.Object]() ret.add(serializeDoubleVector(model.weights)) ret.add(model.intercept: java.lang.Double) - return ret + ret } /** @@ -130,7 +129,7 @@ class PythonMLLibAPI extends Serializable { def trainLinearRegressionModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, stepSize: Double, miniBatchFraction: Double, initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { - return trainRegressionModel((data, initialWeights) => + trainRegressionModel((data, initialWeights) => LinearRegressionWithSGD.train(data, numIterations, stepSize, miniBatchFraction, initialWeights), dataBytesJRDD, initialWeightsBA) @@ -142,7 +141,7 @@ class PythonMLLibAPI extends Serializable { def trainLassoModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, stepSize: Double, regParam: Double, miniBatchFraction: Double, initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { - return trainRegressionModel((data, initialWeights) => + trainRegressionModel((data, initialWeights) => LassoWithSGD.train(data, numIterations, stepSize, regParam, miniBatchFraction, initialWeights), dataBytesJRDD, initialWeightsBA) @@ -154,7 +153,7 @@ class PythonMLLibAPI extends Serializable { def trainRidgeModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, stepSize: Double, regParam: Double, miniBatchFraction: Double, initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { - return trainRegressionModel((data, initialWeights) => + trainRegressionModel((data, initialWeights) => RidgeRegressionWithSGD.train(data, numIterations, stepSize, regParam, miniBatchFraction, initialWeights), dataBytesJRDD, initialWeightsBA) @@ -166,7 +165,7 @@ class PythonMLLibAPI extends Serializable { def trainSVMModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, stepSize: Double, regParam: Double, miniBatchFraction: Double, initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { - return trainRegressionModel((data, initialWeights) => + trainRegressionModel((data, initialWeights) => SVMWithSGD.train(data, numIterations, stepSize, regParam, miniBatchFraction, initialWeights), dataBytesJRDD, initialWeightsBA) @@ -178,7 +177,7 @@ class PythonMLLibAPI extends Serializable { def trainLogisticRegressionModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, stepSize: Double, miniBatchFraction: Double, initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { - return trainRegressionModel((data, initialWeights) => + trainRegressionModel((data, initialWeights) => LogisticRegressionWithSGD.train(data, numIterations, stepSize, miniBatchFraction, initialWeights), dataBytesJRDD, initialWeightsBA) @@ -194,7 +193,7 @@ class PythonMLLibAPI extends Serializable { val model = KMeans.train(data, k, maxIterations, runs, initializationMode) val ret = new java.util.LinkedList[java.lang.Object]() ret.add(serializeDoubleMatrix(model.clusterCenters)) - return ret + ret } /** Unpack a Rating object from an array of bytes */ @@ -204,7 +203,7 @@ class PythonMLLibAPI extends Serializable { val user = bb.getInt() val product = bb.getInt() val rating = bb.getDouble() - return new Rating(user, product, rating) + new Rating(user, product, rating) } /** Unpack a tuple of Ints from an array of bytes */ @@ -245,7 +244,7 @@ class PythonMLLibAPI extends Serializable { def trainALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int, iterations: Int, lambda: Double, blocks: Int): MatrixFactorizationModel = { val ratings = ratingsBytesJRDD.rdd.map(unpackRating) - return ALS.train(ratings, rank, iterations, lambda, blocks) + ALS.train(ratings, rank, iterations, lambda, blocks) } /** @@ -257,6 +256,6 @@ class PythonMLLibAPI extends Serializable { def trainImplicitALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int, iterations: Int, lambda: Double, blocks: Int, alpha: Double): MatrixFactorizationModel = { val ratings = ratingsBytesJRDD.rdd.map(unpackRating) - return ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha) + ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index 8b27ecf82c06d7686873da088bbf3b3de84dffe0..89ee07063dd89cb098889d75c833a379ab814064 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -22,7 +22,7 @@ import scala.util.Random import scala.util.Sorting import org.apache.spark.broadcast.Broadcast -import org.apache.spark.{Logging, HashPartitioner, Partitioner, SparkContext} +import org.apache.spark.{Logging, HashPartitioner, Partitioner, SparkContext, SparkConf} import org.apache.spark.storage.StorageLevel import org.apache.spark.rdd.RDD import org.apache.spark.serializer.KryoRegistrator @@ -578,12 +578,13 @@ object ALS { val implicitPrefs = if (args.length >= 7) args(6).toBoolean else false val alpha = if (args.length >= 8) args(7).toDouble else 1 val blocks = if (args.length == 9) args(8).toInt else -1 - val sc = new SparkContext(master, "ALS") - sc.conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - sc.conf.set("spark.kryo.registrator", classOf[ALSRegistrator].getName) - sc.conf.set("spark.kryo.referenceTracking", "false") - sc.conf.set("spark.kryoserializer.buffer.mb", "8") - sc.conf.set("spark.locality.wait", "10000") + val conf = new SparkConf() + .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + .set("spark.kryo.registrator", classOf[ALSRegistrator].getName) + .set("spark.kryo.referenceTracking", "false") + .set("spark.kryoserializer.buffer.mb", "8") + .set("spark.locality.wait", "10000") + val sc = new SparkContext(master, "ALS", conf) val ratings = sc.textFile(ratingsFile).map { line => val fields = line.split(',') diff --git a/pom.xml b/pom.xml index 68dbde7c8b657409d9911e5274bbe8e5a6320a71..b25d9d7ef891dfcbc6d2c5d17586a363a5e9a198 100644 --- a/pom.xml +++ b/pom.xml @@ -258,6 +258,17 @@ </exclusion> </exclusions> </dependency> + <dependency> + <groupId>${akka.group}</groupId> + <artifactId>akka-testkit_${scala.binary.version}</artifactId> + <version>${akka.version}</version> + <exclusions> + <exclusion> + <groupId>org.jboss.netty</groupId> + <artifactId>netty</artifactId> + </exclusion> + </exclusions> + </dependency> <dependency> <groupId>it.unimi.dsi</groupId> <artifactId>fastutil</artifactId> @@ -346,6 +357,12 @@ <version>1.9.1</version> <scope>test</scope> </dependency> + <dependency> + <groupId>org.mockito</groupId> + <artifactId>mockito-all</artifactId> + <scope>test</scope> + <version>1.8.5</version> + </dependency> <dependency> <groupId>commons-io</groupId> <artifactId>commons-io</artifactId> diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index c2b1c0c35cc5b5eafc3999ed9115bf74fdc7ccf4..d508603e244bebcd57002b714b18b70eb4f798e7 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -252,6 +252,7 @@ object SparkBuild extends Build { "org.ow2.asm" % "asm" % "4.0", "org.spark-project.akka" %% "akka-remote" % "2.2.3-shaded-protobuf" excludeAll(excludeNetty), "org.spark-project.akka" %% "akka-slf4j" % "2.2.3-shaded-protobuf" excludeAll(excludeNetty), + "org.spark-project.akka" %% "akka-testkit" % "2.2.3-shaded-protobuf" % "test", "net.liftweb" %% "lift-json" % "2.5.1" excludeAll(excludeNetty), "it.unimi.dsi" % "fastutil" % "6.4.4", "colt" % "colt" % "1.2.0", @@ -398,7 +399,7 @@ object SparkBuild extends Build { name := "spark-streaming-kafka", libraryDependencies ++= Seq( "com.github.sgroschupf" % "zkclient" % "0.1" excludeAll(excludeNetty), - "com.sksamuel.kafka" %% "kafka" % "0.8.0-beta1" + "org.apache.kafka" %% "kafka" % "0.8.0" exclude("com.sun.jdmk", "jmxtools") exclude("com.sun.jmx", "jmxri") exclude("net.sf.jopt-simple", "jopt-simple") diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala index daaa2a0305113527d66d1d69cb54c791933c91ae..8aad27366524afa9820a8a518dd5c7831b27e7a2 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -35,7 +35,6 @@ class ReplSuite extends FunSuite { } // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown System.clearProperty("spark.driver.port") - System.clearProperty("spark.hostPort") return out.toString } @@ -75,7 +74,6 @@ class ReplSuite extends FunSuite { interp.sparkContext.stop() System.clearProperty("spark.driver.port") - System.clearProperty("spark.hostPort") } test("simple foreach with accumulator") { diff --git a/sbt/sbt b/sbt/sbt index 7f47d90cf11bbcef036fd4c5d863fcf8fd4e2052..62ead8a69dbf69d21ee4b5a5e3fa9dcc86006451 100755 --- a/sbt/sbt +++ b/sbt/sbt @@ -25,37 +25,26 @@ URL1=http://typesafe.artifactoryonline.com/typesafe/ivy-releases/org.scala-sbt/s URL2=http://repo.typesafe.com/typesafe/ivy-releases/org.scala-sbt/sbt-launch/${SBT_VERSION}/sbt-launch.jar JAR=sbt/sbt-launch-${SBT_VERSION}.jar -printf "Checking for system sbt [" -if hash sbt 2>/dev/null; then - printf "FOUND]\n" - # Use System SBT - sbt "$@" -else - printf "NOT FOUND]\n" - # Download sbt or use already downloaded - if [ ! -d .sbtlib ]; then - mkdir .sbtlib - fi - if [ ! -f ${JAR} ]; then - # Download - printf "Attempting to fetch sbt\n" - if hash curl 2>/dev/null; then - curl --progress-bar ${URL1} > ${JAR} || curl --progress-bar ${URL2} > ${JAR} - elif hash wget 2>/dev/null; then - wget --progress=bar ${URL1} -O ${JAR} || wget --progress=bar ${URL2} -O ${JAR} - else - printf "You do not have curl or wget installed, please install sbt manually from http://www.scala-sbt.org/\n" - exit -1 - fi - fi - if [ ! -f ${JAR} ]; then - # We failed to download - printf "Our attempt to download sbt locally to ${JAR} failed. Please install sbt manually from http://www.scala-sbt.org/\n" +# Download sbt launch jar if it hasn't been downloaded yet +if [ ! -f ${JAR} ]; then + # Download + printf "Attempting to fetch sbt\n" + if hash curl 2>/dev/null; then + curl --progress-bar ${URL1} > ${JAR} || curl --progress-bar ${URL2} > ${JAR} + elif hash wget 2>/dev/null; then + wget --progress=bar ${URL1} -O ${JAR} || wget --progress=bar ${URL2} -O ${JAR} + else + printf "You do not have curl or wget installed, please install sbt manually from http://www.scala-sbt.org/\n" exit -1 fi - printf "Launching sbt from ${JAR}\n" - java \ - -Xmx1200m -XX:MaxPermSize=350m -XX:ReservedCodeCacheSize=256m \ - -jar ${JAR} \ - "$@" fi +if [ ! -f ${JAR} ]; then + # We failed to download + printf "Our attempt to download sbt locally to ${JAR} failed. Please install sbt manually from http://www.scala-sbt.org/\n" + exit -1 +fi +printf "Launching sbt from ${JAR}\n" +java \ + -Xmx1200m -XX:MaxPermSize=350m -XX:ReservedCodeCacheSize=256m \ + -jar ${JAR} \ + "$@" diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index ca0115f90e49eb48b8dbaea9529ad63f1bafe01d..5046a1d53fa41a992a43c4605b7be84f1c66cc6d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -24,10 +24,10 @@ import java.util.concurrent.RejectedExecutionException import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.conf.Configuration -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.{SparkException, SparkConf, Logging} import org.apache.spark.io.CompressionCodec import org.apache.spark.util.MetadataCleaner -import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.streaming.scheduler.JobGenerator private[streaming] @@ -40,10 +40,14 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) val graph = ssc.graph val checkpointDir = ssc.checkpointDir val checkpointDuration = ssc.checkpointDuration - val pendingTimes = ssc.scheduler.getPendingTimes() + val pendingTimes = ssc.scheduler.getPendingTimes().toArray val delaySeconds = MetadataCleaner.getDelaySeconds(ssc.conf) val sparkConf = ssc.conf + // These should be unset when a checkpoint is deserialized, + // otherwise the SparkContext won't initialize correctly. + sparkConf.remove("spark.driver.host").remove("spark.driver.port") + def validate() { assert(master != null, "Checkpoint.master is null") assert(framework != null, "Checkpoint.framework is null") @@ -53,59 +57,119 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) } } +private[streaming] +object Checkpoint extends Logging { + val PREFIX = "checkpoint-" + val REGEX = (PREFIX + """([\d]+)([\w\.]*)""").r + + /** Get the checkpoint file for the given checkpoint time */ + def checkpointFile(checkpointDir: String, checkpointTime: Time) = { + new Path(checkpointDir, PREFIX + checkpointTime.milliseconds) + } + + /** Get the checkpoint backup file for the given checkpoint time */ + def checkpointBackupFile(checkpointDir: String, checkpointTime: Time) = { + new Path(checkpointDir, PREFIX + checkpointTime.milliseconds + ".bk") + } + + /** Get checkpoint files present in the give directory, ordered by oldest-first */ + def getCheckpointFiles(checkpointDir: String, fs: FileSystem): Seq[Path] = { + def sortFunc(path1: Path, path2: Path): Boolean = { + val (time1, bk1) = path1.getName match { case REGEX(x, y) => (x.toLong, !y.isEmpty) } + val (time2, bk2) = path2.getName match { case REGEX(x, y) => (x.toLong, !y.isEmpty) } + (time1 < time2) || (time1 == time2 && bk1) + } + + val path = new Path(checkpointDir) + if (fs.exists(path)) { + val statuses = fs.listStatus(path) + if (statuses != null) { + val paths = statuses.map(_.getPath) + val filtered = paths.filter(p => REGEX.findFirstIn(p.toString).nonEmpty) + filtered.sortWith(sortFunc) + } else { + logWarning("Listing " + path + " returned null") + Seq.empty + } + } else { + logInfo("Checkpoint directory " + path + " does not exist") + Seq.empty + } + } +} + /** * Convenience class to handle the writing of graph checkpoint to file */ private[streaming] -class CheckpointWriter(conf: SparkConf, checkpointDir: String, hadoopConf: Configuration) - extends Logging -{ - val file = new Path(checkpointDir, "graph") +class CheckpointWriter( + jobGenerator: JobGenerator, + conf: SparkConf, + checkpointDir: String, + hadoopConf: Configuration + ) extends Logging { val MAX_ATTEMPTS = 3 val executor = Executors.newFixedThreadPool(1) val compressionCodec = CompressionCodec.createCodec(conf) - // The file to which we actually write - and then "move" to file - val writeFile = new Path(file.getParent, file.getName + ".next") - // The file to which existing checkpoint is backed up (i.e. "moved") - val bakFile = new Path(file.getParent, file.getName + ".bk") - private var stopped = false private var fs_ : FileSystem = _ - // Removed code which validates whether there is only one CheckpointWriter per path 'file' since - // I did not notice any errors - reintroduce it ? class CheckpointWriteHandler(checkpointTime: Time, bytes: Array[Byte]) extends Runnable { def run() { var attempts = 0 val startTime = System.currentTimeMillis() + val tempFile = new Path(checkpointDir, "temp") + val checkpointFile = Checkpoint.checkpointFile(checkpointDir, checkpointTime) + val backupFile = Checkpoint.checkpointBackupFile(checkpointDir, checkpointTime) + while (attempts < MAX_ATTEMPTS && !stopped) { attempts += 1 try { - logDebug("Saving checkpoint for time " + checkpointTime + " to file '" + file + "'") - // This is inherently thread unsafe, so alleviating it by writing to '.new' and - // then moving it to the final file - val fos = fs.create(writeFile) + logInfo("Saving checkpoint for time " + checkpointTime + " to file '" + checkpointFile + "'") + + // Write checkpoint to temp file + fs.delete(tempFile, true) // just in case it exists + val fos = fs.create(tempFile) fos.write(bytes) fos.close() - if (fs.exists(file) && fs.rename(file, bakFile)) { - logDebug("Moved existing checkpoint file to " + bakFile) + + // If the checkpoint file exists, back it up + // If the backup exists as well, just delete it, otherwise rename will fail + if (fs.exists(checkpointFile)) { + fs.delete(backupFile, true) // just in case it exists + if (!fs.rename(checkpointFile, backupFile)) { + logWarning("Could not rename " + checkpointFile + " to " + backupFile) + } + } + + // Rename temp file to the final checkpoint file + if (!fs.rename(tempFile, checkpointFile)) { + logWarning("Could not rename " + tempFile + " to " + checkpointFile) + } + + // Delete old checkpoint files + val allCheckpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs) + if (allCheckpointFiles.size > 4) { + allCheckpointFiles.take(allCheckpointFiles.size - 10).foreach(file => { + logInfo("Deleting " + file) + fs.delete(file, true) + }) } - // paranoia - fs.delete(file, false) - fs.rename(writeFile, file) + // All done, print success val finishTime = System.currentTimeMillis() - logInfo("Checkpoint for time " + checkpointTime + " saved to file '" + file + - "', took " + bytes.length + " bytes and " + (finishTime - startTime) + " milliseconds") + logInfo("Checkpoint for time " + checkpointTime + " saved to file '" + checkpointFile + + "', took " + bytes.length + " bytes and " + (finishTime - startTime) + " ms") + jobGenerator.onCheckpointCompletion(checkpointTime) return } catch { case ioe: IOException => - logWarning("Error writing checkpoint to file in " + attempts + " attempts", ioe) + logWarning("Error in attempt " + attempts + " of writing checkpoint to " + checkpointFile, ioe) reset() } } - logError("Could not write checkpoint for time " + checkpointTime + " to file '" + file + "'") + logWarning("Could not write checkpoint for time " + checkpointTime + " to file " + checkpointFile + "'") } } @@ -118,6 +182,7 @@ class CheckpointWriter(conf: SparkConf, checkpointDir: String, hadoopConf: Confi bos.close() try { executor.execute(new CheckpointWriteHandler(checkpoint.checkpointTime, bos.toByteArray)) + logDebug("Submitted checkpoint of time " + checkpoint.checkpointTime + " writer queue") } catch { case rej: RejectedExecutionException => logError("Could not submit checkpoint task to the thread pool executor", rej) @@ -140,7 +205,7 @@ class CheckpointWriter(conf: SparkConf, checkpointDir: String, hadoopConf: Confi } private def fs = synchronized { - if (fs_ == null) fs_ = file.getFileSystem(hadoopConf) + if (fs_ == null) fs_ = new Path(checkpointDir).getFileSystem(hadoopConf) fs_ } @@ -153,43 +218,46 @@ class CheckpointWriter(conf: SparkConf, checkpointDir: String, hadoopConf: Confi private[streaming] object CheckpointReader extends Logging { - def read(conf: SparkConf, path: String): Checkpoint = { - val fs = new Path(path).getFileSystem(new Configuration()) - val attempts = Seq(new Path(path, "graph"), new Path(path, "graph.bk"), - new Path(path), new Path(path + ".bk")) + def read(checkpointDir: String, conf: SparkConf, hadoopConf: Configuration): Option[Checkpoint] = { + val checkpointPath = new Path(checkpointDir) + def fs = checkpointPath.getFileSystem(hadoopConf) + + // Try to find the checkpoint files + val checkpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs).reverse + if (checkpointFiles.isEmpty) { + return None + } + // Try to read the checkpoint files in the order + logInfo("Checkpoint files found: " + checkpointFiles.mkString(",")) val compressionCodec = CompressionCodec.createCodec(conf) - - attempts.foreach(file => { - if (fs.exists(file)) { - logInfo("Attempting to load checkpoint from file '" + file + "'") - try { - val fis = fs.open(file) - // ObjectInputStream uses the last defined user-defined class loader in the stack - // to find classes, which maybe the wrong class loader. Hence, a inherited version - // of ObjectInputStream is used to explicitly use the current thread's default class - // loader to find and load classes. This is a well know Java issue and has popped up - // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627) - val zis = compressionCodec.compressedInputStream(fis) - val ois = new ObjectInputStreamWithLoader(zis, - Thread.currentThread().getContextClassLoader) - val cp = ois.readObject.asInstanceOf[Checkpoint] - ois.close() - fs.close() - cp.validate() - logInfo("Checkpoint successfully loaded from file '" + file + "'") - logInfo("Checkpoint was generated at time " + cp.checkpointTime) - return cp - } catch { - case e: Exception => - logError("Error loading checkpoint from file '" + file + "'", e) - } - } else { - logWarning("Could not read checkpoint from file '" + file + "' as it does not exist") + checkpointFiles.foreach(file => { + logInfo("Attempting to load checkpoint from file " + file) + try { + val fis = fs.open(file) + // ObjectInputStream uses the last defined user-defined class loader in the stack + // to find classes, which maybe the wrong class loader. Hence, a inherited version + // of ObjectInputStream is used to explicitly use the current thread's default class + // loader to find and load classes. This is a well know Java issue and has popped up + // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627) + val zis = compressionCodec.compressedInputStream(fis) + val ois = new ObjectInputStreamWithLoader(zis, + Thread.currentThread().getContextClassLoader) + val cp = ois.readObject.asInstanceOf[Checkpoint] + ois.close() + fs.close() + cp.validate() + logInfo("Checkpoint successfully loaded from file " + file) + logInfo("Checkpoint was generated at time " + cp.checkpointTime) + return Some(cp) + } catch { + case e: Exception => + logWarning("Error reading checkpoint from file " + file, e) } - }) - throw new Exception("Could not read checkpoint from path '" + path + "'") + + // If none of checkpoint files could be read, then throw exception + throw new SparkException("Failed to read checkpoint from directory " + checkpointPath) } } @@ -203,6 +271,6 @@ class ObjectInputStreamWithLoader(inputStream_ : InputStream, loader: ClassLoade } catch { case e: Exception => } - return super.resolveClass(desc) + super.resolveClass(desc) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala b/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala new file mode 100644 index 0000000000000000000000000000000000000000..1f5dacb543db863ce7d00629cafc1ae651c52b80 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala @@ -0,0 +1,28 @@ +package org.apache.spark.streaming + +private[streaming] class ContextWaiter { + private var error: Throwable = null + private var stopped: Boolean = false + + def notifyError(e: Throwable) = synchronized { + error = e + notifyAll() + } + + def notifyStop() = synchronized { + notifyAll() + } + + def waitForStopOrError(timeout: Long = -1) = synchronized { + // If already had error, then throw it + if (error != null) { + throw error + } + + // If not already stopped, then wait + if (!stopped) { + if (timeout < 0) wait() else wait(timeout) + if (error != null) throw error + } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala deleted file mode 100644 index 3fd5d52403c14be29b043decfd02a42876c49db1..0000000000000000000000000000000000000000 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming - -import org.apache.hadoop.fs.Path -import org.apache.hadoop.fs.FileSystem -import org.apache.hadoop.conf.Configuration - -import collection.mutable.HashMap -import org.apache.spark.Logging - -import scala.collection.mutable.HashMap -import scala.reflect.ClassTag - - -private[streaming] -class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T]) - extends Serializable with Logging { - protected val data = new HashMap[Time, AnyRef]() - - @transient private var fileSystem : FileSystem = null - @transient private var lastCheckpointFiles: HashMap[Time, String] = null - - protected[streaming] def checkpointFiles = data.asInstanceOf[HashMap[Time, String]] - - /** - * Updates the checkpoint data of the DStream. This gets called every time - * the graph checkpoint is initiated. Default implementation records the - * checkpoint files to which the generate RDDs of the DStream has been saved. - */ - def update() { - - // Get the checkpointed RDDs from the generated RDDs - val newCheckpointFiles = dstream.generatedRDDs.filter(_._2.getCheckpointFile.isDefined) - .map(x => (x._1, x._2.getCheckpointFile.get)) - - // Make a copy of the existing checkpoint data (checkpointed RDDs) - lastCheckpointFiles = checkpointFiles.clone() - - // If the new checkpoint data has checkpoints then replace existing with the new one - if (newCheckpointFiles.size > 0) { - checkpointFiles.clear() - checkpointFiles ++= newCheckpointFiles - } - - // TODO: remove this, this is just for debugging - newCheckpointFiles.foreach { - case (time, data) => { logInfo("Added checkpointed RDD for time " + time + " to stream checkpoint") } - } - } - - /** - * Cleanup old checkpoint data. This gets called every time the graph - * checkpoint is initiated, but after `update` is called. Default - * implementation, cleans up old checkpoint files. - */ - def cleanup() { - // If there is at least on checkpoint file in the current checkpoint files, - // then delete the old checkpoint files. - if (checkpointFiles.size > 0 && lastCheckpointFiles != null) { - (lastCheckpointFiles -- checkpointFiles.keySet).foreach { - case (time, file) => { - try { - val path = new Path(file) - if (fileSystem == null) { - fileSystem = path.getFileSystem(new Configuration()) - } - fileSystem.delete(path, true) - logInfo("Deleted checkpoint file '" + file + "' for time " + time) - } catch { - case e: Exception => - logWarning("Error deleting old checkpoint file '" + file + "' for time " + time, e) - } - } - } - } - } - - /** - * Restore the checkpoint data. This gets called once when the DStream graph - * (along with its DStreams) are being restored from a graph checkpoint file. - * Default implementation restores the RDDs from their checkpoint files. - */ - def restore() { - // Create RDDs from the checkpoint data - checkpointFiles.foreach { - case(time, file) => { - logInfo("Restoring checkpointed RDD for time " + time + " from file '" + file + "'") - dstream.generatedRDDs += ((time, dstream.context.sparkContext.checkpointFile[T](file))) - } - } - } - - override def toString() = { - "[\n" + checkpointFiles.size + " checkpoint files \n" + checkpointFiles.mkString("\n") + "\n]" - } -} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala index a09b891956efe2348043c85cf6f11c22596e3be4..8faa79f8c7e9d8b48bff46497cfb0a2550e15ee5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala @@ -17,11 +17,11 @@ package org.apache.spark.streaming -import dstream.InputDStream +import scala.collection.mutable.ArrayBuffer import java.io.{ObjectInputStream, IOException, ObjectOutputStream} -import collection.mutable.ArrayBuffer import org.apache.spark.Logging import org.apache.spark.streaming.scheduler.Job +import org.apache.spark.streaming.dstream.{DStream, NetworkInputDStream, InputDStream} final private[streaming] class DStreamGraph extends Serializable with Logging { @@ -78,7 +78,7 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { def remember(duration: Duration) { this.synchronized { if (rememberDuration != null) { - throw new Exception("Batch duration already set as " + batchDuration + + throw new Exception("Remember duration already set as " + batchDuration + ". cannot set it again.") } rememberDuration = duration @@ -103,37 +103,51 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { def getOutputStreams() = this.synchronized { outputStreams.toArray } + def getNetworkInputStreams() = this.synchronized { + inputStreams.filter(_.isInstanceOf[NetworkInputDStream[_]]) + .map(_.asInstanceOf[NetworkInputDStream[_]]) + .toArray + } + def generateJobs(time: Time): Seq[Job] = { - this.synchronized { - logInfo("Generating jobs for time " + time) - val jobs = outputStreams.flatMap(outputStream => outputStream.generateJob(time)) - logInfo("Generated " + jobs.length + " jobs for time " + time) - jobs + logDebug("Generating jobs for time " + time) + val jobs = this.synchronized { + outputStreams.flatMap(outputStream => outputStream.generateJob(time)) } + logDebug("Generated " + jobs.length + " jobs for time " + time) + jobs } - def clearOldMetadata(time: Time) { + def clearMetadata(time: Time) { + logDebug("Clearing metadata for time " + time) this.synchronized { - logInfo("Clearing old metadata for time " + time) - outputStreams.foreach(_.clearOldMetadata(time)) - logInfo("Cleared old metadata for time " + time) + outputStreams.foreach(_.clearMetadata(time)) } + logDebug("Cleared old metadata for time " + time) } def updateCheckpointData(time: Time) { + logInfo("Updating checkpoint data for time " + time) this.synchronized { - logInfo("Updating checkpoint data for time " + time) outputStreams.foreach(_.updateCheckpointData(time)) - logInfo("Updated checkpoint data for time " + time) } + logInfo("Updated checkpoint data for time " + time) + } + + def clearCheckpointData(time: Time) { + logInfo("Clearing checkpoint data for time " + time) + this.synchronized { + outputStreams.foreach(_.clearCheckpointData(time)) + } + logInfo("Cleared checkpoint data for time " + time) } def restoreCheckpointData() { + logInfo("Restoring checkpoint data") this.synchronized { - logInfo("Restoring checkpoint data") outputStreams.foreach(_.restoreCheckpointData()) - logInfo("Restored checkpoint data") } + logInfo("Restored checkpoint data") } def validate() { @@ -146,8 +160,8 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { + logDebug("DStreamGraph.writeObject used") this.synchronized { - logDebug("DStreamGraph.writeObject used") checkpointInProgress = true oos.defaultWriteObject() checkpointInProgress = false @@ -156,8 +170,8 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { @throws(classOf[IOException]) private def readObject(ois: ObjectInputStream) { + logDebug("DStreamGraph.readObject used") this.synchronized { - logDebug("DStreamGraph.readObject used") checkpointInProgress = true ois.defaultReadObject() checkpointInProgress = false diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 693cb7fc30fa044374fd08c0128f5a1199f0dfb0..7b279334034ab9adfa5ab67aa61df75e27a1ac2f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -39,13 +39,14 @@ import org.apache.spark.util.MetadataCleaner import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.receivers._ import org.apache.spark.streaming.scheduler._ +import org.apache.hadoop.conf.Configuration /** * A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic * information (such as, cluster URL and job name) to internally create a SparkContext, it provides * methods used to create DStream from various input sources. */ -class StreamingContext private ( +class StreamingContext private[streaming] ( sc_ : SparkContext, cp_ : Checkpoint, batchDur_ : Duration @@ -88,30 +89,21 @@ class StreamingContext private ( /** * Re-create a StreamingContext from a checkpoint file. - * @param path Path either to the directory that was specified as the checkpoint directory, or - * to the checkpoint file 'graph' or 'graph.bk'. + * @param path Path to the directory that was specified as the checkpoint directory + * @param hadoopConf Optional, configuration object if necessary for reading from + * HDFS compatible filesystems */ - def this(path: String) = this(null, CheckpointReader.read(new SparkConf(), path), null) + def this(path: String, hadoopConf: Configuration = new Configuration) = + this(null, CheckpointReader.read(path, new SparkConf(), hadoopConf).get, null) if (sc_ == null && cp_ == null) { throw new Exception("Spark Streaming cannot be initialized with " + "both SparkContext and checkpoint as null") } - private val conf_ = Option(sc_).map(_.conf).getOrElse(cp_.sparkConf) + private[streaming] val isCheckpointPresent = (cp_ != null) - if(cp_ != null && cp_.delaySeconds >= 0 && MetadataCleaner.getDelaySeconds(conf_) < 0) { - MetadataCleaner.setDelaySeconds(conf_, cp_.delaySeconds) - } - - if (MetadataCleaner.getDelaySeconds(conf_) < 0) { - throw new SparkException("Spark Streaming cannot be used without setting spark.cleaner.ttl; " - + "set this property before creating a SparkContext (use SPARK_JAVA_OPTS for the shell)") - } - - protected[streaming] val isCheckpointPresent = (cp_ != null) - - protected[streaming] val sc: SparkContext = { + private[streaming] val sc: SparkContext = { if (isCheckpointPresent) { new SparkContext(cp_.sparkConf) } else { @@ -119,11 +111,16 @@ class StreamingContext private ( } } - protected[streaming] val conf = sc.conf + if (MetadataCleaner.getDelaySeconds(sc.conf) < 0) { + throw new SparkException("Spark Streaming cannot be used without setting spark.cleaner.ttl; " + + "set this property before creating a SparkContext (use SPARK_JAVA_OPTS for the shell)") + } + + private[streaming] val conf = sc.conf - protected[streaming] val env = SparkEnv.get + private[streaming] val env = SparkEnv.get - protected[streaming] val graph: DStreamGraph = { + private[streaming] val graph: DStreamGraph = { if (isCheckpointPresent) { cp_.graph.setContext(this) cp_.graph.restoreCheckpointData() @@ -136,10 +133,9 @@ class StreamingContext private ( } } - protected[streaming] val nextNetworkInputStreamId = new AtomicInteger(0) - protected[streaming] var networkInputTracker: NetworkInputTracker = null + private val nextNetworkInputStreamId = new AtomicInteger(0) - protected[streaming] var checkpointDir: String = { + private[streaming] var checkpointDir: String = { if (isCheckpointPresent) { sc.setCheckpointDir(cp_.checkpointDir) cp_.checkpointDir @@ -148,11 +144,13 @@ class StreamingContext private ( } } - protected[streaming] val checkpointDuration: Duration = { + private[streaming] val checkpointDuration: Duration = { if (isCheckpointPresent) cp_.checkpointDuration else graph.batchDuration } - protected[streaming] val scheduler = new JobScheduler(this) + private[streaming] val scheduler = new JobScheduler(this) + + private[streaming] val waiter = new ContextWaiter /** * Return the associated Spark context */ @@ -170,9 +168,10 @@ class StreamingContext private ( } /** - * Set the context to periodically checkpoint the DStream operations for master - * fault-tolerance. The graph will be checkpointed every batch interval. - * @param directory HDFS-compatible directory where the checkpoint data will be reliably stored + * Set the context to periodically checkpoint the DStream operations for driver + * fault-tolerance. + * @param directory HDFS-compatible directory where the checkpoint data will be reliably stored. + * Note that this must be a fault-tolerant file system like HDFS for */ def checkpoint(directory: String) { if (directory != null) { @@ -187,11 +186,11 @@ class StreamingContext private ( } } - protected[streaming] def initialCheckpoint: Checkpoint = { + private[streaming] def initialCheckpoint: Checkpoint = { if (isCheckpointPresent) cp_ else null } - protected[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement() + private[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement() /** * Create an input stream with any arbitrary user implemented network receiver. @@ -221,7 +220,7 @@ class StreamingContext private ( def actorStream[T: ClassTag]( props: Props, name: String, - storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2, + storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2, supervisorStrategy: SupervisorStrategy = ReceiverSupervisorStrategy.defaultStrategy ): DStream[T] = { networkStream(new ActorReceiver[T](props, name, storageLevel, supervisorStrategy)) @@ -273,6 +272,7 @@ class StreamingContext private ( * @param hostname Hostname to connect to for receiving data * @param port Port to connect to for receiving data * @param storageLevel Storage level to use for storing the received objects + * (default: StorageLevel.MEMORY_AND_DISK_SER_2) * @tparam T Type of the objects in the received blocks */ def rawSocketStream[T: ClassTag]( @@ -412,7 +412,7 @@ class StreamingContext private ( scheduler.listenerBus.addListener(streamingListener) } - protected def validate() { + private def validate() { assert(graph != null, "Graph is null") graph.validate() @@ -426,81 +426,115 @@ class StreamingContext private ( /** * Start the execution of the streams. */ - def start() { + def start() = synchronized { validate() + scheduler.start() + } - // Get the network input streams - val networkInputStreams = graph.getInputStreams().filter(s => s match { - case n: NetworkInputDStream[_] => true - case _ => false - }).map(_.asInstanceOf[NetworkInputDStream[_]]).toArray - - // Start the network input tracker (must start before receivers) - if (networkInputStreams.length > 0) { - networkInputTracker = new NetworkInputTracker(this, networkInputStreams) - networkInputTracker.start() - } - Thread.sleep(1000) + /** + * Wait for the execution to stop. Any exceptions that occurs during the execution + * will be thrown in this thread. + */ + def awaitTermination() { + waiter.waitForStopOrError() + } - // Start the scheduler - scheduler.start() + /** + * Wait for the execution to stop. Any exceptions that occurs during the execution + * will be thrown in this thread. + * @param timeout time to wait in milliseconds + */ + def awaitTermination(timeout: Long) { + waiter.waitForStopOrError(timeout) } /** * Stop the execution of the streams. + * @param stopSparkContext Stop the associated SparkContext or not */ - def stop() { - try { - if (scheduler != null) scheduler.stop() - if (networkInputTracker != null) networkInputTracker.stop() - sc.stop() - logInfo("StreamingContext stopped successfully") - } catch { - case e: Exception => logWarning("Error while stopping", e) - } + def stop(stopSparkContext: Boolean = true) = synchronized { + scheduler.stop() + logInfo("StreamingContext stopped successfully") + waiter.notifyStop() + if (stopSparkContext) sc.stop() } } +/** + * StreamingContext object contains a number of utility functions related to the + * StreamingContext class. + */ + +object StreamingContext extends Logging { -object StreamingContext { + private[streaming] val DEFAULT_CLEANER_TTL = 3600 implicit def toPairDStreamFunctions[K: ClassTag, V: ClassTag](stream: DStream[(K,V)]) = { new PairDStreamFunctions[K, V](stream) } + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the StreamingContext + * will be created by called the provided `creatingFunc`. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param creatingFunc Function to create a new StreamingContext + * @param hadoopConf Optional Hadoop configuration if necessary for reading from the + * file system + * @param createOnError Optional, whether to create a new StreamingContext if there is an + * error in reading checkpoint data. By default, an exception will be + * thrown on error. + */ + def getOrCreate( + checkpointPath: String, + creatingFunc: () => StreamingContext, + hadoopConf: Configuration = new Configuration(), + createOnError: Boolean = false + ): StreamingContext = { + val checkpointOption = try { + CheckpointReader.read(checkpointPath, new SparkConf(), hadoopConf) + } catch { + case e: Exception => + if (createOnError) { + None + } else { + throw e + } + } + checkpointOption.map(new StreamingContext(null, _, null)).getOrElse(creatingFunc()) + } + /** * Find the JAR from which a given class was loaded, to make it easy for users to pass - * their JARs to SparkContext. + * their JARs to StreamingContext. */ def jarOfClass(cls: Class[_]) = SparkContext.jarOfClass(cls) - protected[streaming] def createNewSparkContext(conf: SparkConf): SparkContext = { + private[streaming] def createNewSparkContext(conf: SparkConf): SparkContext = { // Set the default cleaner delay to an hour if not already set. // This should be sufficient for even 1 second batch intervals. - val sc = new SparkContext(conf) - if (MetadataCleaner.getDelaySeconds(sc.conf) < 0) { - MetadataCleaner.setDelaySeconds(sc.conf, 3600) + if (MetadataCleaner.getDelaySeconds(conf) < 0) { + MetadataCleaner.setDelaySeconds(conf, DEFAULT_CLEANER_TTL) } + val sc = new SparkContext(conf) sc } - protected[streaming] def createNewSparkContext( + private[streaming] def createNewSparkContext( master: String, appName: String, sparkHome: String, jars: Seq[String], - environment: Map[String, String]): SparkContext = - { - val sc = new SparkContext(master, appName, sparkHome, jars, environment) - // Set the default cleaner delay to an hour if not already set. - // This should be sufficient for even 1 second batch intervals. - if (MetadataCleaner.getDelaySeconds(sc.conf) < 0) { - MetadataCleaner.setDelaySeconds(sc.conf, 3600) - } - sc + environment: Map[String, String] + ): SparkContext = { + val conf = SparkContext.updatedConf( + new SparkConf(), master, appName, sparkHome, jars, environment) + createNewSparkContext(conf) } - protected[streaming] def rddToFileName[T](prefix: String, suffix: String, time: Time): String = { + private[streaming] def rddToFileName[T](prefix: String, suffix: String, time: Time): String = { if (prefix == null) { time.milliseconds.toString } else if (suffix == null || suffix.length ==0) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala index d29033df3223f31f6c1ee807ad996ae5773c77c2..c92854ccd9a28030dae44d72bfc7b0376c14d60c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala @@ -17,13 +17,14 @@ package org.apache.spark.streaming.api.java -import org.apache.spark.streaming.{Duration, Time, DStream} +import org.apache.spark.streaming.{Duration, Time} import org.apache.spark.api.java.function.{Function => JFunction} import org.apache.spark.api.java.JavaRDD import org.apache.spark.storage.StorageLevel import org.apache.spark.rdd.RDD import scala.reflect.ClassTag +import org.apache.spark.streaming.dstream.DStream /** * A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index 64f38ce1c0283278c358c44572d79881652b3b37..1ec4492bcab9b3c7cc4125dcc64f1d30672b1b68 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -30,6 +30,7 @@ import org.apache.spark.api.java.function.{Function3 => JFunction3, _} import java.util import org.apache.spark.rdd.RDD import JavaDStream._ +import org.apache.spark.streaming.dstream.DStream trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T, R]] extends Serializable { @@ -243,17 +244,39 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T /** * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. + * + * @deprecated As of release 0.9.0, replaced by foreachRDD */ + @Deprecated def foreach(foreachFunc: JFunction[R, Void]) { - dstream.foreach(rdd => foreachFunc.call(wrapRDD(rdd))) + foreachRDD(foreachFunc) } /** * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. + * + * @deprecated As of release 0.9.0, replaced by foreachRDD */ + @Deprecated def foreach(foreachFunc: JFunction2[R, Time, Void]) { - dstream.foreach((rdd, time) => foreachFunc.call(wrapRDD(rdd), time)) + foreachRDD(foreachFunc) + } + + /** + * Apply a function to each RDD in this DStream. This is an output operator, so + * 'this' DStream will be registered as an output stream and therefore materialized. + */ + def foreachRDD(foreachFunc: JFunction[R, Void]) { + dstream.foreachRDD(rdd => foreachFunc.call(wrapRDD(rdd))) + } + + /** + * Apply a function to each RDD in this DStream. This is an output operator, so + * 'this' DStream will be registered as an output stream and therefore materialized. + */ + def foreachRDD(foreachFunc: JFunction2[R, Time, Void]) { + dstream.foreachRDD((rdd, time) => foreachFunc.call(wrapRDD(rdd), time)) } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index 6c3467d4056d721a90e0efc55e903862af12526e..6bb985ca540fff6143e645c5f142498d7a43c492 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -35,6 +35,7 @@ import org.apache.spark.storage.StorageLevel import com.google.common.base.Optional import org.apache.spark.rdd.RDD import org.apache.spark.rdd.PairRDDFunctions +import org.apache.spark.streaming.dstream.DStream class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( implicit val kManifest: ClassTag[K], diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 7068f32517407f20f6243b117260b0d69efce756..a2f0b88cb094f7c8fc8e47c3210d54c7ea3354b8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -35,6 +35,8 @@ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.streaming.scheduler.StreamingListener +import org.apache.hadoop.conf.Configuration +import org.apache.spark.streaming.dstream.DStream /** * A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic @@ -128,10 +130,16 @@ class JavaStreamingContext(val ssc: StreamingContext) { /** * Re-creates a StreamingContext from a checkpoint file. - * @param path Path either to the directory that was specified as the checkpoint directory, or - * to the checkpoint file 'graph' or 'graph.bk'. + * @param path Path to the directory that was specified as the checkpoint directory */ - def this(path: String) = this (new StreamingContext(path)) + def this(path: String) = this(new StreamingContext(path)) + + /** + * Re-creates a StreamingContext from a checkpoint file. + * @param path Path to the directory that was specified as the checkpoint directory + * + */ + def this(path: String, hadoopConf: Configuration) = this(new StreamingContext(path, hadoopConf)) /** The underlying SparkContext */ val sc: JavaSparkContext = new JavaSparkContext(ssc.sc) @@ -143,7 +151,6 @@ class JavaStreamingContext(val ssc: StreamingContext) { * @param hostname Hostname to connect to for receiving data * @param port Port to connect to for receiving data * @param storageLevel Storage level to use for storing the received objects - * (default: StorageLevel.MEMORY_AND_DISK_SER_2) */ def socketTextStream(hostname: String, port: Int, storageLevel: StorageLevel) : JavaDStream[String] = { @@ -153,7 +160,7 @@ class JavaStreamingContext(val ssc: StreamingContext) { /** * Create a input stream from network source hostname:port. Data is received using * a TCP socket and the receive bytes is interpreted as UTF8 encoded \n delimited - * lines. + * lines. Storage level of the data will be the default StorageLevel.MEMORY_AND_DISK_SER_2. * @param hostname Hostname to connect to for receiving data * @param port Port to connect to for receiving data */ @@ -294,6 +301,7 @@ class JavaStreamingContext(val ssc: StreamingContext) { /** * Create an input stream with any arbitrary user implemented actor receiver. + * Storage level of the data will be the default StorageLevel.MEMORY_AND_DISK_SER_2. * @param props Props object defining creation of the actor * @param name Name of the actor * @@ -471,20 +479,116 @@ class JavaStreamingContext(val ssc: StreamingContext) { } /** - * Starts the execution of the streams. + * Start the execution of the streams. */ def start() = ssc.start() /** - * Sstops the execution of the streams. + * Wait for the execution to stop. Any exceptions that occurs during the execution + * will be thrown in this thread. + */ + def awaitTermination() = ssc.awaitTermination() + + /** + * Wait for the execution to stop. Any exceptions that occurs during the execution + * will be thrown in this thread. + * @param timeout time to wait in milliseconds + */ + def awaitTermination(timeout: Long) = ssc.awaitTermination(timeout) + + /** + * Stop the execution of the streams. Will stop the associated JavaSparkContext as well. */ def stop() = ssc.stop() + + /** + * Stop the execution of the streams. + * @param stopSparkContext Stop the associated SparkContext or not + */ + def stop(stopSparkContext: Boolean) = ssc.stop(stopSparkContext) } +/** + * JavaStreamingContext object contains a number of utility functions. + */ object JavaStreamingContext { + + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the provided factory + * will be used to create a JavaStreamingContext. + * + * @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program + * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext + */ + def getOrCreate( + checkpointPath: String, + factory: JavaStreamingContextFactory + ): JavaStreamingContext = { + val ssc = StreamingContext.getOrCreate(checkpointPath, () => { + factory.create.ssc + }) + new JavaStreamingContext(ssc) + } + + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the provided factory + * will be used to create a JavaStreamingContext. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext + * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible + * file system + */ + def getOrCreate( + checkpointPath: String, + hadoopConf: Configuration, + factory: JavaStreamingContextFactory + ): JavaStreamingContext = { + val ssc = StreamingContext.getOrCreate(checkpointPath, () => { + factory.create.ssc + }, hadoopConf) + new JavaStreamingContext(ssc) + } + + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the provided factory + * will be used to create a JavaStreamingContext. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext + * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible + * file system + * @param createOnError Whether to create a new JavaStreamingContext if there is an + * error in reading checkpoint data. + */ + def getOrCreate( + checkpointPath: String, + hadoopConf: Configuration, + factory: JavaStreamingContextFactory, + createOnError: Boolean + ): JavaStreamingContext = { + val ssc = StreamingContext.getOrCreate(checkpointPath, () => { + factory.create.ssc + }, hadoopConf, createOnError) + new JavaStreamingContext(ssc) + } + /** * Find the JAR from which a given class was loaded, to make it easy for users to pass - * their JARs to SparkContext. + * their JARs to StreamingContext. */ def jarOfClass(cls: Class[_]) = SparkContext.jarOfClass(cls).toArray } + +/** + * Factory interface for creating a new JavaStreamingContext + */ +trait JavaStreamingContextFactory { + def create(): JavaStreamingContext +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala similarity index 89% rename from streaming/src/main/scala/org/apache/spark/streaming/DStream.scala rename to streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index 00671ba5206f95382c021c220c081cea484656ab..a7c4cca7eacab61ade04b4636d89c765b89890b6 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -15,21 +15,23 @@ * limitations under the License. */ -package org.apache.spark.streaming +package org.apache.spark.streaming.dstream -import StreamingContext._ -import org.apache.spark.streaming.dstream._ -import org.apache.spark.streaming.scheduler.Job -import org.apache.spark.Logging -import org.apache.spark.rdd.RDD -import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.MetadataCleaner +import scala.deprecated import scala.collection.mutable.HashMap import scala.reflect.ClassTag -import java.io.{ObjectInputStream, IOException, ObjectOutputStream} +import java.io.{IOException, ObjectInputStream, ObjectOutputStream} +import org.apache.spark.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.MetadataCleaner +import org.apache.spark.streaming._ +import org.apache.spark.streaming.StreamingContext._ +import org.apache.spark.streaming.scheduler.Job +import org.apache.spark.streaming.Duration /** * A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous @@ -41,7 +43,7 @@ import java.io.{ObjectInputStream, IOException, ObjectOutputStream} * by a parent DStream. * * This class contains the basic operations available on all DStreams, such as `map`, `filter` and - * `window`. In addition, [[org.apache.spark.streaming.PairDStreamFunctions]] contains operations available + * `window`. In addition, [[org.apache.spark.streaming.dstream.PairDStreamFunctions]] contains operations available * only on DStreams of key-value pairs, such as `groupByKeyAndWindow` and `join`. These operations * are automatically available on any DStream of the right type (e.g., DStream[(Int, Int)] through * implicit conversions when `spark.streaming.StreamingContext._` is imported. @@ -53,7 +55,7 @@ import java.io.{ObjectInputStream, IOException, ObjectOutputStream} */ abstract class DStream[T: ClassTag] ( - @transient protected[streaming] var ssc: StreamingContext + @transient private[streaming] var ssc: StreamingContext ) extends Serializable with Logging { // ======================================================================= @@ -73,31 +75,31 @@ abstract class DStream[T: ClassTag] ( // Methods and fields available on all DStreams // ======================================================================= - // RDDs generated, marked as protected[streaming] so that testsuites can access it + // RDDs generated, marked as private[streaming] so that testsuites can access it @transient - protected[streaming] var generatedRDDs = new HashMap[Time, RDD[T]] () + private[streaming] var generatedRDDs = new HashMap[Time, RDD[T]] () // Time zero for the DStream - protected[streaming] var zeroTime: Time = null + private[streaming] var zeroTime: Time = null // Duration for which the DStream will remember each RDD created - protected[streaming] var rememberDuration: Duration = null + private[streaming] var rememberDuration: Duration = null // Storage level of the RDDs in the stream - protected[streaming] var storageLevel: StorageLevel = StorageLevel.NONE + private[streaming] var storageLevel: StorageLevel = StorageLevel.NONE // Checkpoint details - protected[streaming] val mustCheckpoint = false - protected[streaming] var checkpointDuration: Duration = null - protected[streaming] val checkpointData = new DStreamCheckpointData(this) + private[streaming] val mustCheckpoint = false + private[streaming] var checkpointDuration: Duration = null + private[streaming] val checkpointData = new DStreamCheckpointData(this) // Reference to whole DStream graph - protected[streaming] var graph: DStreamGraph = null + private[streaming] var graph: DStreamGraph = null - protected[streaming] def isInitialized = (zeroTime != null) + private[streaming] def isInitialized = (zeroTime != null) // Duration for which the DStream requires its parent DStream to remember each RDD created - protected[streaming] def parentRememberDuration = rememberDuration + private[streaming] def parentRememberDuration = rememberDuration /** Return the StreamingContext associated with this DStream */ def context = ssc @@ -137,7 +139,7 @@ abstract class DStream[T: ClassTag] ( * the validity of future times is calculated. This method also recursively initializes * its parent DStreams. */ - protected[streaming] def initialize(time: Time) { + private[streaming] def initialize(time: Time) { if (zeroTime != null && zeroTime != time) { throw new Exception("ZeroTime is already initialized to " + zeroTime + ", cannot initialize it again to " + time) @@ -163,7 +165,7 @@ abstract class DStream[T: ClassTag] ( dependencies.foreach(_.initialize(zeroTime)) } - protected[streaming] def validate() { + private[streaming] def validate() { assert(rememberDuration != null, "Remember duration is set to null") assert( @@ -227,7 +229,7 @@ abstract class DStream[T: ClassTag] ( logInfo("Initialized and validated " + this) } - protected[streaming] def setContext(s: StreamingContext) { + private[streaming] def setContext(s: StreamingContext) { if (ssc != null && ssc != s) { throw new Exception("Context is already set in " + this + ", cannot set it again") } @@ -236,7 +238,7 @@ abstract class DStream[T: ClassTag] ( dependencies.foreach(_.setContext(ssc)) } - protected[streaming] def setGraph(g: DStreamGraph) { + private[streaming] def setGraph(g: DStreamGraph) { if (graph != null && graph != g) { throw new Exception("Graph is already set in " + this + ", cannot set it again") } @@ -244,7 +246,7 @@ abstract class DStream[T: ClassTag] ( dependencies.foreach(_.setGraph(graph)) } - protected[streaming] def remember(duration: Duration) { + private[streaming] def remember(duration: Duration) { if (duration != null && duration > rememberDuration) { rememberDuration = duration logInfo("Duration for remembering RDDs set to " + rememberDuration + " for " + this) @@ -253,14 +255,14 @@ abstract class DStream[T: ClassTag] ( } /** Checks whether the 'time' is valid wrt slideDuration for generating RDD */ - protected def isTimeValid(time: Time): Boolean = { + private[streaming] def isTimeValid(time: Time): Boolean = { if (!isInitialized) { throw new Exception (this + " has not been initialized") } else if (time <= zeroTime || ! (time - zeroTime).isMultipleOf(slideDuration)) { logInfo("Time " + time + " is invalid as zeroTime is " + zeroTime + " and slideDuration is " + slideDuration + " and difference is " + (time - zeroTime)) false } else { - logInfo("Time " + time + " is valid") + logDebug("Time " + time + " is valid") true } } @@ -269,7 +271,7 @@ abstract class DStream[T: ClassTag] ( * Retrieve a precomputed RDD of this DStream, or computes the RDD. This is an internal * method that should not be called directly. */ - protected[streaming] def getOrCompute(time: Time): Option[RDD[T]] = { + private[streaming] def getOrCompute(time: Time): Option[RDD[T]] = { // If this DStream was not initialized (i.e., zeroTime not set), then do it // If RDD was already generated, then retrieve it from HashMap generatedRDDs.get(time) match { @@ -310,7 +312,7 @@ abstract class DStream[T: ClassTag] ( * that materializes the corresponding RDD. Subclasses of DStream may override this * to generate their own jobs. */ - protected[streaming] def generateJob(time: Time): Option[Job] = { + private[streaming] def generateJob(time: Time): Option[Job] = { getOrCompute(time) match { case Some(rdd) => { val jobFunc = () => { @@ -329,19 +331,18 @@ abstract class DStream[T: ClassTag] ( * implementation clears the old generated RDDs. Subclasses of DStream may override * this to clear their own metadata along with the generated RDDs. */ - protected[streaming] def clearOldMetadata(time: Time) { - var numForgotten = 0 + private[streaming] def clearMetadata(time: Time) { val oldRDDs = generatedRDDs.filter(_._1 <= (time - rememberDuration)) generatedRDDs --= oldRDDs.keys - logInfo("Cleared " + oldRDDs.size + " RDDs that were older than " + + logDebug("Cleared " + oldRDDs.size + " RDDs that were older than " + (time - rememberDuration) + ": " + oldRDDs.keys.mkString(", ")) - dependencies.foreach(_.clearOldMetadata(time)) + dependencies.foreach(_.clearMetadata(time)) } /* Adds metadata to the Stream while it is running. - * This methd should be overwritten by sublcasses of InputDStream. + * This method should be overwritten by sublcasses of InputDStream. */ - protected[streaming] def addMetadata(metadata: Any) { + private[streaming] def addMetadata(metadata: Any) { if (metadata != null) { logInfo("Dropping Metadata: " + metadata.toString) } @@ -354,21 +355,27 @@ abstract class DStream[T: ClassTag] ( * checkpointData. Subclasses of DStream (especially those of InputDStream) may override * this method to save custom checkpoint data. */ - protected[streaming] def updateCheckpointData(currentTime: Time) { - logInfo("Updating checkpoint data for time " + currentTime) - checkpointData.update() + private[streaming] def updateCheckpointData(currentTime: Time) { + logDebug("Updating checkpoint data for time " + currentTime) + checkpointData.update(currentTime) dependencies.foreach(_.updateCheckpointData(currentTime)) - checkpointData.cleanup() logDebug("Updated checkpoint data for time " + currentTime + ": " + checkpointData) } + private[streaming] def clearCheckpointData(time: Time) { + logDebug("Clearing checkpoint data") + checkpointData.cleanup(time) + dependencies.foreach(_.clearCheckpointData(time)) + logDebug("Cleared checkpoint data") + } + /** * Restore the RDDs in generatedRDDs from the checkpointData. This is an internal method * that should not be called directly. This is a default implementation that recreates RDDs * from the checkpoint file names stored in checkpointData. Subclasses of DStream that * override the updateCheckpointData() method would also need to override this method. */ - protected[streaming] def restoreCheckpointData() { + private[streaming] def restoreCheckpointData() { // Create RDDs from the checkpoint data logInfo("Restoring checkpoint data") checkpointData.restore() @@ -482,15 +489,29 @@ abstract class DStream[T: ClassTag] ( * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. */ - def foreach(foreachFunc: RDD[T] => Unit) { - this.foreach((r: RDD[T], t: Time) => foreachFunc(r)) + @deprecated("use foreachRDD", "0.9.0") + def foreach(foreachFunc: RDD[T] => Unit) = this.foreachRDD(foreachFunc) + + /** + * Apply a function to each RDD in this DStream. This is an output operator, so + * 'this' DStream will be registered as an output stream and therefore materialized. + */ + @deprecated("use foreachRDD", "0.9.0") + def foreach(foreachFunc: (RDD[T], Time) => Unit) = this.foreachRDD(foreachFunc) + + /** + * Apply a function to each RDD in this DStream. This is an output operator, so + * 'this' DStream will be registered as an output stream and therefore materialized. + */ + def foreachRDD(foreachFunc: RDD[T] => Unit) { + this.foreachRDD((r: RDD[T], t: Time) => foreachFunc(r)) } /** * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. */ - def foreach(foreachFunc: (RDD[T], Time) => Unit) { + def foreachRDD(foreachFunc: (RDD[T], Time) => Unit) { ssc.registerOutputStream(new ForEachDStream(this, context.sparkContext.clean(foreachFunc))) } @@ -679,7 +700,7 @@ abstract class DStream[T: ClassTag] ( /** * Return all the RDDs defined by the Interval object (both end times included) */ - protected[streaming] def slice(interval: Interval): Seq[RDD[T]] = { + def slice(interval: Interval): Seq[RDD[T]] = { slice(interval.beginTime, interval.endTime) } @@ -714,7 +735,7 @@ abstract class DStream[T: ClassTag] ( val file = rddToFileName(prefix, suffix, time) rdd.saveAsObjectFile(file) } - this.foreach(saveFunc) + this.foreachRDD(saveFunc) } /** @@ -727,7 +748,7 @@ abstract class DStream[T: ClassTag] ( val file = rddToFileName(prefix, suffix, time) rdd.saveAsTextFile(file) } - this.foreach(saveFunc) + this.foreachRDD(saveFunc) } def register() { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala new file mode 100644 index 0000000000000000000000000000000000000000..2da4127f47f142d882ba6e19c10b6c40b5f84306 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.dstream + +import scala.collection.mutable.HashMap +import scala.reflect.ClassTag +import java.io.{ObjectInputStream, IOException} +import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.FileSystem +import org.apache.spark.Logging +import org.apache.spark.streaming.Time + +private[streaming] +class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T]) + extends Serializable with Logging { + protected val data = new HashMap[Time, AnyRef]() + + // Mapping of the batch time to the checkpointed RDD file of that time + @transient private var timeToCheckpointFile = new HashMap[Time, String] + // Mapping of the batch time to the time of the oldest checkpointed RDD + // in that batch's checkpoint data + @transient private var timeToOldestCheckpointFileTime = new HashMap[Time, Time] + + @transient private var fileSystem : FileSystem = null + protected[streaming] def currentCheckpointFiles = data.asInstanceOf[HashMap[Time, String]] + + /** + * Updates the checkpoint data of the DStream. This gets called every time + * the graph checkpoint is initiated. Default implementation records the + * checkpoint files to which the generate RDDs of the DStream has been saved. + */ + def update(time: Time) { + + // Get the checkpointed RDDs from the generated RDDs + val checkpointFiles = dstream.generatedRDDs.filter(_._2.getCheckpointFile.isDefined) + .map(x => (x._1, x._2.getCheckpointFile.get)) + logDebug("Current checkpoint files:\n" + checkpointFiles.toSeq.mkString("\n")) + + // Add the checkpoint files to the data to be serialized + if (!checkpointFiles.isEmpty) { + currentCheckpointFiles.clear() + currentCheckpointFiles ++= checkpointFiles + // Add the current checkpoint files to the map of all checkpoint files + // This will be used to delete old checkpoint files + timeToCheckpointFile ++= currentCheckpointFiles + // Remember the time of the oldest checkpoint RDD in current state + timeToOldestCheckpointFileTime(time) = currentCheckpointFiles.keys.min(Time.ordering) + } + } + + /** + * Cleanup old checkpoint data. This gets called after a checkpoint of `time` has been + * written to the checkpoint directory. + */ + def cleanup(time: Time) { + // Get the time of the oldest checkpointed RDD that was written as part of the + // checkpoint of `time` + timeToOldestCheckpointFileTime.remove(time) match { + case Some(lastCheckpointFileTime) => + // Find all the checkpointed RDDs (i.e. files) that are older than `lastCheckpointFileTime` + // This is because checkpointed RDDs older than this are not going to be needed + // even after master fails, as the checkpoint data of `time` does not refer to those files + val filesToDelete = timeToCheckpointFile.filter(_._1 < lastCheckpointFileTime) + logDebug("Files to delete:\n" + filesToDelete.mkString(",")) + filesToDelete.foreach { + case (time, file) => + try { + val path = new Path(file) + if (fileSystem == null) { + fileSystem = path.getFileSystem(dstream.ssc.sparkContext.hadoopConfiguration) + } + fileSystem.delete(path, true) + timeToCheckpointFile -= time + logInfo("Deleted checkpoint file '" + file + "' for time " + time) + } catch { + case e: Exception => + logWarning("Error deleting old checkpoint file '" + file + "' for time " + time, e) + fileSystem = null + } + } + case None => + logInfo("Nothing to delete") + } + } + + /** + * Restore the checkpoint data. This gets called once when the DStream graph + * (along with its DStreams) are being restored from a graph checkpoint file. + * Default implementation restores the RDDs from their checkpoint files. + */ + def restore() { + // Create RDDs from the checkpoint data + currentCheckpointFiles.foreach { + case(time, file) => { + logInfo("Restoring checkpointed RDD for time " + time + " from file '" + file + "'") + dstream.generatedRDDs += ((time, dstream.context.sparkContext.checkpointFile[T](file))) + } + } + } + + override def toString() = { + "[\n" + currentCheckpointFiles.size + " checkpoint files \n" + currentCheckpointFiles.mkString("\n") + "\n]" + } + + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream) { + ois.defaultReadObject() + timeToOldestCheckpointFileTime = new HashMap[Time, Time] + timeToCheckpointFile = new HashMap[Time, String] + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index fb9eda899672094c98ca43f86caaea7f38de5a03..37c46b26a50b54bdd1b213b2fb627aa0ea895709 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -23,10 +23,10 @@ import scala.reflect.ClassTag import org.apache.hadoop.fs.{FileSystem, Path, PathFilter} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} -import org.apache.spark.SparkException import org.apache.spark.rdd.RDD import org.apache.spark.rdd.UnionRDD -import org.apache.spark.streaming.{DStreamCheckpointData, StreamingContext, Time} +import org.apache.spark.streaming.{StreamingContext, Time} +import org.apache.spark.util.TimeStampedHashMap private[streaming] @@ -46,6 +46,8 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas @transient private var path_ : Path = null @transient private var fs_ : FileSystem = null @transient private[streaming] var files = new HashMap[Time, Array[String]] + @transient private var fileModTimes = new TimeStampedHashMap[String, Long](true) + @transient private var lastNewFileFindingTime = 0L override def start() { if (newFilesOnly) { @@ -88,14 +90,16 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas } /** Clear the old time-to-files mappings along with old RDDs */ - protected[streaming] override def clearOldMetadata(time: Time) { - super.clearOldMetadata(time) + protected[streaming] override def clearMetadata(time: Time) { + super.clearMetadata(time) val oldFiles = files.filter(_._1 <= (time - rememberDuration)) files --= oldFiles.keys logInfo("Cleared " + oldFiles.size + " old files that were older than " + (time - rememberDuration) + ": " + oldFiles.keys.mkString(", ")) logDebug("Cleared files are:\n" + oldFiles.map(p => (p._1, p._2.mkString(", "))).mkString("\n")) + // Delete file mod times that weren't accessed in the last round of getting new files + fileModTimes.clearOldValues(lastNewFileFindingTime - 1) } /** @@ -104,8 +108,19 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas */ private def findNewFiles(currentTime: Long): (Seq[String], Long, Seq[String]) = { logDebug("Trying to get new files for time " + currentTime) + lastNewFileFindingTime = System.currentTimeMillis val filter = new CustomPathFilter(currentTime) - val newFiles = fs.listStatus(path, filter).map(_.getPath.toString) + val newFiles = fs.listStatus(directoryPath, filter).map(_.getPath.toString) + val timeTaken = System.currentTimeMillis - lastNewFileFindingTime + logInfo("Finding new files took " + timeTaken + " ms") + logDebug("# cached file times = " + fileModTimes.size) + if (timeTaken > slideDuration.milliseconds) { + logWarning( + "Time taken to find new files exceeds the batch size. " + + "Consider increasing the batch size or reduceing the number of " + + "files in the monitored directory." + ) + } (newFiles, filter.latestModTime, filter.latestModTimeFiles.toSeq) } @@ -122,16 +137,21 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas new UnionRDD(context.sparkContext, fileRDDs) } - private def path: Path = { + private def directoryPath: Path = { if (path_ == null) path_ = new Path(directory) path_ } private def fs: FileSystem = { - if (fs_ == null) fs_ = path.getFileSystem(new Configuration()) + if (fs_ == null) fs_ = directoryPath.getFileSystem(new Configuration()) fs_ } + private def getFileModTime(path: Path) = { + // Get file mod time from cache or fetch it from the file system + fileModTimes.getOrElseUpdate(path.toString, fs.getFileStatus(path).getModificationTime()) + } + private def reset() { fs_ = null } @@ -142,6 +162,7 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas ois.defaultReadObject() generatedRDDs = new HashMap[Time, RDD[(K,V)]] () files = new HashMap[Time, Array[String]] + fileModTimes = new TimeStampedHashMap[String, Long](true) } /** @@ -153,15 +174,15 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas def hadoopFiles = data.asInstanceOf[HashMap[Time, Array[String]]] - override def update() { + override def update(time: Time) { hadoopFiles.clear() hadoopFiles ++= files } - override def cleanup() { } + override def cleanup(time: Time) { } override def restore() { - hadoopFiles.foreach { + hadoopFiles.toSeq.sortBy(_._1)(Time.ordering).foreach { case (t, f) => { // Restore the metadata in both files and generatedRDDs logInfo("Restoring files for time " + t + " - " + @@ -187,14 +208,13 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas // Latest file mod time seen in this round of fetching files and its corresponding files var latestModTime = 0L val latestModTimeFiles = new HashSet[String]() - def accept(path: Path): Boolean = { try { if (!filter(path)) { // Reject file if it does not satisfy filter logDebug("Rejected by filter " + path) return false } - val modTime = fs.getFileStatus(path).getModificationTime() + val modTime = getFileModTime(path) logDebug("Mod time for " + path + " is " + modTime) if (modTime < prevModTime) { logDebug("Mod time less than last mod time") @@ -219,7 +239,7 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas reset() return false } - return true + true } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala index db2e0a4ceef0366ca0deefd5650df1ed0f336d94..c81534ae584ea05e1fc14800544dba1ffc326668 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.streaming.{Duration, DStream, Time} +import org.apache.spark.streaming.{Duration, Time} import org.apache.spark.rdd.RDD import scala.reflect.ClassTag diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala index 244dc3ee4fa143c8bde0bc08045c67545dc6929c..658623455498ceb038915e791ac04c502f2f1909 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.streaming.{Duration, DStream, Time} +import org.apache.spark.streaming.{Duration, Time} import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext._ import scala.reflect.ClassTag diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala index 336c4b7a92dc6c3754eb16436c34fa6ba4d2ea18..c7bb2833eabb8878cecc6cbc9beb50b4b6ad4227 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.streaming.{Duration, DStream, Time} +import org.apache.spark.streaming.{Duration, Time} import org.apache.spark.rdd.RDD import scala.reflect.ClassTag diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala index 364abcde68c95125d887a6ed0b40ad52611b63eb..905bc723f69a9e27dcccedbe6399240c32f5168e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala @@ -18,7 +18,7 @@ package org.apache.spark.streaming.dstream import org.apache.spark.rdd.RDD -import org.apache.spark.streaming.{Duration, DStream, Time} +import org.apache.spark.streaming.{Duration, Time} import org.apache.spark.streaming.scheduler.Job import scala.reflect.ClassTag diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala index 23136f44fa3103d76bfe13a6a4d9ba21706db9c1..a9bb51f05404833487d7e9f8579df09345b0884e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.streaming.{Duration, DStream, Time} +import org.apache.spark.streaming.{Duration, Time} import org.apache.spark.rdd.RDD import scala.reflect.ClassTag diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala index f01e67fe13096ca3b5db44e2b74c52fc573ec0a1..a1075ad304ef6c320cb2a05c04355446de516c5a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.streaming.{Time, Duration, StreamingContext, DStream} +import org.apache.spark.streaming.{Time, Duration, StreamingContext} import scala.reflect.ClassTag @@ -43,7 +43,7 @@ abstract class InputDStream[T: ClassTag] (@transient ssc_ : StreamingContext) * This ensures that InputDStream.compute() is called strictly on increasing * times. */ - override protected def isTimeValid(time: Time): Boolean = { + override private[streaming] def isTimeValid(time: Time): Boolean = { if (!super.isTimeValid(time)) { false // Time not valid } else { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala index 8a04060e5b6c11360fbcec5d02777aee7cf0753f..3d8ee29df1e821fce5e48a152643071714a25d7c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.streaming.{Duration, DStream, Time} +import org.apache.spark.streaming.{Duration, Time} import org.apache.spark.rdd.RDD import scala.reflect.ClassTag diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala index 0ce364fd4632829d3b7f80945e27633626d5e346..7aea1f945d9db60d059aa09b7f94474854e2e5d6 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.streaming.{Duration, DStream, Time} +import org.apache.spark.streaming.{Duration, Time} import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext._ import scala.reflect.ClassTag diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala index c0b7491d096cd64bc37d7b2d5ce97ba00feded48..02704a8d1c2e0757d4ca1ff6cdc8838ded58e9b3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.streaming.{Duration, DStream, Time} +import org.apache.spark.streaming.{Duration, Time} import org.apache.spark.rdd.RDD import scala.reflect.ClassTag diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala index 27d474c0a0459a3aa556ba8e7bea3bb967396dbc..0f1f6fc2cec4a846404dc90e3fc3e9a56f1dd312 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala @@ -68,7 +68,7 @@ abstract class NetworkInputDStream[T: ClassTag](@transient ssc_ : StreamingConte // then this returns an empty RDD. This may happen when recovering from a // master failure if (validTime >= graph.startTime) { - val blockIds = ssc.networkInputTracker.getBlockIds(id, validTime) + val blockIds = ssc.scheduler.networkInputTracker.getBlockIds(id, validTime) Some(new BlockRDD[T](ssc.sc, blockIds)) } else { Some(new BlockRDD[T](ssc.sc, Array[BlockId]())) @@ -175,7 +175,7 @@ abstract class NetworkReceiver[T: ClassTag]() extends Serializable with Logging private class NetworkReceiverActor extends Actor { logInfo("Attempting to register with tracker") val ip = env.conf.get("spark.driver.host", "localhost") - val port = env.conf.get("spark.driver.port", "7077").toInt + val port = env.conf.getInt("spark.driver.port", 7077) val url = "akka.tcp://spark@%s:%s/user/NetworkInputTracker".format(ip, port) val tracker = env.actorSystem.actorSelection(url) val timeout = 5.seconds @@ -212,7 +212,7 @@ abstract class NetworkReceiver[T: ClassTag]() extends Serializable with Logging case class Block(id: BlockId, buffer: ArrayBuffer[T], metadata: Any = null) val clock = new SystemClock() - val blockInterval = env.conf.get("spark.streaming.blockInterval", "200").toLong + val blockInterval = env.conf.getLong("spark.streaming.blockInterval", 200) val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer) val blockStorageLevel = storageLevel val blocksForPushing = new ArrayBlockingQueue[Block](1000) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala similarity index 99% rename from streaming/src/main/scala/org/apache/spark/streaming/PairDStreamFunctions.scala rename to streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index 56dbcbda23a7029bda765437ce235430136fa4df..6b3e48382e0c403aaad5ecd25d1ce20baae7db37 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.streaming +package org.apache.spark.streaming.dstream import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.streaming.dstream._ @@ -33,6 +33,7 @@ import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} import org.apache.hadoop.mapred.OutputFormat import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.conf.Configuration +import org.apache.spark.streaming.{Time, Duration} class PairDStreamFunctions[K: ClassTag, V: ClassTag](self: DStream[(K,V)]) extends Serializable { @@ -582,7 +583,7 @@ extends Serializable { val file = rddToFileName(prefix, suffix, time) rdd.saveAsHadoopFile(file, keyClass, valueClass, outputFormatClass, conf) } - self.foreach(saveFunc) + self.foreachRDD(saveFunc) } /** @@ -612,7 +613,7 @@ extends Serializable { val file = rddToFileName(prefix, suffix, time) rdd.saveAsNewAPIHadoopFile(file, keyClass, valueClass, outputFormatClass, conf) } - self.foreach(saveFunc) + self.foreachRDD(saveFunc) } private def getKeyClass() = implicitly[ClassTag[K]].runtimeClass diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala index db56345ca84fb3749cdce96b3968d103cbbcc9f0..7a6b1ea35eb13163e15f0a77ec19f67a56bc0b9a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala @@ -26,7 +26,7 @@ import org.apache.spark.SparkContext._ import org.apache.spark.storage.StorageLevel import scala.collection.mutable.ArrayBuffer -import org.apache.spark.streaming.{Duration, Interval, Time, DStream} +import org.apache.spark.streaming.{Duration, Interval, Time} import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala index 84e69f277b22e97d2ac8303bc5d784faa8b96b06..880a89bc368956ce575f9a5489a5272cb6da6784 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.dstream import org.apache.spark.Partitioner import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext._ -import org.apache.spark.streaming.{Duration, DStream, Time} +import org.apache.spark.streaming.{Duration, Time} import scala.reflect.ClassTag private[streaming] diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala index e0ff3ccba4e7dd153c7af6e8bba5c7d9aee9bc88..9d8889b6553566c1249400d0dae09ce1e3146e00 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.Partitioner import org.apache.spark.SparkContext._ import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{Duration, Time, DStream} +import org.apache.spark.streaming.{Duration, Time} import scala.reflect.ClassTag @@ -65,7 +65,7 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner) val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning) //logDebug("Generating state RDD for time " + validTime) - return Some(stateRDD) + Some(stateRDD) } case None => { // If parent RDD does not exist @@ -76,7 +76,7 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( updateFuncLocal(i) } val stateRDD = prevStateRDD.mapPartitions(finalFunc, preservePartitioning) - return Some(stateRDD) + Some(stateRDD) } } } @@ -98,11 +98,11 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( val groupedRDD = parentRDD.groupByKey(partitioner) val sessionRDD = groupedRDD.mapPartitions(finalFunc, preservePartitioning) //logDebug("Generating state RDD for time " + validTime + " (first)") - return Some(sessionRDD) + Some(sessionRDD) } case None => { // If parent RDD does not exist, then nothing to do! //logDebug("Not generating state RDD (no previous state, no parent)") - return None + None } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala index aeea060df7161fe33206a40e488a780bb1dedd9d..7cd4554282ca18a725eec959703948bb7b496436 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala @@ -18,7 +18,7 @@ package org.apache.spark.streaming.dstream import org.apache.spark.rdd.RDD -import org.apache.spark.streaming.{Duration, DStream, Time} +import org.apache.spark.streaming.{Duration, Time} import scala.reflect.ClassTag private[streaming] diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala index 0d84ec84f2c6335e98316a3263fbeae2d520d77d..4ecba03ab5d2f84691d75b26c3caffdb4e362b65 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala @@ -17,9 +17,8 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.streaming.{Duration, DStream, Time} +import org.apache.spark.streaming.{Duration, Time} import org.apache.spark.rdd.RDD -import collection.mutable.ArrayBuffer import org.apache.spark.rdd.UnionRDD import scala.collection.mutable.ArrayBuffer diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala index 89c43ff935bb415183abe74e943ed953e4531406..6301772468737afe993c30cd3a84989ce57e7258 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala @@ -32,13 +32,14 @@ class WindowedDStream[T: ClassTag]( extends DStream[T](parent.ssc) { if (!_windowDuration.isMultipleOf(parent.slideDuration)) - throw new Exception("The window duration of WindowedDStream (" + _slideDuration + ") " + - "must be multiple of the slide duration of parent DStream (" + parent.slideDuration + ")") + throw new Exception("The window duration of windowed DStream (" + _slideDuration + ") " + + "must be a multiple of the slide duration of parent DStream (" + parent.slideDuration + ")") if (!_slideDuration.isMultipleOf(parent.slideDuration)) - throw new Exception("The slide duration of WindowedDStream (" + _slideDuration + ") " + - "must be multiple of the slide duration of parent DStream (" + parent.slideDuration + ")") + throw new Exception("The slide duration of windowed DStream (" + _slideDuration + ") " + + "must be a multiple of the slide duration of parent DStream (" + parent.slideDuration + ")") + // Persist parent level by default, as those RDDs are going to be obviously reused. parent.persist(StorageLevel.MEMORY_ONLY_SER) def windowDuration: Duration = _windowDuration @@ -49,6 +50,14 @@ class WindowedDStream[T: ClassTag]( override def parentRememberDuration: Duration = rememberDuration + windowDuration + override def persist(level: StorageLevel): DStream[T] = { + // Do not let this windowed DStream be persisted as windowed (union-ed) RDDs share underlying + // RDDs and persisting the windowed RDDs would store numerous copies of the underlying data. + // Instead control the persistence of the parent DStream. + parent.persist(level) + this + } + override def compute(validTime: Time): Option[RDD[T]] = { val currentWindow = new Interval(validTime - windowDuration + parent.slideDuration, validTime) val rddsInWindow = parent.slice(currentWindow) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala index 7341bfbc99399b94a1143e12752c1120bf3fbdb3..7e0f6b2cdfc084446eb238c8d5d79e8617beea46 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala @@ -18,6 +18,7 @@ package org.apache.spark.streaming.scheduler import org.apache.spark.streaming.Time +import scala.util.Try /** * Class representing a Spark computation. It may contain multiple Spark jobs. @@ -25,12 +26,10 @@ import org.apache.spark.streaming.Time private[streaming] class Job(val time: Time, func: () => _) { var id: String = _ + var result: Try[_] = null - def run(): Long = { - val startTime = System.currentTimeMillis - func() - val stopTime = System.currentTimeMillis - (stopTime - startTime) + def run() { + result = Try(func()) } def setId(number: Int) { @@ -38,4 +37,4 @@ class Job(val time: Time, func: () => _) { } override def toString = id -} \ No newline at end of file +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 5f8be93a9851846aa14063d8a0fe6678c5cf067a..b5f11d344068d828330377bfa37c251e2f208736 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -17,17 +17,18 @@ package org.apache.spark.streaming.scheduler -import akka.actor.{Props, Actor} -import org.apache.spark.SparkEnv -import org.apache.spark.Logging +import akka.actor.{ActorRef, ActorSystem, Props, Actor} +import org.apache.spark.{SparkException, SparkEnv, Logging} import org.apache.spark.streaming.{Checkpoint, Time, CheckpointWriter} import org.apache.spark.streaming.util.{ManualClock, RecurringTimer, Clock} +import scala.util.{Failure, Success, Try} /** Event classes for JobGenerator */ private[scheduler] sealed trait JobGeneratorEvent private[scheduler] case class GenerateJobs(time: Time) extends JobGeneratorEvent -private[scheduler] case class ClearOldMetadata(time: Time) extends JobGeneratorEvent +private[scheduler] case class ClearMetadata(time: Time) extends JobGeneratorEvent private[scheduler] case class DoCheckpoint(time: Time) extends JobGeneratorEvent +private[scheduler] case class ClearCheckpointData(time: Time) extends JobGeneratorEvent /** * This class generates jobs from DStreams as well as drives checkpointing and cleaning @@ -36,29 +37,38 @@ private[scheduler] case class DoCheckpoint(time: Time) extends JobGeneratorEvent private[streaming] class JobGenerator(jobScheduler: JobScheduler) extends Logging { - val ssc = jobScheduler.ssc - val graph = ssc.graph - val eventProcessorActor = ssc.env.actorSystem.actorOf(Props(new Actor { - def receive = { - case event: JobGeneratorEvent => - logDebug("Got event of type " + event.getClass.getName) - processEvent(event) - } - })) + private val ssc = jobScheduler.ssc + private val graph = ssc.graph val clock = { val clockClass = ssc.sc.conf.get( "spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock") Class.forName(clockClass).newInstance().asInstanceOf[Clock] } - val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds, - longTime => eventProcessorActor ! GenerateJobs(new Time(longTime))) - lazy val checkpointWriter = if (ssc.checkpointDuration != null && ssc.checkpointDir != null) { - new CheckpointWriter(ssc.conf, ssc.checkpointDir, ssc.sparkContext.hadoopConfiguration) + private val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds, + longTime => eventActor ! GenerateJobs(new Time(longTime))) + private lazy val checkpointWriter = if (ssc.checkpointDuration != null && ssc.checkpointDir != null) { + new CheckpointWriter(this, ssc.conf, ssc.checkpointDir, ssc.sparkContext.hadoopConfiguration) } else { null } + // eventActor is created when generator starts. + // This not being null means the scheduler has been started and not stopped + private var eventActor: ActorRef = null + + /** Start generation of jobs */ def start() = synchronized { + if (eventActor != null) { + throw new SparkException("JobGenerator already started") + } + + eventActor = ssc.env.actorSystem.actorOf(Props(new Actor { + def receive = { + case event: JobGeneratorEvent => + logDebug("Got event of type " + event.getClass.getName) + processEvent(event) + } + }), "JobGenerator") if (ssc.isCheckpointPresent) { restart() } else { @@ -66,26 +76,35 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { } } - def stop() { - timer.stop() - if (checkpointWriter != null) checkpointWriter.stop() - ssc.graph.stop() - logInfo("JobGenerator stopped") + /** Stop generation of jobs */ + def stop() = synchronized { + if (eventActor != null) { + timer.stop() + ssc.env.actorSystem.stop(eventActor) + if (checkpointWriter != null) checkpointWriter.stop() + ssc.graph.stop() + logInfo("JobGenerator stopped") + } } /** * On batch completion, clear old metadata and checkpoint computation. */ - private[scheduler] def onBatchCompletion(time: Time) { - eventProcessorActor ! ClearOldMetadata(time) + def onBatchCompletion(time: Time) { + eventActor ! ClearMetadata(time) + } + + def onCheckpointCompletion(time: Time) { + eventActor ! ClearCheckpointData(time) } /** Processes all events */ private def processEvent(event: JobGeneratorEvent) { event match { case GenerateJobs(time) => generateJobs(time) - case ClearOldMetadata(time) => clearOldMetadata(time) + case ClearMetadata(time) => clearMetadata(time) case DoCheckpoint(time) => doCheckpoint(time) + case ClearCheckpointData(time) => clearCheckpointData(time) } } @@ -104,7 +123,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { // or if the property is defined set it to that time if (clock.isInstanceOf[ManualClock]) { val lastTime = ssc.initialCheckpoint.checkpointTime.milliseconds - val jumpTime = ssc.sc.conf.get("spark.streaming.manualClock.jump", "0").toLong + val jumpTime = ssc.sc.conf.getLong("spark.streaming.manualClock.jump", 0) clock.asInstanceOf[ManualClock].setTime(lastTime + jumpTime) } @@ -115,14 +134,17 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { val checkpointTime = ssc.initialCheckpoint.checkpointTime val restartTime = new Time(timer.getRestartTime(graph.zeroTime.milliseconds)) val downTimes = checkpointTime.until(restartTime, batchDuration) - logInfo("Batches during down time: " + downTimes.mkString(", ")) + logInfo("Batches during down time (" + downTimes.size + " batches): " + + downTimes.mkString(", ")) // Batches that were unprocessed before failure - val pendingTimes = ssc.initialCheckpoint.pendingTimes - logInfo("Batches pending processing: " + pendingTimes.mkString(", ")) + val pendingTimes = ssc.initialCheckpoint.pendingTimes.sorted(Time.ordering) + logInfo("Batches pending processing (" + pendingTimes.size + " batches): " + + pendingTimes.mkString(", ")) // Reschedule jobs for these times val timesToReschedule = (pendingTimes ++ downTimes).distinct.sorted(Time.ordering) - logInfo("Batches to reschedule: " + timesToReschedule.mkString(", ")) + logInfo("Batches to reschedule (" + timesToReschedule.size + " batches): " + + timesToReschedule.mkString(", ")) timesToReschedule.foreach(time => jobScheduler.runJobs(time, graph.generateJobs(time)) ) @@ -135,15 +157,22 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { /** Generate jobs and perform checkpoint for the given `time`. */ private def generateJobs(time: Time) { SparkEnv.set(ssc.env) - logInfo("\n-----------------------------------------------------\n") - jobScheduler.runJobs(time, graph.generateJobs(time)) - eventProcessorActor ! DoCheckpoint(time) + Try(graph.generateJobs(time)) match { + case Success(jobs) => jobScheduler.runJobs(time, jobs) + case Failure(e) => jobScheduler.reportError("Error generating jobs for time " + time, e) + } + eventActor ! DoCheckpoint(time) } /** Clear DStream metadata for the given `time`. */ - private def clearOldMetadata(time: Time) { - ssc.graph.clearOldMetadata(time) - eventProcessorActor ! DoCheckpoint(time) + private def clearMetadata(time: Time) { + ssc.graph.clearMetadata(time) + eventActor ! DoCheckpoint(time) + } + + /** Clear DStream checkpoint data for the given `time`. */ + private def clearCheckpointData(time: Time) { + ssc.graph.clearCheckpointData(time) } /** Perform checkpoint for the give `time`. */ @@ -155,4 +184,3 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { } } } - diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 9304fc1a9338d6b7336a607ccbd09346fcc6ce9b..de675d3c7fb94b4983e1583ea782c4a29b8dce73 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -17,36 +17,68 @@ package org.apache.spark.streaming.scheduler -import org.apache.spark.Logging -import org.apache.spark.SparkEnv +import scala.util.{Failure, Success, Try} +import scala.collection.JavaConversions._ import java.util.concurrent.{TimeUnit, ConcurrentHashMap, Executors} -import scala.collection.mutable.HashSet +import akka.actor.{ActorRef, Actor, Props} +import org.apache.spark.{SparkException, Logging, SparkEnv} import org.apache.spark.streaming._ + +private[scheduler] sealed trait JobSchedulerEvent +private[scheduler] case class JobStarted(job: Job) extends JobSchedulerEvent +private[scheduler] case class JobCompleted(job: Job) extends JobSchedulerEvent +private[scheduler] case class ErrorReported(msg: String, e: Throwable) extends JobSchedulerEvent + /** * This class schedules jobs to be run on Spark. It uses the JobGenerator to generate - * the jobs and runs them using a thread pool. Number of threads + * the jobs and runs them using a thread pool. */ private[streaming] class JobScheduler(val ssc: StreamingContext) extends Logging { - val jobSets = new ConcurrentHashMap[Time, JobSet] - val numConcurrentJobs = ssc.conf.get("spark.streaming.concurrentJobs", "1").toInt - val executor = Executors.newFixedThreadPool(numConcurrentJobs) - val generator = new JobGenerator(this) + private val jobSets = new ConcurrentHashMap[Time, JobSet] + private val numConcurrentJobs = ssc.conf.getInt("spark.streaming.concurrentJobs", 1) + private val executor = Executors.newFixedThreadPool(numConcurrentJobs) + private val jobGenerator = new JobGenerator(this) + val clock = jobGenerator.clock val listenerBus = new StreamingListenerBus() - def clock = generator.clock + // These two are created only when scheduler starts. + // eventActor not being null means the scheduler has been started and not stopped + var networkInputTracker: NetworkInputTracker = null + private var eventActor: ActorRef = null + + + def start() = synchronized { + if (eventActor != null) { + throw new SparkException("JobScheduler already started") + } - def start() { - generator.start() + eventActor = ssc.env.actorSystem.actorOf(Props(new Actor { + def receive = { + case event: JobSchedulerEvent => processEvent(event) + } + }), "JobScheduler") + listenerBus.start() + networkInputTracker = new NetworkInputTracker(ssc) + networkInputTracker.start() + Thread.sleep(1000) + jobGenerator.start() + logInfo("JobScheduler started") } - def stop() { - generator.stop() - executor.shutdown() - if (!executor.awaitTermination(5, TimeUnit.SECONDS)) { - executor.shutdownNow() + def stop() = synchronized { + if (eventActor != null) { + jobGenerator.stop() + networkInputTracker.stop() + executor.shutdown() + if (!executor.awaitTermination(2, TimeUnit.SECONDS)) { + executor.shutdownNow() + } + listenerBus.stop() + ssc.env.actorSystem.stop(eventActor) + logInfo("JobScheduler stopped") } } @@ -61,46 +93,67 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { } } - def getPendingTimes(): Array[Time] = { - jobSets.keySet.toArray(new Array[Time](0)) + def getPendingTimes(): Seq[Time] = { + jobSets.keySet.toSeq + } + + def reportError(msg: String, e: Throwable) { + eventActor ! ErrorReported(msg, e) } - private def beforeJobStart(job: Job) { + private def processEvent(event: JobSchedulerEvent) { + try { + event match { + case JobStarted(job) => handleJobStart(job) + case JobCompleted(job) => handleJobCompletion(job) + case ErrorReported(m, e) => handleError(m, e) + } + } catch { + case e: Throwable => + reportError("Error in job scheduler", e) + } + } + + private def handleJobStart(job: Job) { val jobSet = jobSets.get(job.time) if (!jobSet.hasStarted) { - listenerBus.post(StreamingListenerBatchStarted(jobSet.toBatchInfo())) + listenerBus.post(StreamingListenerBatchStarted(jobSet.toBatchInfo)) } - jobSet.beforeJobStart(job) + jobSet.handleJobStart(job) logInfo("Starting job " + job.id + " from job set of time " + jobSet.time) - SparkEnv.set(generator.ssc.env) + SparkEnv.set(ssc.env) } - private def afterJobEnd(job: Job) { - val jobSet = jobSets.get(job.time) - jobSet.afterJobStop(job) - logInfo("Finished job " + job.id + " from job set of time " + jobSet.time) - if (jobSet.hasCompleted) { - jobSets.remove(jobSet.time) - generator.onBatchCompletion(jobSet.time) - logInfo("Total delay: %.3f s for time %s (execution: %.3f s)".format( - jobSet.totalDelay / 1000.0, jobSet.time.toString, - jobSet.processingDelay / 1000.0 - )) - listenerBus.post(StreamingListenerBatchCompleted(jobSet.toBatchInfo())) + private def handleJobCompletion(job: Job) { + job.result match { + case Success(_) => + val jobSet = jobSets.get(job.time) + jobSet.handleJobCompletion(job) + logInfo("Finished job " + job.id + " from job set of time " + jobSet.time) + if (jobSet.hasCompleted) { + jobSets.remove(jobSet.time) + jobGenerator.onBatchCompletion(jobSet.time) + logInfo("Total delay: %.3f s for time %s (execution: %.3f s)".format( + jobSet.totalDelay / 1000.0, jobSet.time.toString, + jobSet.processingDelay / 1000.0 + )) + listenerBus.post(StreamingListenerBatchCompleted(jobSet.toBatchInfo)) + } + case Failure(e) => + reportError("Error running job " + job, e) } } - private[streaming] - class JobHandler(job: Job) extends Runnable { + private def handleError(msg: String, e: Throwable) { + logError(msg, e) + ssc.waiter.notifyError(e) + } + + private class JobHandler(job: Job) extends Runnable { def run() { - beforeJobStart(job) - try { - job.run() - } catch { - case e: Exception => - logError("Running " + job + " failed", e) - } - afterJobEnd(job) + eventActor ! JobStarted(job) + job.run() + eventActor ! JobCompleted(job) } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala index 57268674ead9dd22a9c77a941f0195544400999c..fcf303aee6cd73e34b3af85162c65cd36b63e349 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.scheduler -import scala.collection.mutable.HashSet +import scala.collection.mutable.{ArrayBuffer, HashSet} import org.apache.spark.streaming.Time /** Class representing a set of Jobs @@ -27,25 +27,25 @@ private[streaming] case class JobSet(time: Time, jobs: Seq[Job]) { private val incompleteJobs = new HashSet[Job]() - var submissionTime = System.currentTimeMillis() // when this jobset was submitted - var processingStartTime = -1L // when the first job of this jobset started processing - var processingEndTime = -1L // when the last job of this jobset finished processing + private val submissionTime = System.currentTimeMillis() // when this jobset was submitted + private var processingStartTime = -1L // when the first job of this jobset started processing + private var processingEndTime = -1L // when the last job of this jobset finished processing jobs.zipWithIndex.foreach { case (job, i) => job.setId(i) } incompleteJobs ++= jobs - def beforeJobStart(job: Job) { + def handleJobStart(job: Job) { if (processingStartTime < 0) processingStartTime = System.currentTimeMillis() } - def afterJobStop(job: Job) { + def handleJobCompletion(job: Job) { incompleteJobs -= job if (hasCompleted) processingEndTime = System.currentTimeMillis() } - def hasStarted() = (processingStartTime > 0) + def hasStarted = processingStartTime > 0 - def hasCompleted() = incompleteJobs.isEmpty + def hasCompleted = incompleteJobs.isEmpty // Time taken to process all the jobs from the time they started processing // (i.e. not including the time they wait in the streaming scheduler queue) @@ -57,7 +57,7 @@ case class JobSet(time: Time, jobs: Seq[Job]) { processingEndTime - time.milliseconds } - def toBatchInfo(): BatchInfo = { + def toBatchInfo: BatchInfo = { new BatchInfo( time, submissionTime, diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala index 75f724464348c107f32498d1457b78c9b8f97c61..0d9733fa69a12a67f386fa736f379aebfe06619c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala @@ -19,8 +19,7 @@ package org.apache.spark.streaming.scheduler import org.apache.spark.streaming.dstream.{NetworkInputDStream, NetworkReceiver} import org.apache.spark.streaming.dstream.{StopReceiver, ReportBlock, ReportError} -import org.apache.spark.Logging -import org.apache.spark.SparkEnv +import org.apache.spark.{SparkException, Logging, SparkEnv} import org.apache.spark.SparkContext._ import scala.collection.mutable.HashMap @@ -32,6 +31,7 @@ import akka.pattern.ask import akka.dispatch._ import org.apache.spark.storage.BlockId import org.apache.spark.streaming.{Time, StreamingContext} +import org.apache.spark.util.AkkaUtils private[streaming] sealed trait NetworkInputTrackerMessage private[streaming] case class RegisterReceiver(streamId: Int, receiverActor: ActorRef) extends NetworkInputTrackerMessage @@ -39,33 +39,47 @@ private[streaming] case class AddBlocks(streamId: Int, blockIds: Seq[BlockId], m private[streaming] case class DeregisterReceiver(streamId: Int, msg: String) extends NetworkInputTrackerMessage /** - * This class manages the execution of the receivers of NetworkInputDStreams. + * This class manages the execution of the receivers of NetworkInputDStreams. Instance of + * this class must be created after all input streams have been added and StreamingContext.start() + * has been called because it needs the final set of input streams at the time of instantiation. */ private[streaming] -class NetworkInputTracker( - @transient ssc: StreamingContext, - @transient networkInputStreams: Array[NetworkInputDStream[_]]) - extends Logging { +class NetworkInputTracker(ssc: StreamingContext) extends Logging { + val networkInputStreams = ssc.graph.getNetworkInputStreams() val networkInputStreamMap = Map(networkInputStreams.map(x => (x.id, x)): _*) val receiverExecutor = new ReceiverExecutor() val receiverInfo = new HashMap[Int, ActorRef] val receivedBlockIds = new HashMap[Int, Queue[BlockId]] - val timeout = 5000.milliseconds + val timeout = AkkaUtils.askTimeout(ssc.conf) + + // actor is created when generator starts. + // This not being null means the tracker has been started and not stopped + var actor: ActorRef = null var currentTime: Time = null /** Start the actor and receiver execution thread. */ def start() { - ssc.env.actorSystem.actorOf(Props(new NetworkInputTrackerActor), "NetworkInputTracker") - receiverExecutor.start() + if (actor != null) { + throw new SparkException("NetworkInputTracker already started") + } + + if (!networkInputStreams.isEmpty) { + actor = ssc.env.actorSystem.actorOf(Props(new NetworkInputTrackerActor), "NetworkInputTracker") + receiverExecutor.start() + logInfo("NetworkInputTracker started") + } } /** Stop the receiver execution thread. */ def stop() { - // TODO: stop the actor as well - receiverExecutor.interrupt() - receiverExecutor.stopReceivers() + if (!networkInputStreams.isEmpty && actor != null) { + receiverExecutor.interrupt() + receiverExecutor.stopReceivers() + ssc.env.actorSystem.stop(actor) + logInfo("NetworkInputTracker stopped") + } } /** Return all the blocks received from a receiver. */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala index 36225e190cd7917502f23debf6a7b9c77b14743e..461ea3506477f8fcfcb6e6cadb1141bf3bf15473 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala @@ -24,9 +24,10 @@ import org.apache.spark.util.Distribution sealed trait StreamingListenerEvent case class StreamingListenerBatchCompleted(batchInfo: BatchInfo) extends StreamingListenerEvent - case class StreamingListenerBatchStarted(batchInfo: BatchInfo) extends StreamingListenerEvent +/** An event used in the listener to shutdown the listener daemon thread. */ +private[scheduler] case object StreamingListenerShutdown extends StreamingListenerEvent /** * A listener interface for receiving information about an ongoing streaming diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala index 110a20f282f110879ad7836399f2f9e3784a1ac1..3063cf10a39f32fa1f9545c1ab2cdbb350db2529 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala @@ -31,7 +31,7 @@ private[spark] class StreamingListenerBus() extends Logging { private val eventQueue = new LinkedBlockingQueue[StreamingListenerEvent](EVENT_QUEUE_CAPACITY) private var queueFullErrorMessageLogged = false - new Thread("StreamingListenerBus") { + val listenerThread = new Thread("StreamingListenerBus") { setDaemon(true) override def run() { while (true) { @@ -41,11 +41,18 @@ private[spark] class StreamingListenerBus() extends Logging { listeners.foreach(_.onBatchStarted(batchStarted)) case batchCompleted: StreamingListenerBatchCompleted => listeners.foreach(_.onBatchCompleted(batchCompleted)) + case StreamingListenerShutdown => + // Get out of the while loop and shutdown the daemon thread + return case _ => } } } - }.start() + } + + def start() { + listenerThread.start() + } def addListener(listener: StreamingListener) { listeners += listener @@ -54,9 +61,9 @@ private[spark] class StreamingListenerBus() extends Logging { def post(event: StreamingListenerEvent) { val eventAdded = eventQueue.offer(event) if (!eventAdded && !queueFullErrorMessageLogged) { - logError("Dropping SparkListenerEvent because no remaining room in event queue. " + - "This likely means one of the SparkListeners is too slow and cannot keep up with the " + - "rate at which tasks are being started by the scheduler.") + logError("Dropping StreamingListenerEvent because no remaining room in event queue. " + + "This likely means one of the StreamingListeners is too slow and cannot keep up with the " + + "rate at which events are being started by the scheduler.") queueFullErrorMessageLogged = true } } @@ -68,7 +75,7 @@ private[spark] class StreamingListenerBus() extends Logging { */ def waitUntilEmpty(timeoutMillis: Int): Boolean = { val finishTime = System.currentTimeMillis + timeoutMillis - while (!eventQueue.isEmpty()) { + while (!eventQueue.isEmpty) { if (System.currentTimeMillis > finishTime) { return false } @@ -76,6 +83,8 @@ private[spark] class StreamingListenerBus() extends Logging { * add overhead in the general case. */ Thread.sleep(10) } - return true + true } + + def stop(): Unit = post(StreamingListenerShutdown) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala index f67bb2f6ac51a5bf90fb50ba10ac296ed6e7285f..c3a849d2769a72c62230ff6a8bdda4fa32384193 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala @@ -66,7 +66,7 @@ class SystemClock() extends Clock { } Thread.sleep(sleepTime) } - return -1 + -1 } } @@ -96,6 +96,6 @@ class ManualClock() extends Clock { this.wait(100) } } - return currentTime() + currentTime() } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala index 1559f7a9f7ac00a917cc742bbc6f9287270e9e20..be67af3a6466a6a65a875576a3fbfe93b820d7a2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.util import org.apache.spark.Logging import org.apache.spark.rdd.RDD import org.apache.spark.streaming._ -import org.apache.spark.streaming.dstream.ForEachDStream +import org.apache.spark.streaming.dstream.{DStream, ForEachDStream} import StreamingContext._ import scala.util.Random @@ -42,6 +42,7 @@ object MasterFailureTest extends Logging { @volatile var killed = false @volatile var killCount = 0 + @volatile var setupCalled = false def main(args: Array[String]) { if (args.size < 2) { @@ -131,8 +132,26 @@ object MasterFailureTest extends Logging { // Just making sure that the expected output does not have duplicates assert(expectedOutput.distinct.toSet == expectedOutput.toSet) + // Reset all state + reset() + + // Create the directories for this test + val uuid = UUID.randomUUID().toString + val rootDir = new Path(directory, uuid) + val fs = rootDir.getFileSystem(new Configuration()) + val checkpointDir = new Path(rootDir, "checkpoint") + val testDir = new Path(rootDir, "test") + fs.mkdirs(checkpointDir) + fs.mkdirs(testDir) + // Setup the stream computation with the given operation - val (ssc, checkpointDir, testDir) = setupStreams(directory, batchDuration, operation) + val ssc = StreamingContext.getOrCreate(checkpointDir.toString, () => { + setupStreams(batchDuration, operation, checkpointDir, testDir) + }) + + // Check if setupStream was called to create StreamingContext + // (and not created from checkpoint file) + assert(setupCalled, "Setup was not called in the first call to StreamingContext.getOrCreate") // Start generating files in the a different thread val fileGeneratingThread = new FileGeneratingThread(input, testDir, batchDuration.milliseconds) @@ -144,9 +163,7 @@ object MasterFailureTest extends Logging { val maxTimeToRun = expectedOutput.size * batchDuration.milliseconds * 2 val mergedOutput = runStreams(ssc, lastExpectedOutput, maxTimeToRun) - // Delete directories fileGeneratingThread.join() - val fs = checkpointDir.getFileSystem(new Configuration()) fs.delete(checkpointDir, true) fs.delete(testDir, true) logInfo("Finished test after " + killCount + " failures") @@ -159,32 +176,23 @@ object MasterFailureTest extends Logging { * files should be written for testing. */ private def setupStreams[T: ClassTag]( - directory: String, batchDuration: Duration, - operation: DStream[String] => DStream[T] - ): (StreamingContext, Path, Path) = { - // Reset all state - reset() - - // Create the directories for this test - val uuid = UUID.randomUUID().toString - val rootDir = new Path(directory, uuid) - val fs = rootDir.getFileSystem(new Configuration()) - val checkpointDir = new Path(rootDir, "checkpoint") - val testDir = new Path(rootDir, "test") - fs.mkdirs(checkpointDir) - fs.mkdirs(testDir) + operation: DStream[String] => DStream[T], + checkpointDir: Path, + testDir: Path + ): StreamingContext = { + // Mark that setup was called + setupCalled = true // Setup the streaming computation with the given operation System.clearProperty("spark.driver.port") - System.clearProperty("spark.hostPort") - var ssc = new StreamingContext("local[4]", "MasterFailureTest", batchDuration, null, Nil, Map()) + val ssc = new StreamingContext("local[4]", "MasterFailureTest", batchDuration, null, Nil, Map()) ssc.checkpoint(checkpointDir.toString) val inputStream = ssc.textFileStream(testDir.toString) val operatedStream = operation(inputStream) val outputStream = new TestOutputStream(operatedStream) ssc.registerOutputStream(outputStream) - (ssc, checkpointDir, testDir) + ssc } @@ -204,7 +212,7 @@ object MasterFailureTest extends Logging { var isTimedOut = false val mergedOutput = new ArrayBuffer[T]() val checkpointDir = ssc.checkpointDir - var batchDuration = ssc.graph.batchDuration + val batchDuration = ssc.graph.batchDuration while(!isLastOutputGenerated && !isTimedOut) { // Get the output buffer @@ -224,7 +232,6 @@ object MasterFailureTest extends Logging { // (iii) Its not timed out yet System.clearProperty("spark.streaming.clock") System.clearProperty("spark.driver.port") - System.clearProperty("spark.hostPort") ssc.start() val startTime = System.currentTimeMillis() while (!killed && !isLastOutputGenerated && !isTimedOut) { @@ -261,7 +268,10 @@ object MasterFailureTest extends Logging { ) Thread.sleep(sleepTime) // Recreate the streaming context from checkpoint - ssc = new StreamingContext(checkpointDir) + ssc = StreamingContext.getOrCreate(checkpointDir, () => { + throw new Exception("Trying to create new context when it " + + "should be reading from checkpoint file") + }) } } mergedOutput @@ -297,6 +307,7 @@ object MasterFailureTest extends Logging { private def reset() { killed = false killCount = 0 + setupCalled = false } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala index 4e6ce6eabd7ba2fd3753b9f1b3ea72a17e6616a9..5b6c048a396200eb054c1f2408460c01ae8a39fc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala @@ -90,7 +90,7 @@ object RawTextHelper { } } } - return taken.toIterator + taken.toIterator } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala index d644240405caa478f6b838473a8d7f7475615942..559c2473851b30bf73f9d376126f846edf826c6d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala @@ -20,17 +20,7 @@ package org.apache.spark.streaming.util private[streaming] class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) => Unit) { - private val minPollTime = 25L - - private val pollTime = { - if (period / 10.0 > minPollTime) { - (period / 10.0).toLong - } else { - minPollTime - } - } - - private val thread = new Thread() { + private val thread = new Thread("RecurringTimer") { override def run() { loop } } @@ -66,7 +56,6 @@ class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) => callback(nextTime) nextTime += period } - } catch { case e: InterruptedException => } diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index 0d2145da9a1a687794ba79e77cea01b8ed030b9e..8b7d7709bf2c57f750267ebc3acbdf15e7813766 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -28,6 +28,7 @@ import java.util.*; import com.google.common.base.Optional; import com.google.common.collect.Lists; import com.google.common.io.Files; +import com.google.common.collect.Sets; import org.apache.spark.SparkConf; import org.apache.spark.HashPartitioner; @@ -441,13 +442,13 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa new Tuple2<String, String>("new york", "islanders"))); - List<List<Tuple2<String, Tuple2<String, String>>>> expected = Arrays.asList( - Arrays.asList( + List<HashSet<Tuple2<String, Tuple2<String, String>>>> expected = Arrays.asList( + Sets.newHashSet( new Tuple2<String, Tuple2<String, String>>("california", new Tuple2<String, String>("dodgers", "giants")), new Tuple2<String, Tuple2<String, String>>("new york", - new Tuple2<String, String>("yankees", "mets"))), - Arrays.asList( + new Tuple2<String, String>("yankees", "mets"))), + Sets.newHashSet( new Tuple2<String, Tuple2<String, String>>("california", new Tuple2<String, String>("sharks", "ducks")), new Tuple2<String, Tuple2<String, String>>("new york", @@ -482,8 +483,12 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa JavaTestUtils.attachTestOutputStream(joined); List<List<Tuple2<String, Tuple2<String, String>>>> result = JavaTestUtils.runStreams(ssc, 2, 2); + List<HashSet<Tuple2<String, Tuple2<String, String>>>> unorderedResult = Lists.newArrayList(); + for (List<Tuple2<String, Tuple2<String, String>>> res: result) { + unorderedResult.add(Sets.newHashSet(res)); + } - Assert.assertEquals(expected, result); + Assert.assertEquals(expected, unorderedResult); } @@ -1196,15 +1201,15 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa Arrays.asList("hello", "moon"), Arrays.asList("hello")); - List<List<Tuple2<String, Long>>> expected = Arrays.asList( - Arrays.asList( + List<HashSet<Tuple2<String, Long>>> expected = Arrays.asList( + Sets.newHashSet( new Tuple2<String, Long>("hello", 1L), new Tuple2<String, Long>("world", 1L)), - Arrays.asList( + Sets.newHashSet( new Tuple2<String, Long>("hello", 2L), new Tuple2<String, Long>("world", 1L), new Tuple2<String, Long>("moon", 1L)), - Arrays.asList( + Sets.newHashSet( new Tuple2<String, Long>("hello", 2L), new Tuple2<String, Long>("moon", 1L))); @@ -1214,8 +1219,12 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa stream.countByValueAndWindow(new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(counted); List<List<Tuple2<String, Long>>> result = JavaTestUtils.runStreams(ssc, 3, 3); + List<HashSet<Tuple2<String, Long>>> unorderedResult = Lists.newArrayList(); + for (List<Tuple2<String, Long>> res: result) { + unorderedResult.add(Sets.newHashSet(res)); + } - Assert.assertEquals(expected, result); + Assert.assertEquals(expected, unorderedResult); } @Test diff --git a/streaming/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/streaming/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java index 34bee568859f9736f66e2428305f3b07e1fa5457..849bbf1299182cdb38979c4fd60ba2432f2129af 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java +++ b/streaming/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -28,7 +28,6 @@ public abstract class LocalJavaStreamingContext { @Before public void setUp() { System.clearProperty("spark.driver.port"); - System.clearProperty("spark.hostPort"); System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); ssc.checkpoint("checkpoint"); @@ -41,6 +40,5 @@ public abstract class LocalJavaStreamingContext { // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown System.clearProperty("spark.driver.port"); - System.clearProperty("spark.hostPort"); } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index ee6b433d1f1fa2d05a33991cad2d7cdf0b81a7c5..7037aae234208a3bc387829287513dabc4a9f6c9 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.SparkContext._ import util.ManualClock import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.streaming.dstream.DStream class BasicOperationsSuite extends TestSuiteBase { test("map") { @@ -375,15 +376,11 @@ class BasicOperationsSuite extends TestSuiteBase { } test("slice") { - val conf2 = new SparkConf() - .setMaster("local[2]") - .setAppName("BasicOperationsSuite") - .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") - val ssc = new StreamingContext(new SparkContext(conf2), Seconds(1)) + val ssc = new StreamingContext(conf, Seconds(1)) val input = Seq(Seq(1), Seq(2), Seq(3), Seq(4)) val stream = new TestInputStream[Int](ssc, input, 2) ssc.registerInputStream(stream) - stream.foreach(_ => {}) // Dummy output stream + stream.foreachRDD(_ => {}) // Dummy output stream ssc.start() Thread.sleep(2000) def getInputFromSlice(fromMillis: Long, toMillis: Long) = { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 8dc80ac2edfa060f8f572adfb705ad1ad63e7139..0c68c44ddb6da9960b3869b941b9de650134fbc0 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -26,8 +26,10 @@ import com.google.common.io.Files import org.apache.hadoop.fs.{Path, FileSystem} import org.apache.hadoop.conf.Configuration import org.apache.spark.streaming.StreamingContext._ -import org.apache.spark.streaming.dstream.FileInputDStream +import org.apache.spark.streaming.dstream.{DStream, FileInputDStream} import org.apache.spark.streaming.util.ManualClock +import org.apache.spark.util.Utils +import org.apache.spark.SparkConf /** * This test suites tests the checkpointing functionality of DStreams - @@ -84,9 +86,9 @@ class CheckpointSuite extends TestSuiteBase { ssc.start() advanceTimeWithRealDelay(ssc, firstNumBatches) logInfo("Checkpoint data of state stream = \n" + stateStream.checkpointData) - assert(!stateStream.checkpointData.checkpointFiles.isEmpty, + assert(!stateStream.checkpointData.currentCheckpointFiles.isEmpty, "No checkpointed RDDs in state stream before first failure") - stateStream.checkpointData.checkpointFiles.foreach { + stateStream.checkpointData.currentCheckpointFiles.foreach { case (time, file) => { assert(fs.exists(new Path(file)), "Checkpoint file '" + file +"' for time " + time + " for state stream before first failure does not exist") @@ -95,7 +97,7 @@ class CheckpointSuite extends TestSuiteBase { // Run till a further time such that previous checkpoint files in the stream would be deleted // and check whether the earlier checkpoint files are deleted - val checkpointFiles = stateStream.checkpointData.checkpointFiles.map(x => new File(x._2)) + val checkpointFiles = stateStream.checkpointData.currentCheckpointFiles.map(x => new File(x._2)) advanceTimeWithRealDelay(ssc, secondNumBatches) checkpointFiles.foreach(file => assert(!file.exists, "Checkpoint file '" + file + "' was not deleted")) @@ -114,9 +116,9 @@ class CheckpointSuite extends TestSuiteBase { // is present in the checkpoint data or not ssc.start() advanceTimeWithRealDelay(ssc, 1) - assert(!stateStream.checkpointData.checkpointFiles.isEmpty, + assert(!stateStream.checkpointData.currentCheckpointFiles.isEmpty, "No checkpointed RDDs in state stream before second failure") - stateStream.checkpointData.checkpointFiles.foreach { + stateStream.checkpointData.currentCheckpointFiles.foreach { case (time, file) => { assert(fs.exists(new Path(file)), "Checkpoint file '" + file +"' for time " + time + " for state stream before seconds failure does not exist") @@ -142,6 +144,26 @@ class CheckpointSuite extends TestSuiteBase { ssc = null } + // This tests whether spark conf persists through checkpoints, and certain + // configs gets scrubbed + test("persistence of conf through checkpoints") { + val key = "spark.mykey" + val value = "myvalue" + System.setProperty(key, value) + ssc = new StreamingContext(master, framework, batchDuration) + val cp = new Checkpoint(ssc, Time(1000)) + assert(!cp.sparkConf.contains("spark.driver.host")) + assert(!cp.sparkConf.contains("spark.driver.port")) + assert(!cp.sparkConf.contains("spark.hostPort")) + assert(cp.sparkConf.get(key) === value) + ssc.stop() + val newCp = Utils.deserialize[Checkpoint](Utils.serialize(cp)) + assert(!newCp.sparkConf.contains("spark.driver.host")) + assert(!newCp.sparkConf.contains("spark.driver.port")) + assert(!newCp.sparkConf.contains("spark.hostPort")) + assert(newCp.sparkConf.get(key) === value) + } + // This tests whether the systm can recover from a master failure with simple // non-stateful operations. This assumes as reliable, replayable input @@ -336,7 +358,6 @@ class CheckpointSuite extends TestSuiteBase { ) ssc = new StreamingContext(checkpointDir) System.clearProperty("spark.driver.port") - System.clearProperty("spark.hostPort") ssc.start() val outputNew = advanceTimeWithRealDelay[V](ssc, nextNumBatches) // the first element will be re-processed data of the last batch before restart diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..f7f3346f81db58607903a92527816f55ae48a8fe --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import org.scalatest.{FunSuite, BeforeAndAfter} +import org.scalatest.exceptions.TestFailedDueToTimeoutException +import org.scalatest.concurrent.Timeouts +import org.scalatest.time.SpanSugar._ +import org.apache.spark.{SparkException, SparkConf, SparkContext} +import org.apache.spark.util.{Utils, MetadataCleaner} +import org.apache.spark.streaming.dstream.DStream + +class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts { + + val master = "local[2]" + val appName = this.getClass.getSimpleName + val batchDuration = Seconds(1) + val sparkHome = "someDir" + val envPair = "key" -> "value" + val ttl = StreamingContext.DEFAULT_CLEANER_TTL + 100 + + var sc: SparkContext = null + var ssc: StreamingContext = null + + before { + System.clearProperty("spark.cleaner.ttl") + } + + after { + if (ssc != null) { + ssc.stop() + ssc = null + } + if (sc != null) { + sc.stop() + sc = null + } + } + + test("from no conf constructor") { + ssc = new StreamingContext(master, appName, batchDuration) + assert(ssc.sparkContext.conf.get("spark.master") === master) + assert(ssc.sparkContext.conf.get("spark.app.name") === appName) + assert(MetadataCleaner.getDelaySeconds(ssc.sparkContext.conf) === + StreamingContext.DEFAULT_CLEANER_TTL) + } + + test("from no conf + spark home") { + ssc = new StreamingContext(master, appName, batchDuration, sparkHome, Nil) + assert(ssc.conf.get("spark.home") === sparkHome) + assert(MetadataCleaner.getDelaySeconds(ssc.sparkContext.conf) === + StreamingContext.DEFAULT_CLEANER_TTL) + } + + test("from no conf + spark home + env") { + ssc = new StreamingContext(master, appName, batchDuration, + sparkHome, Nil, Map(envPair)) + assert(ssc.conf.getExecutorEnv.exists(_ == envPair)) + assert(MetadataCleaner.getDelaySeconds(ssc.sparkContext.conf) === + StreamingContext.DEFAULT_CLEANER_TTL) + } + + test("from conf without ttl set") { + val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) + ssc = new StreamingContext(myConf, batchDuration) + assert(MetadataCleaner.getDelaySeconds(ssc.conf) === + StreamingContext.DEFAULT_CLEANER_TTL) + } + + test("from conf with ttl set") { + val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) + myConf.set("spark.cleaner.ttl", ttl.toString) + ssc = new StreamingContext(myConf, batchDuration) + assert(ssc.conf.getInt("spark.cleaner.ttl", -1) === ttl) + } + + test("from existing SparkContext without ttl set") { + sc = new SparkContext(master, appName) + val exception = intercept[SparkException] { + ssc = new StreamingContext(sc, batchDuration) + } + assert(exception.getMessage.contains("ttl")) + } + + test("from existing SparkContext with ttl set") { + val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) + myConf.set("spark.cleaner.ttl", ttl.toString) + ssc = new StreamingContext(myConf, batchDuration) + assert(ssc.conf.getInt("spark.cleaner.ttl", -1) === ttl) + } + + test("from checkpoint") { + val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) + myConf.set("spark.cleaner.ttl", ttl.toString) + val ssc1 = new StreamingContext(myConf, batchDuration) + val cp = new Checkpoint(ssc1, Time(1000)) + assert(MetadataCleaner.getDelaySeconds(cp.sparkConf) === ttl) + ssc1.stop() + val newCp = Utils.deserialize[Checkpoint](Utils.serialize(cp)) + assert(MetadataCleaner.getDelaySeconds(newCp.sparkConf) === ttl) + ssc = new StreamingContext(null, cp, null) + assert(MetadataCleaner.getDelaySeconds(ssc.conf) === ttl) + } + + test("start multiple times") { + ssc = new StreamingContext(master, appName, batchDuration) + addInputStream(ssc).register + + ssc.start() + intercept[SparkException] { + ssc.start() + } + } + + test("stop multiple times") { + ssc = new StreamingContext(master, appName, batchDuration) + addInputStream(ssc).register + ssc.start() + ssc.stop() + ssc.stop() + ssc = null + } + + test("stop only streaming context") { + ssc = new StreamingContext(master, appName, batchDuration) + sc = ssc.sparkContext + addInputStream(ssc).register + ssc.start() + ssc.stop(false) + ssc = null + assert(sc.makeRDD(1 to 100).collect().size === 100) + ssc = new StreamingContext(sc, batchDuration) + } + + test("awaitTermination") { + ssc = new StreamingContext(master, appName, batchDuration) + val inputStream = addInputStream(ssc) + inputStream.map(x => x).register + + // test whether start() blocks indefinitely or not + failAfter(2000 millis) { + ssc.start() + } + + // test whether waitForStop() exits after give amount of time + failAfter(1000 millis) { + ssc.awaitTermination(500) + } + + // test whether waitForStop() does not exit if not time is given + val exception = intercept[Exception] { + failAfter(1000 millis) { + ssc.awaitTermination() + throw new Exception("Did not wait for stop") + } + } + assert(exception.isInstanceOf[TestFailedDueToTimeoutException], "Did not wait for stop") + + // test whether wait exits if context is stopped + failAfter(10000 millis) { // 10 seconds because spark takes a long time to shutdown + new Thread() { + override def run { + Thread.sleep(500) + ssc.stop() + } + }.start() + ssc.awaitTermination() + } + } + + test("awaitTermination with error in task") { + ssc = new StreamingContext(master, appName, batchDuration) + val inputStream = addInputStream(ssc) + inputStream.map(x => { throw new TestException("error in map task"); x}) + .foreachRDD(_.count) + + val exception = intercept[Exception] { + ssc.start() + ssc.awaitTermination(5000) + } + assert(exception.getMessage.contains("map task"), "Expected exception not thrown") + } + + test("awaitTermination with error in job generation") { + ssc = new StreamingContext(master, appName, batchDuration) + val inputStream = addInputStream(ssc) + + inputStream.transform(rdd => { throw new TestException("error in transform"); rdd }).register + val exception = intercept[TestException] { + ssc.start() + ssc.awaitTermination(5000) + } + assert(exception.getMessage.contains("transform"), "Expected exception not thrown") + } + + def addInputStream(s: StreamingContext): DStream[Int] = { + val input = (1 to 100).map(i => (1 to i)) + val inputStream = new TestInputStream(s, input, 1) + s.registerInputStream(inputStream) + inputStream + } +} + +class TestException(msg: String) extends Exception(msg) \ No newline at end of file diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index fa6414209605405e2a70834409bb3851e10b6422..9e0f2c900e8ba699c802d202839b0edeb2525ec4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.streaming import org.apache.spark.streaming.scheduler._ import scala.collection.mutable.ArrayBuffer import org.scalatest.matchers.ShouldMatchers +import org.apache.spark.streaming.dstream.DStream class StreamingListenerSuite extends TestSuiteBase with ShouldMatchers { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index b20d02f99681e87879c31a59778a9fc24388b6ab..535e5bd1f1f2e7f8786c815a2f933274e88de2f1 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming -import org.apache.spark.streaming.dstream.{InputDStream, ForEachDStream} +import org.apache.spark.streaming.dstream.{DStream, InputDStream, ForEachDStream} import org.apache.spark.streaming.util.ManualClock import scala.collection.mutable.ArrayBuffer @@ -137,7 +137,7 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { val conf = new SparkConf() .setMaster(master) .setAppName(framework) - .set("spark.cleaner.ttl", "3600") + .set("spark.cleaner.ttl", StreamingContext.DEFAULT_CLEANER_TTL.toString) // Default before function for any streaming test suite. Override this // if you want to add your stuff to "before" (i.e., don't call before { } ) @@ -156,7 +156,6 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { def afterFunction() { // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown System.clearProperty("spark.driver.port") - System.clearProperty("spark.hostPort") } before(beforeFunction) @@ -273,10 +272,11 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { val startTime = System.currentTimeMillis() while (output.size < numExpectedOutput && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { logInfo("output.size = " + output.size + ", numExpectedOutput = " + numExpectedOutput) - Thread.sleep(10) + ssc.awaitTermination(50) } val timeTaken = System.currentTimeMillis() - startTime - + logInfo("Output generated in " + timeTaken + " milliseconds") + output.foreach(x => logInfo("[" + x.mkString(",") + "]")) assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") assert(output.size === numExpectedOutput, "Unexpected number of outputs generated") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala index c39abfc21b3ba219a08c6bf6e514db052dbcd7f0..471c99fab4682513c6216e5e406e093597e766bc 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala @@ -18,6 +18,8 @@ package org.apache.spark.streaming import org.apache.spark.streaming.StreamingContext._ +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.storage.StorageLevel class WindowOperationsSuite extends TestSuiteBase { @@ -143,6 +145,19 @@ class WindowOperationsSuite extends TestSuiteBase { Seconds(3) ) + test("window - persistence level") { + val input = Seq( Seq(0), Seq(1), Seq(2), Seq(3), Seq(4), Seq(5)) + val ssc = new StreamingContext(conf, batchDuration) + val inputStream = new TestInputStream[Int](ssc, input, 1) + val windowStream1 = inputStream.window(batchDuration * 2) + assert(windowStream1.storageLevel === StorageLevel.NONE) + assert(inputStream.storageLevel === StorageLevel.MEMORY_ONLY_SER) + windowStream1.persist(StorageLevel.MEMORY_ONLY) + assert(windowStream1.storageLevel === StorageLevel.NONE) + assert(inputStream.storageLevel === StorageLevel.MEMORY_ONLY) + ssc.stop() + } + // Testing naive reduceByKeyAndWindow (without invertible function) testReduceByKeyAndWindow( diff --git a/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala b/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala index f670f65bf5b38ca6c56703cc955c2f2b9236a751..4886cd6ea8a64a1269dfd1f6fdb229d074b6f583 100644 --- a/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala +++ b/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala @@ -24,8 +24,9 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark._ import org.apache.spark.api.java._ import org.apache.spark.rdd.{RDD, DoubleRDDFunctions, PairRDDFunctions, OrderedRDDFunctions} -import org.apache.spark.streaming.{PairDStreamFunctions, DStream, StreamingContext} +import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaDStream, JavaStreamingContext} +import org.apache.spark.streaming.dstream.{DStream, PairDStreamFunctions} private[spark] abstract class SparkType(val name: String) @@ -147,7 +148,7 @@ object JavaAPICompletenessChecker { } else { ParameterizedType(classOf[JavaRDD[_]].getName, parameters.map(applySubs)) } - case "org.apache.spark.streaming.DStream" => + case "org.apache.spark.streaming.dstream.DStream" => if (parameters(0).name == classOf[Tuple2[_, _]].getName) { val tupleParams = parameters(0).asInstanceOf[ParameterizedType].parameters.map(applySubs) @@ -248,30 +249,29 @@ object JavaAPICompletenessChecker { "org.apache.spark.SparkContext.getSparkHome", "org.apache.spark.SparkContext.executorMemoryRequested", "org.apache.spark.SparkContext.getExecutorStorageStatus", - "org.apache.spark.streaming.DStream.generatedRDDs", - "org.apache.spark.streaming.DStream.zeroTime", - "org.apache.spark.streaming.DStream.rememberDuration", - "org.apache.spark.streaming.DStream.storageLevel", - "org.apache.spark.streaming.DStream.mustCheckpoint", - "org.apache.spark.streaming.DStream.checkpointDuration", - "org.apache.spark.streaming.DStream.checkpointData", - "org.apache.spark.streaming.DStream.graph", - "org.apache.spark.streaming.DStream.isInitialized", - "org.apache.spark.streaming.DStream.parentRememberDuration", - "org.apache.spark.streaming.DStream.initialize", - "org.apache.spark.streaming.DStream.validate", - "org.apache.spark.streaming.DStream.setContext", - "org.apache.spark.streaming.DStream.setGraph", - "org.apache.spark.streaming.DStream.remember", - "org.apache.spark.streaming.DStream.getOrCompute", - "org.apache.spark.streaming.DStream.generateJob", - "org.apache.spark.streaming.DStream.clearOldMetadata", - "org.apache.spark.streaming.DStream.addMetadata", - "org.apache.spark.streaming.DStream.updateCheckpointData", - "org.apache.spark.streaming.DStream.restoreCheckpointData", - "org.apache.spark.streaming.DStream.isTimeValid", + "org.apache.spark.streaming.dstream.DStream.generatedRDDs", + "org.apache.spark.streaming.dstream.DStream.zeroTime", + "org.apache.spark.streaming.dstream.DStream.rememberDuration", + "org.apache.spark.streaming.dstream.DStream.storageLevel", + "org.apache.spark.streaming.dstream.DStream.mustCheckpoint", + "org.apache.spark.streaming.dstream.DStream.checkpointDuration", + "org.apache.spark.streaming.dstream.DStream.checkpointData", + "org.apache.spark.streaming.dstream.DStream.graph", + "org.apache.spark.streaming.dstream.DStream.isInitialized", + "org.apache.spark.streaming.dstream.DStream.parentRememberDuration", + "org.apache.spark.streaming.dstream.DStream.initialize", + "org.apache.spark.streaming.dstream.DStream.validate", + "org.apache.spark.streaming.dstream.DStream.setContext", + "org.apache.spark.streaming.dstream.DStream.setGraph", + "org.apache.spark.streaming.dstream.DStream.remember", + "org.apache.spark.streaming.dstream.DStream.getOrCompute", + "org.apache.spark.streaming.dstream.DStream.generateJob", + "org.apache.spark.streaming.dstream.DStream.clearOldMetadata", + "org.apache.spark.streaming.dstream.DStream.addMetadata", + "org.apache.spark.streaming.dstream.DStream.updateCheckpointData", + "org.apache.spark.streaming.dstream.DStream.restoreCheckpointData", + "org.apache.spark.streaming.dstream.DStream.isTimeValid", "org.apache.spark.streaming.StreamingContext.nextNetworkInputStreamId", - "org.apache.spark.streaming.StreamingContext.networkInputTracker", "org.apache.spark.streaming.StreamingContext.checkpointDir", "org.apache.spark.streaming.StreamingContext.checkpointDuration", "org.apache.spark.streaming.StreamingContext.receiverJobThread", diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 2bb11e54c549af037fc75a2dd92300e3588828b4..2e46d750c4a3801f5d4cc381dcd2f8dd759561a3 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -127,14 +127,13 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, // local dirs, so lets check both. We assume one of the 2 is set. // LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X val localDirs = Option(System.getenv("YARN_LOCAL_DIRS")) - .getOrElse(Option(System.getenv("LOCAL_DIRS")) - .getOrElse("")) - - if (localDirs.isEmpty()) { - throw new Exception("Yarn Local dirs can't be empty") + .orElse(Option(System.getenv("LOCAL_DIRS"))) + + localDirs match { + case None => throw new Exception("Yarn Local dirs can't be empty") + case Some(l) => l } - localDirs - } + } private def getApplicationAttemptId(): ApplicationAttemptId = { val envs = System.getenv() diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 23781ea35c670803b4625bff4087866bdaf269ea..e56bc02897216ae350e4b299f55d055b428c253d 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -158,7 +158,7 @@ class Client(args: ClientArguments, conf: Configuration, sparkConf: SparkConf) val appContext = Records.newRecord(classOf[ApplicationSubmissionContext]) appContext.setApplicationId(appId) appContext.setApplicationName(args.appName) - return appContext + appContext } /** See if two file systems are the same or not. */ @@ -193,7 +193,8 @@ class Client(args: ClientArguments, conf: Configuration, sparkConf: SparkConf) if (srcUri.getPort() != dstUri.getPort()) { return false } - return true + + true } /** Copy the file into HDFS if needed. */ @@ -299,7 +300,7 @@ class Client(args: ClientArguments, conf: Configuration, sparkConf: SparkConf) } UserGroupInformation.getCurrentUser().addCredentials(credentials) - return localResources + localResources } def setupLaunchEnv( diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala index ddfec1a4ac6728e8dd0030a6b5fcf164a1dff42c..9fe4d64a0fca098fbaa4f35b53e438146e22e8a9 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala @@ -76,6 +76,10 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar def run() { + // Setup the directories so things go to yarn approved directories rather + // then user specified and /tmp. + System.setProperty("spark.local.dir", getLocalDirs()) + appAttemptId = getApplicationAttemptId() resourceManager = registerWithResourceManager() val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster() @@ -103,10 +107,12 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar // ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapse. val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000) - // must be <= timeoutInterval/ 2. - // On other hand, also ensure that we are reasonably responsive without causing too many requests to RM. - // so atleast 1 minute or timeoutInterval / 10 - whichever is higher. - val interval = math.min(timeoutInterval / 2, math.max(timeoutInterval/ 10, 60000L)) + // we want to be reasonably responsive without causing too many requests to RM. + val schedulerInterval = + System.getProperty("spark.yarn.scheduler.heartbeat.interval-ms", "5000").toLong + // must be <= timeoutInterval / 2. + val interval = math.min(timeoutInterval / 2, schedulerInterval) + reporterThread = launchReporterThread(interval) // Wait for the reporter thread to Finish. @@ -119,13 +125,27 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar System.exit(0) } + /** Get the Yarn approved local directories. */ + private def getLocalDirs(): String = { + // Hadoop 0.23 and 2.x have different Environment variable names for the + // local dirs, so lets check both. We assume one of the 2 is set. + // LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X + val localDirs = Option(System.getenv("YARN_LOCAL_DIRS")) + .orElse(Option(System.getenv("LOCAL_DIRS"))) + + localDirs match { + case None => throw new Exception("Yarn Local dirs can't be empty") + case Some(l) => l + } + } + private def getApplicationAttemptId(): ApplicationAttemptId = { val envs = System.getenv() val containerIdString = envs.get(ApplicationConstants.AM_CONTAINER_ID_ENV) val containerId = ConverterUtils.toContainerId(containerIdString) val appAttemptId = containerId.getApplicationAttemptId() logInfo("ApplicationAttemptId: " + appAttemptId) - return appAttemptId + appAttemptId } private def registerWithResourceManager(): AMRMProtocol = { @@ -133,7 +153,7 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar YarnConfiguration.RM_SCHEDULER_ADDRESS, YarnConfiguration.DEFAULT_RM_SCHEDULER_ADDRESS)) logInfo("Connecting to ResourceManager at " + rmAddress) - return rpc.getProxy(classOf[AMRMProtocol], rmAddress, conf).asInstanceOf[AMRMProtocol] + rpc.getProxy(classOf[AMRMProtocol], rmAddress, conf).asInstanceOf[AMRMProtocol] } private def registerApplicationMaster(): RegisterApplicationMasterResponse = { @@ -147,7 +167,7 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar appMasterRequest.setRpcPort(0) // What do we provide here ? Might make sense to expose something sensible later ? appMasterRequest.setTrackingUrl("") - return resourceManager.registerApplicationMaster(appMasterRequest) + resourceManager.registerApplicationMaster(appMasterRequest) } private def waitForSparkMaster() { @@ -220,7 +240,7 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar t.setDaemon(true) t.start() logInfo("Started progress reporter thread - sleep time : " + sleepTime) - return t + t } private def sendProgress() { diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala index 132630e5ef04cbb09bc400dacef37a1ea2b0d28d..d32cdcc879f7e5c8ef1e0cf9f3d9718de43b39cc 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala @@ -195,7 +195,7 @@ class WorkerRunnable( } logInfo("Prepared Local resources " + localResources) - return localResources + localResources } def prepareEnvironment: HashMap[String, String] = { @@ -207,7 +207,7 @@ class WorkerRunnable( Apps.setEnvFromInputString(env, System.getenv("SPARK_YARN_USER_ENV")) System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v } - return env + env } def connectToCM: ContainerManager = { @@ -226,8 +226,7 @@ class WorkerRunnable( val proxy = user .doAs(new PrivilegedExceptionAction[ContainerManager] { def run: ContainerManager = { - return rpc.getProxy(classOf[ContainerManager], - cmAddress, conf).asInstanceOf[ContainerManager] + rpc.getProxy(classOf[ContainerManager], cmAddress, conf).asInstanceOf[ContainerManager] } }) proxy diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala index 5f159b073f5372dc9e1f73347f88ae8aa48e0cbb..535abbfb7f638c15b76651ab81f0a802866529e7 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala @@ -143,7 +143,7 @@ class ClientDistributedCacheManager() extends Logging { if (isPublic(conf, uri, statCache)) { return LocalResourceVisibility.PUBLIC } - return LocalResourceVisibility.PRIVATE + LocalResourceVisibility.PRIVATE } /** @@ -161,7 +161,7 @@ class ClientDistributedCacheManager() extends Logging { if (!checkPermissionOfOther(fs, current, FsAction.READ, statCache)) { return false } - return ancestorsHaveExecutePermissions(fs, current.getParent(), statCache) + ancestorsHaveExecutePermissions(fs, current.getParent(), statCache) } /** @@ -183,7 +183,7 @@ class ClientDistributedCacheManager() extends Logging { } current = current.getParent() } - return true + true } /** @@ -203,7 +203,7 @@ class ClientDistributedCacheManager() extends Logging { if (otherAction.implies(action)) { return true } - return false + false } /** @@ -223,6 +223,6 @@ class ClientDistributedCacheManager() extends Logging { statCache.put(uri, newStat) newStat } - return stat + stat } } diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 4b1b5da048df4d801dacb24f7df18245e98735ac..22e55e0c60647978d4543d14f78444fdfa0d2e8d 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -22,6 +22,8 @@ import org.apache.spark.{SparkException, Logging, SparkContext} import org.apache.spark.deploy.yarn.{Client, ClientArguments} import org.apache.spark.scheduler.TaskSchedulerImpl +import scala.collection.mutable.ArrayBuffer + private[spark] class YarnClientSchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext) @@ -31,45 +33,47 @@ private[spark] class YarnClientSchedulerBackend( var client: Client = null var appId: ApplicationId = null + private[spark] def addArg(optionName: String, optionalParam: String, arrayBuf: ArrayBuffer[String]) { + Option(System.getenv(optionalParam)) foreach { + optParam => { + arrayBuf += (optionName, optParam) + } + } + } + override def start() { super.start() - val defalutWorkerCores = "2" - val defalutWorkerMemory = "512m" - val defaultWorkerNumber = "1" - val userJar = System.getenv("SPARK_YARN_APP_JAR") - val distFiles = System.getenv("SPARK_YARN_DIST_FILES") - var workerCores = System.getenv("SPARK_WORKER_CORES") - var workerMemory = System.getenv("SPARK_WORKER_MEMORY") - var workerNumber = System.getenv("SPARK_WORKER_INSTANCES") - if (userJar == null) throw new SparkException("env SPARK_YARN_APP_JAR is not set") - if (workerCores == null) - workerCores = defalutWorkerCores - if (workerMemory == null) - workerMemory = defalutWorkerMemory - if (workerNumber == null) - workerNumber = defaultWorkerNumber - val driverHost = conf.get("spark.driver.host") val driverPort = conf.get("spark.driver.port") val hostport = driverHost + ":" + driverPort - val argsArray = Array[String]( + val argsArrayBuf = new ArrayBuffer[String]() + argsArrayBuf += ( "--class", "notused", "--jar", userJar, "--args", hostport, - "--worker-memory", workerMemory, - "--worker-cores", workerCores, - "--num-workers", workerNumber, - "--master-class", "org.apache.spark.deploy.yarn.WorkerLauncher", - "--files", distFiles + "--master-class", "org.apache.spark.deploy.yarn.WorkerLauncher" ) - val args = new ClientArguments(argsArray, conf) + // process any optional arguments, use the defaults already defined in ClientArguments + // if things aren't specified + Map("--master-memory" -> "SPARK_MASTER_MEMORY", + "--num-workers" -> "SPARK_WORKER_INSTANCES", + "--worker-memory" -> "SPARK_WORKER_MEMORY", + "--worker-cores" -> "SPARK_WORKER_CORES", + "--queue" -> "SPARK_YARN_QUEUE", + "--name" -> "SPARK_YARN_APP_NAME", + "--files" -> "SPARK_YARN_DIST_FILES", + "--archives" -> "SPARK_YARN_DIST_ARCHIVES") + .foreach { case (optName, optParam) => addArg(optName, optParam, argsArrayBuf) } + + logDebug("ClientArguments called with: " + argsArrayBuf) + val args = new ClientArguments(argsArrayBuf.toArray, conf) client = new Client(args, conf) appId = client.runApp() waitForApp() diff --git a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala index 2941356bc55f9f85ca176c3ef0d23a1a08c6a8e5..458df4fa3cd9943fd67504b83566dcb1a1776cc7 100644 --- a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala +++ b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala @@ -42,7 +42,7 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { class MockClientDistributedCacheManager extends ClientDistributedCacheManager { override def getVisibility(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]): LocalResourceVisibility = { - return LocalResourceVisibility.PRIVATE + LocalResourceVisibility.PRIVATE } } diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 69ae14ce8385cfb3b978481d90b2681d02a2ae80..4b777d5fa7a283e78744923dce484b13b1cc1431 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -116,14 +116,13 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, // local dirs, so lets check both. We assume one of the 2 is set. // LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X val localDirs = Option(System.getenv("YARN_LOCAL_DIRS")) - .getOrElse(Option(System.getenv("LOCAL_DIRS")) - .getOrElse("")) - - if (localDirs.isEmpty()) { - throw new Exception("Yarn Local dirs can't be empty") + .orElse(Option(System.getenv("LOCAL_DIRS"))) + + localDirs match { + case None => throw new Exception("Yarn Local dirs can't be empty") + case Some(l) => l } - localDirs - } + } private def getApplicationAttemptId(): ApplicationAttemptId = { val envs = System.getenv() diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index be323d77835a8892eb3e481eba83f31f7dc3e8b9..51d9adb9d4061679c5871517046a96379393735b 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -99,6 +99,7 @@ class Client(args: ClientArguments, conf: Configuration, sparkConf: SparkConf) appContext.setApplicationName(args.appName) appContext.setQueue(args.amQueue) appContext.setAMContainerSpec(amContainer) + appContext.setApplicationType("SPARK") // Memory for the ApplicationMaster. val memoryResource = Records.newRecord(classOf[Resource]).asInstanceOf[Resource] @@ -207,7 +208,8 @@ class Client(args: ClientArguments, conf: Configuration, sparkConf: SparkConf) if (srcUri.getPort() != dstUri.getPort()) { return false } - return true + + true } /** Copy the file into HDFS if needed. */ diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala index 49248a8516b9cef6941524500965354d9ac30d08..78353224fa4b8b51aea9ba6d56ab04c2e1a479cd 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala @@ -78,6 +78,10 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar def run() { + // Setup the directories so things go to yarn approved directories rather + // then user specified and /tmp. + System.setProperty("spark.local.dir", getLocalDirs()) + amClient = AMRMClient.createAMRMClient() amClient.init(yarnConf) amClient.start() @@ -94,10 +98,12 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar // ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapse. val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000) - // must be <= timeoutInterval/ 2. - // On other hand, also ensure that we are reasonably responsive without causing too many requests to RM. - // so atleast 1 minute or timeoutInterval / 10 - whichever is higher. - val interval = math.min(timeoutInterval / 2, math.max(timeoutInterval / 10, 60000L)) + // we want to be reasonably responsive without causing too many requests to RM. + val schedulerInterval = + System.getProperty("spark.yarn.scheduler.heartbeat.interval-ms", "5000").toLong + // must be <= timeoutInterval / 2. + val interval = math.min(timeoutInterval / 2, schedulerInterval) + reporterThread = launchReporterThread(interval) // Wait for the reporter thread to Finish. @@ -110,6 +116,20 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar System.exit(0) } + /** Get the Yarn approved local directories. */ + private def getLocalDirs(): String = { + // Hadoop 0.23 and 2.x have different Environment variable names for the + // local dirs, so lets check both. We assume one of the 2 is set. + // LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X + val localDirs = Option(System.getenv("YARN_LOCAL_DIRS")) + .orElse(Option(System.getenv("LOCAL_DIRS"))) + + localDirs match { + case None => throw new Exception("Yarn Local dirs can't be empty") + case Some(l) => l + } + } + private def getApplicationAttemptId(): ApplicationAttemptId = { val envs = System.getenv() val containerIdString = envs.get(ApplicationConstants.Environment.CONTAINER_ID.name())