diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index 15fd30e65761de4454a3e69a966732da031ce748..87f5cf944ed85da85acb5db1da248e3b3d33701f 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -20,6 +20,8 @@ package org.apache.spark.broadcast import java.io.Serializable import org.apache.spark.SparkException +import org.apache.spark.Logging +import org.apache.spark.util.Utils import scala.reflect.ClassTag @@ -52,7 +54,7 @@ import scala.reflect.ClassTag * @param id A unique identifier for the broadcast variable. * @tparam T Type of the data contained in the broadcast variable. */ -abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable { +abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable with Logging { /** * Flag signifying whether the broadcast variable is valid @@ -60,6 +62,8 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable { */ @volatile private var _isValid = true + private var _destroySite = "" + /** Get the broadcasted value. */ def value: T = { assertValid() @@ -84,13 +88,26 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable { doUnpersist(blocking) } + + /** + * Destroy all data and metadata related to this broadcast variable. Use this with caution; + * once a broadcast variable has been destroyed, it cannot be used again. + * This method blocks until destroy has completed + */ + def destroy() { + destroy(blocking = true) + } + /** * Destroy all data and metadata related to this broadcast variable. Use this with caution; * once a broadcast variable has been destroyed, it cannot be used again. + * @param blocking Whether to block until destroy has completed */ private[spark] def destroy(blocking: Boolean) { assertValid() _isValid = false + _destroySite = Utils.getCallSite().shortForm + logInfo("Destroying %s (from %s)".format(toString, _destroySite)) doDestroy(blocking) } @@ -124,7 +141,8 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable { /** Check if this broadcast is valid. If not valid, exception is thrown. */ protected def assertValid() { if (!_isValid) { - throw new SparkException("Attempted to use %s after it has been destroyed!".format(toString)) + throw new SparkException( + "Attempted to use %s after it was destroyed (%s) ".format(toString, _destroySite)) } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index d722ee5a97e9430a769e579c86074ada5bcabc9b..84ed5db8f0a535a13139a7573e8eb4924de08a79 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -992,7 +992,8 @@ private[spark] object Utils extends Logging { private def coreExclusionFunction(className: String): Boolean = { // A regular expression to match classes of the "core" Spark API that we want to skip when // finding the call site of a method. - val SPARK_CORE_CLASS_REGEX = """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?\.[A-Z]""".r + val SPARK_CORE_CLASS_REGEX = + """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?(\.broadcast)?\.[A-Z]""".r val SCALA_CLASS_REGEX = """^scala""".r val isSparkCoreClass = SPARK_CORE_CLASS_REGEX.findFirstIn(className).isDefined val isScalaClass = SCALA_CLASS_REGEX.findFirstIn(className).isDefined diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index e096c8c3e9b461a5bd0486389dcaee4ef942d184..1014fd62d9a75b4a2928db439056ab640c3cc0ac 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.broadcast import scala.util.Random -import org.scalatest.FunSuite +import org.scalatest.{Assertions, FunSuite} import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException} import org.apache.spark.io.SnappyCompressionCodec @@ -136,6 +136,12 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { test("Unpersisting TorrentBroadcast on executors and driver in distributed mode") { testUnpersistTorrentBroadcast(distributed = true, removeFromDriver = true) } + + test("Using broadcast after destroy prints callsite") { + sc = new SparkContext("local", "test") + testPackage.runCallSiteTest(sc) + } + /** * Verify the persistence of state associated with an HttpBroadcast in either local mode or * local-cluster mode (when distributed = true). @@ -311,3 +317,15 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { conf } } + +package object testPackage extends Assertions { + + def runCallSiteTest(sc: SparkContext) { + val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2) + val broadcast = sc.broadcast(rdd) + broadcast.destroy() + val thrown = intercept[SparkException] { broadcast.value } + assert(thrown.getMessage.contains("BroadcastSuite.scala")) + } + +}