Skip to content
Snippets Groups Projects
Commit 9aa340a2 authored by Shivaram Venkataraman's avatar Shivaram Venkataraman
Browse files

[SPARK-4030] Make destroy public for broadcast variables

This change makes the destroy function public for broadcast variables. Motivation for the change is described in https://issues.apache.org/jira/browse/SPARK-4030.
This patch also logs where destroy was called from if a broadcast variable is used after destruction.

Author: Shivaram Venkataraman <shivaram@cs.berkeley.edu>

Closes #2922 from shivaram/broadcast-destroy and squashes the following commits:

a11abab [Shivaram Venkataraman] Fix scala style in Utils.scala
bed9c9d [Shivaram Venkataraman] Make destroy blocking by default
e80c1ab [Shivaram Venkataraman] Make destroy public for broadcast variables Also log where destroy was called from if a broadcast variable is used after destruction.
parent 6377adaf
No related branches found
No related tags found
No related merge requests found
......@@ -20,6 +20,8 @@ package org.apache.spark.broadcast
import java.io.Serializable
import org.apache.spark.SparkException
import org.apache.spark.Logging
import org.apache.spark.util.Utils
import scala.reflect.ClassTag
......@@ -52,7 +54,7 @@ import scala.reflect.ClassTag
* @param id A unique identifier for the broadcast variable.
* @tparam T Type of the data contained in the broadcast variable.
*/
abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable {
abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable with Logging {
/**
* Flag signifying whether the broadcast variable is valid
......@@ -60,6 +62,8 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable {
*/
@volatile private var _isValid = true
private var _destroySite = ""
/** Get the broadcasted value. */
def value: T = {
assertValid()
......@@ -84,13 +88,26 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable {
doUnpersist(blocking)
}
/**
* Destroy all data and metadata related to this broadcast variable. Use this with caution;
* once a broadcast variable has been destroyed, it cannot be used again.
* This method blocks until destroy has completed
*/
def destroy() {
destroy(blocking = true)
}
/**
* Destroy all data and metadata related to this broadcast variable. Use this with caution;
* once a broadcast variable has been destroyed, it cannot be used again.
* @param blocking Whether to block until destroy has completed
*/
private[spark] def destroy(blocking: Boolean) {
assertValid()
_isValid = false
_destroySite = Utils.getCallSite().shortForm
logInfo("Destroying %s (from %s)".format(toString, _destroySite))
doDestroy(blocking)
}
......@@ -124,7 +141,8 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable {
/** Check if this broadcast is valid. If not valid, exception is thrown. */
protected def assertValid() {
if (!_isValid) {
throw new SparkException("Attempted to use %s after it has been destroyed!".format(toString))
throw new SparkException(
"Attempted to use %s after it was destroyed (%s) ".format(toString, _destroySite))
}
}
......
......@@ -992,7 +992,8 @@ private[spark] object Utils extends Logging {
private def coreExclusionFunction(className: String): Boolean = {
// A regular expression to match classes of the "core" Spark API that we want to skip when
// finding the call site of a method.
val SPARK_CORE_CLASS_REGEX = """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?\.[A-Z]""".r
val SPARK_CORE_CLASS_REGEX =
"""^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?(\.broadcast)?\.[A-Z]""".r
val SCALA_CLASS_REGEX = """^scala""".r
val isSparkCoreClass = SPARK_CORE_CLASS_REGEX.findFirstIn(className).isDefined
val isScalaClass = SCALA_CLASS_REGEX.findFirstIn(className).isDefined
......
......@@ -19,7 +19,7 @@ package org.apache.spark.broadcast
import scala.util.Random
import org.scalatest.FunSuite
import org.scalatest.{Assertions, FunSuite}
import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException}
import org.apache.spark.io.SnappyCompressionCodec
......@@ -136,6 +136,12 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
test("Unpersisting TorrentBroadcast on executors and driver in distributed mode") {
testUnpersistTorrentBroadcast(distributed = true, removeFromDriver = true)
}
test("Using broadcast after destroy prints callsite") {
sc = new SparkContext("local", "test")
testPackage.runCallSiteTest(sc)
}
/**
* Verify the persistence of state associated with an HttpBroadcast in either local mode or
* local-cluster mode (when distributed = true).
......@@ -311,3 +317,15 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
conf
}
}
package object testPackage extends Assertions {
def runCallSiteTest(sc: SparkContext) {
val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2)
val broadcast = sc.broadcast(rdd)
broadcast.destroy()
val thrown = intercept[SparkException] { broadcast.value }
assert(thrown.getMessage.contains("BroadcastSuite.scala"))
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment