diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala index e6c4a6d3794a03ff2ec2dbc3269abfc5b4f41ac4..c64da8804d1660fe3fbe3a46353f63d98112dc90 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -19,24 +19,30 @@ package org.apache.spark.rdd import scala.reflect.ClassTag -import org.apache.spark.{Partition, SparkContext, SparkEnv, TaskContext} +import org.apache.spark._ import org.apache.spark.storage.{BlockId, BlockManager} +import scala.Some private[spark] class BlockRDDPartition(val blockId: BlockId, idx: Int) extends Partition { val index = idx } private[spark] -class BlockRDD[T: ClassTag](sc: SparkContext, @transient blockIds: Array[BlockId]) +class BlockRDD[T: ClassTag](@transient sc: SparkContext, @transient val blockIds: Array[BlockId]) extends RDD[T](sc, Nil) { @transient lazy val locations_ = BlockManager.blockIdsToHosts(blockIds, SparkEnv.get) + @volatile private var _isValid = true - override def getPartitions: Array[Partition] = (0 until blockIds.size).map(i => { - new BlockRDDPartition(blockIds(i), i).asInstanceOf[Partition] - }).toArray + override def getPartitions: Array[Partition] = { + assertValid() + (0 until blockIds.size).map(i => { + new BlockRDDPartition(blockIds(i), i).asInstanceOf[Partition] + }).toArray + } override def compute(split: Partition, context: TaskContext): Iterator[T] = { + assertValid() val blockManager = SparkEnv.get.blockManager val blockId = split.asInstanceOf[BlockRDDPartition].blockId blockManager.get(blockId) match { @@ -47,7 +53,36 @@ class BlockRDD[T: ClassTag](sc: SparkContext, @transient blockIds: Array[BlockId } override def getPreferredLocations(split: Partition): Seq[String] = { + assertValid() locations_(split.asInstanceOf[BlockRDDPartition].blockId) } + + /** + * Remove the data blocks that this BlockRDD is made from. NOTE: This is an + * irreversible operation, as the data in the blocks cannot be recovered back + * once removed. Use it with caution. + */ + private[spark] def removeBlocks() { + blockIds.foreach { blockId => + sc.env.blockManager.master.removeBlock(blockId) + } + _isValid = false + } + + /** + * Whether this BlockRDD is actually usable. This will be false if the data blocks have been + * removed using `this.removeBlocks`. + */ + private[spark] def isValid: Boolean = { + _isValid + } + + /** Check if this BlockRDD is valid. If not valid, exception is thrown. */ + private[spark] def assertValid() { + if (!_isValid) { + throw new SparkException( + "Attempted to use %s after its blocks have been removed!".format(toString)) + } + } } diff --git a/docs/configuration.md b/docs/configuration.md index e7e1dd56cf124f3b1bd2d33dd2b6a3d9c61c7af0..8d3442625b4753df29fd90031f260a7fbf8865e2 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -469,10 +469,13 @@ Apart from these, the following properties are also available, and may be useful </tr> <tr> <td>spark.streaming.unpersist</td> - <td>false</td> + <td>true</td> <td> Force RDDs generated and persisted by Spark Streaming to be automatically unpersisted from - Spark's memory. Setting this to true is likely to reduce Spark's RDD memory usage. + Spark's memory. The raw input data received by Spark Streaming is also automatically cleared. + Setting this to false will allow the raw data and persisted RDDs to be accessible outside the + streaming application as they will not be cleared automatically. But it comes at the cost of + higher memory usage in Spark. </td> </tr> <tr> diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Time.scala b/streaming/src/main/scala/org/apache/spark/streaming/Time.scala index 6a6b00a778b48401f13531b38f5120f69e6bd452..37b3b28fa01cb22651994c78daf5c28ecb8e87ad 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Time.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Time.scala @@ -68,5 +68,5 @@ case class Time(private val millis: Long) { } object Time { - val ordering = Ordering.by((time: Time) => time.millis) + implicit val ordering = Ordering.by((time: Time) => time.millis) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index d393cc03cb33e960ec88ce1aeb76ecbdbd7e0c51..f69f69e0c44afb2231896b28df04290f74e18ab2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -25,7 +25,7 @@ import scala.reflect.ClassTag import java.io.{IOException, ObjectInputStream, ObjectOutputStream} import org.apache.spark.Logging -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{BlockRDD, RDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.MetadataCleaner import org.apache.spark.streaming._ @@ -340,13 +340,23 @@ abstract class DStream[T: ClassTag] ( * this to clear their own metadata along with the generated RDDs. */ private[streaming] def clearMetadata(time: Time) { + val unpersistData = ssc.conf.getBoolean("spark.streaming.unpersist", true) val oldRDDs = generatedRDDs.filter(_._1 <= (time - rememberDuration)) logDebug("Clearing references to old RDDs: [" + oldRDDs.map(x => s"${x._1} -> ${x._2.id}").mkString(", ") + "]") generatedRDDs --= oldRDDs.keys - if (ssc.conf.getBoolean("spark.streaming.unpersist", false)) { + if (unpersistData) { logDebug("Unpersisting old RDDs: " + oldRDDs.values.map(_.id).mkString(", ")) - oldRDDs.values.foreach(_.unpersist(false)) + oldRDDs.values.foreach { rdd => + rdd.unpersist(false) + // Explicitly remove blocks of BlockRDD + rdd match { + case b: BlockRDD[_] => + logInfo("Removing blocks of RDD " + b + " of time " + time) + b.removeBlocks() + case _ => + } + } } logDebug("Cleared " + oldRDDs.size + " RDDs that were older than " + (time - rememberDuration) + ": " + oldRDDs.keys.mkString(", ")) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 8aec27e39478a185ad45a8678239bff66734a7cc..4792ca1f8ae3e4ee224671e11970981ec5c108ad 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.streaming import org.apache.spark.streaming.StreamingContext._ -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{BlockRDD, RDD} import org.apache.spark.SparkContext._ import util.ManualClock @@ -27,6 +27,8 @@ import org.apache.spark.{SparkContext, SparkConf} import org.apache.spark.streaming.dstream.{WindowedDStream, DStream} import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} import scala.reflect.ClassTag +import org.apache.spark.storage.StorageLevel +import scala.collection.mutable class BasicOperationsSuite extends TestSuiteBase { test("map") { @@ -450,6 +452,78 @@ class BasicOperationsSuite extends TestSuiteBase { assert(!stateStream.generatedRDDs.contains(Time(4000))) } + test("rdd cleanup - input blocks and persisted RDDs") { + // Actually receive data over through receiver to create BlockRDDs + + // Start the server + val testServer = new TestServer() + testServer.start() + + // Set up the streaming context and input streams + val ssc = new StreamingContext(conf, batchDuration) + val networkStream = ssc.socketTextStream("localhost", testServer.port, StorageLevel.MEMORY_AND_DISK) + val mappedStream = networkStream.map(_ + ".").persist() + val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] + val outputStream = new TestOutputStream(mappedStream, outputBuffer) + + outputStream.register() + ssc.start() + + // Feed data to the server to send to the network receiver + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val input = Seq(1, 2, 3, 4, 5, 6) + + val blockRdds = new mutable.HashMap[Time, BlockRDD[_]] + val persistentRddIds = new mutable.HashMap[Time, Int] + + def collectRddInfo() { // get all RDD info required for verification + networkStream.generatedRDDs.foreach { case (time, rdd) => + blockRdds(time) = rdd.asInstanceOf[BlockRDD[_]] + } + mappedStream.generatedRDDs.foreach { case (time, rdd) => + persistentRddIds(time) = rdd.id + } + } + + Thread.sleep(200) + for (i <- 0 until input.size) { + testServer.send(input(i).toString + "\n") + Thread.sleep(200) + clock.addToTime(batchDuration.milliseconds) + collectRddInfo() + } + + Thread.sleep(200) + collectRddInfo() + logInfo("Stopping server") + testServer.stop() + logInfo("Stopping context") + + // verify data has been received + assert(outputBuffer.size > 0) + assert(blockRdds.size > 0) + assert(persistentRddIds.size > 0) + + import Time._ + + val latestPersistedRddId = persistentRddIds(persistentRddIds.keySet.max) + val earliestPersistedRddId = persistentRddIds(persistentRddIds.keySet.min) + val latestBlockRdd = blockRdds(blockRdds.keySet.max) + val earliestBlockRdd = blockRdds(blockRdds.keySet.min) + // verify that the latest mapped RDD is persisted but the earliest one has been unpersisted + assert(ssc.sparkContext.persistentRdds.contains(latestPersistedRddId)) + assert(!ssc.sparkContext.persistentRdds.contains(earliestPersistedRddId)) + + // verify that the latest input blocks are present but the earliest blocks have been removed + assert(latestBlockRdd.isValid) + assert(latestBlockRdd.collect != null) + assert(!earliestBlockRdd.isValid) + earliestBlockRdd.blockIds.foreach { blockId => + assert(!ssc.sparkContext.env.blockManager.master.contains(blockId)) + } + ssc.stop() + } + /** Test cleanup of RDDs in DStream metadata */ def runCleanupTest[T: ClassTag]( conf2: SparkConf, diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 3bad871b5c580cbc6cdacb9cf5d2a72e9cefef48..b55b7834c90c10ca254d1d1e6c6c2604950a58ec 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -42,8 +42,6 @@ import org.apache.spark.streaming.receiver.{ActorHelper, Receiver} class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { - val testPort = 9999 - test("socket input stream") { // Start the server val testServer = new TestServer() @@ -288,17 +286,6 @@ class TestServer(portToBind: Int = 0) extends Logging { def port = serverSocket.getLocalPort } -object TestServer { - def main(args: Array[String]) { - val s = new TestServer() - s.start() - while(true) { - Thread.sleep(1000) - s.send("hello") - } - } -} - /** This is an actor for testing actor input stream */ class TestActor(port: Int) extends Actor with ActorHelper { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala index 45304c76b09283447dfaaeaf16533dbed1d49fd4..ff3619a59042da464ccf1d2d3d699c9e76fbd756 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala @@ -29,6 +29,7 @@ import org.scalatest.FunSuite import org.scalatest.concurrent.Timeouts import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ +import scala.language.postfixOps /** Testsuite for testing the network receiver behavior */ class NetworkReceiverSuite extends FunSuite with Timeouts {