diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 75e64c1bf401eb0c2934b7ceb06f6a38e0f41af2..94142d33369c773d6202a41a5e7148b471f41b18 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -56,11 +56,13 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) extends Broadcast[T](id) with Logging with Serializable { /** - * Value of the broadcast object. On driver, this is set directly by the constructor. - * On executors, this is reconstructed by [[readObject]], which builds this value by reading - * blocks from the driver and/or other executors. + * Value of the broadcast object on executors. This is reconstructed by [[readBroadcastBlock]], + * which builds this value by reading blocks from the driver and/or other executors. + * + * On the driver, if the value is required, it is read lazily from the block manager. */ - @transient private var _value: T = obj + @transient private lazy val _value: T = readBroadcastBlock() + /** The compression codec to use, or None if compression is disabled */ @transient private var compressionCodec: Option[CompressionCodec] = _ /** Size of each block. Default value is 4MB. This value is only read by the broadcaster. */ @@ -79,22 +81,24 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) private val broadcastId = BroadcastBlockId(id) /** Total number of blocks this broadcast variable contains. */ - private val numBlocks: Int = writeBlocks() + private val numBlocks: Int = writeBlocks(obj) - override protected def getValue() = _value + override protected def getValue() = { + _value + } /** * Divide the object into multiple blocks and put those blocks in the block manager. - * + * @param value the object to divide * @return number of blocks this broadcast variable is divided into */ - private def writeBlocks(): Int = { + private def writeBlocks(value: T): Int = { // Store a copy of the broadcast variable in the driver so that tasks run on the driver // do not create a duplicate copy of the broadcast variable's value. - SparkEnv.get.blockManager.putSingle(broadcastId, _value, StorageLevel.MEMORY_AND_DISK, + SparkEnv.get.blockManager.putSingle(broadcastId, value, StorageLevel.MEMORY_AND_DISK, tellMaster = false) val blocks = - TorrentBroadcast.blockifyObject(_value, blockSize, SparkEnv.get.serializer, compressionCodec) + TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec) blocks.zipWithIndex.foreach { case (block, i) => SparkEnv.get.blockManager.putBytes( BroadcastBlockId(id, "piece" + i), @@ -157,31 +161,30 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) out.defaultWriteObject() } - /** Used by the JVM when deserializing this object. */ - private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { - in.defaultReadObject() + private def readBroadcastBlock(): T = Utils.tryOrIOException { TorrentBroadcast.synchronized { setConf(SparkEnv.get.conf) SparkEnv.get.blockManager.getLocal(broadcastId).map(_.data.next()) match { case Some(x) => - _value = x.asInstanceOf[T] + x.asInstanceOf[T] case None => logInfo("Started reading broadcast variable " + id) - val start = System.nanoTime() + val startTimeMs = System.currentTimeMillis() val blocks = readBlocks() - val time = (System.nanoTime() - start) / 1e9 - logInfo("Reading broadcast variable " + id + " took " + time + " s") + logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs)) - _value = - TorrentBroadcast.unBlockifyObject[T](blocks, SparkEnv.get.serializer, compressionCodec) + val obj = TorrentBroadcast.unBlockifyObject[T]( + blocks, SparkEnv.get.serializer, compressionCodec) // Store the merged copy in BlockManager so other tasks on this executor don't // need to re-fetch it. SparkEnv.get.blockManager.putSingle( - broadcastId, _value, StorageLevel.MEMORY_AND_DISK, tellMaster = false) + broadcastId, obj, StorageLevel.MEMORY_AND_DISK, tellMaster = false) + obj } } } + } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 46600301558543b109c3307133780950cbef8e7c..612eca308bf0b1f860ede0b79e6ba644b8b05dad 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -988,6 +988,21 @@ private[spark] object Utils extends Logging { } } + /** + * Execute a block of code that returns a value, re-throwing any non-fatal uncaught + * exceptions as IOException. This is used when implementing Externalizable and Serializable's + * read and write methods, since Java's serializer will not report non-IOExceptions properly; + * see SPARK-4080 for more context. + */ + def tryOrIOException[T](block: => T): T = { + try { + block + } catch { + case e: IOException => throw e + case NonFatal(t) => throw new IOException(t) + } + } + /** Default filtering function for finding call sites using `getCallSite`. */ private def coreExclusionFunction(className: String): Boolean = { // A regular expression to match classes of the "core" Spark API that we want to skip when diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 1014fd62d9a75b4a2928db439056ab640c3cc0ac..b0a70f012f1f36c20121d1f9802437d8507daa3b 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -21,11 +21,28 @@ import scala.util.Random import org.scalatest.{Assertions, FunSuite} -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkEnv} import org.apache.spark.io.SnappyCompressionCodec +import org.apache.spark.rdd.RDD import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage._ +// Dummy class that creates a broadcast variable but doesn't use it +class DummyBroadcastClass(rdd: RDD[Int]) extends Serializable { + @transient val list = List(1, 2, 3, 4) + val broadcast = rdd.context.broadcast(list) + val bid = broadcast.id + + def doSomething() = { + rdd.map { x => + val bm = SparkEnv.get.blockManager + // Check if broadcast block was fetched + val isFound = bm.getLocal(BroadcastBlockId(bid)).isDefined + (x, isFound) + }.collect().toSet + } +} + class BroadcastSuite extends FunSuite with LocalSparkContext { private val httpConf = broadcastConf("HttpBroadcastFactory") @@ -105,6 +122,17 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { } } + test("Test Lazy Broadcast variables with TorrentBroadcast") { + val numSlaves = 2 + val conf = torrentConf.clone + sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", conf) + val rdd = sc.parallelize(1 to numSlaves) + + val results = new DummyBroadcastClass(rdd).doSomething() + + assert(results.toSet === (1 to numSlaves).map(x => (x, false)).toSet) + } + test("Unpersisting HttpBroadcast on executors only in local mode") { testUnpersistHttpBroadcast(distributed = false, removeFromDriver = false) }