diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 3fe3dc5e300e89dde1de7652b12b7617bcf2adb4..cf3820fcb6a359829e95c8a92dd94017d7cab2b6 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1159,8 +1159,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli kcf: () => WritableConverter[K], vcf: () => WritableConverter[V]): RDD[(K, V)] = { withScope { assertNotStopped() - val kc = kcf() - val vc = vcf() + val kc = clean(kcf)() + val vc = clean(vcf)() val format = classOf[SequenceFileInputFormat[Writable, Writable]] val writables = hadoopFile(path, format, kc.writableClass(km).asInstanceOf[Class[Writable]], diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index a6d5d2c94e17f43cb44a3e29551cf8db21ceea2e..8653cdee1adee351c3b7ceffac6c61b8516fdada 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -296,6 +296,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * before sending results to a reducer, similarly to a "combiner" in MapReduce. */ def reduceByKeyLocally(func: (V, V) => V): Map[K, V] = self.withScope { + val cleanedF = self.sparkContext.clean(func) if (keyClass.isArray) { throw new SparkException("reduceByKeyLocally() does not support array keys") @@ -305,7 +306,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val map = new JHashMap[K, V] iter.foreach { pair => val old = map.get(pair._1) - map.put(pair._1, if (old == null) pair._2 else func(old, pair._2)) + map.put(pair._1, if (old == null) pair._2 else cleanedF(old, pair._2)) } Iterator(map) } : Iterator[JHashMap[K, V]] @@ -313,7 +314,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val mergeMaps = (m1: JHashMap[K, V], m2: JHashMap[K, V]) => { m2.foreach { pair => val old = m1.get(pair._1) - m1.put(pair._1, if (old == null) pair._2 else func(old, pair._2)) + m1.put(pair._1, if (old == null) pair._2 else cleanedF(old, pair._2)) } m1 } : JHashMap[K, V] diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index e41f6ee27764e0a6b35921e560b0d187ba5e821a..7b165fe28bdd3360eaee5353e640e9d8d777409b 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -112,6 +112,7 @@ class ClosureCleanerSuite extends FunSuite { expectCorrectException { TestUserClosuresActuallyCleaned.testAggregateByKey(pairRdd) } expectCorrectException { TestUserClosuresActuallyCleaned.testFoldByKey(pairRdd) } expectCorrectException { TestUserClosuresActuallyCleaned.testReduceByKey(pairRdd) } + expectCorrectException { TestUserClosuresActuallyCleaned.testReduceByKeyLocally(pairRdd) } expectCorrectException { TestUserClosuresActuallyCleaned.testMapValues(pairRdd) } expectCorrectException { TestUserClosuresActuallyCleaned.testFlatMapValues(pairRdd) } expectCorrectException { TestUserClosuresActuallyCleaned.testForeachAsync(rdd) } @@ -315,6 +316,9 @@ private object TestUserClosuresActuallyCleaned { } def testFoldByKey(rdd: RDD[(Int, Int)]): Unit = { rdd.foldByKey(0) { case (_, _) => return; 1 } } def testReduceByKey(rdd: RDD[(Int, Int)]): Unit = { rdd.reduceByKey { case (_, _) => return; 1 } } + def testReduceByKeyLocally(rdd: RDD[(Int, Int)]): Unit = { + rdd.reduceByKeyLocally { case (_, _) => return; 1 } + } def testMapValues(rdd: RDD[(Int, Int)]): Unit = { rdd.mapValues { _ => return; 1 } } def testFlatMapValues(rdd: RDD[(Int, Int)]): Unit = { rdd.flatMapValues { _ => return; Seq() } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 7f181bcecd4bf807c0bd1800643620d33a9cc5af..fe614c4be590f842bc3039df693f9db2a33e4329 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -255,7 +255,7 @@ class StreamingContext private[streaming] ( * * Note: Return statements are NOT allowed in the given body. */ - private[streaming] def withNamedScope[U](name: String)(body: => U): U = { + private def withNamedScope[U](name: String)(body: => U): U = { RDDOperationScope.withScope(sc, name, allowNesting = false, ignoreParent = false)(body) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index 5977481e1f08148700ffdc69074c3c7d4d9c4d24..7c50a766a9bad2f362eb2a8f81b18d83d872d6f9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -539,7 +539,7 @@ abstract class DStream[T: ClassTag] ( /** Return a new DStream containing only the elements that satisfy a predicate. */ def filter(filterFunc: T => Boolean): DStream[T] = ssc.withScope { - new FilteredDStream(this, filterFunc) + new FilteredDStream(this, context.sparkContext.clean(filterFunc)) } /** @@ -624,7 +624,8 @@ abstract class DStream[T: ClassTag] ( * 'this' DStream will be registered as an output stream and therefore materialized. */ def foreachRDD(foreachFunc: RDD[T] => Unit): Unit = ssc.withScope { - this.foreachRDD((r: RDD[T], t: Time) => foreachFunc(r)) + val cleanedF = context.sparkContext.clean(foreachFunc, false) + this.foreachRDD((r: RDD[T], t: Time) => cleanedF(r)) } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index 884a8e8b5228917a524dba5f1f76dae2b8d163a4..fda22eb6ec42e1edfb61e577133b63ebf98fca7f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -38,6 +38,8 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) { private[streaming] def ssc = self.ssc + private[streaming] def sparkContext = self.context.sparkContext + private[streaming] def defaultPartitioner(numPartitions: Int = self.ssc.sc.defaultParallelism) = { new HashPartitioner(numPartitions) } @@ -98,8 +100,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) def reduceByKey( reduceFunc: (V, V) => V, partitioner: Partitioner): DStream[(K, V)] = ssc.withScope { - val cleanedReduceFunc = ssc.sc.clean(reduceFunc) - combineByKey((v: V) => v, cleanedReduceFunc, cleanedReduceFunc, partitioner) + combineByKey((v: V) => v, reduceFunc, reduceFunc, partitioner) } /** @@ -113,7 +114,15 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) mergeCombiner: (C, C) => C, partitioner: Partitioner, mapSideCombine: Boolean = true): DStream[(K, C)] = ssc.withScope { - new ShuffledDStream[K, V, C](self, createCombiner, mergeValue, mergeCombiner, partitioner, + val cleanedCreateCombiner = sparkContext.clean(createCombiner) + val cleanedMergeValue = sparkContext.clean(mergeValue) + val cleanedMergeCombiner = sparkContext.clean(mergeCombiner) + new ShuffledDStream[K, V, C]( + self, + cleanedCreateCombiner, + cleanedMergeValue, + cleanedMergeCombiner, + partitioner, mapSideCombine) } @@ -264,10 +273,9 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) slideDuration: Duration, partitioner: Partitioner ): DStream[(K, V)] = ssc.withScope { - val cleanedReduceFunc = ssc.sc.clean(reduceFunc) - self.reduceByKey(cleanedReduceFunc, partitioner) + self.reduceByKey(reduceFunc, partitioner) .window(windowDuration, slideDuration) - .reduceByKey(cleanedReduceFunc, partitioner) + .reduceByKey(reduceFunc, partitioner) } /** @@ -385,8 +393,9 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) updateFunc: (Seq[V], Option[S]) => Option[S], partitioner: Partitioner ): DStream[(K, S)] = ssc.withScope { + val cleanedUpdateF = sparkContext.clean(updateFunc) val newUpdateFunc = (iterator: Iterator[(K, Seq[V], Option[S])]) => { - iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s))) + iterator.flatMap(t => cleanedUpdateF(t._2, t._3).map(s => (t._1, s))) } updateStateByKey(newUpdateFunc, partitioner, true) } @@ -428,8 +437,9 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) partitioner: Partitioner, initialRDD: RDD[(K, S)] ): DStream[(K, S)] = ssc.withScope { + val cleanedUpdateF = sparkContext.clean(updateFunc) val newUpdateFunc = (iterator: Iterator[(K, Seq[V], Option[S])]) => { - iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s))) + iterator.flatMap(t => cleanedUpdateF(t._2, t._3).map(s => (t._1, s))) } updateStateByKey(newUpdateFunc, partitioner, true, initialRDD) } @@ -463,7 +473,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * 'this' DStream without changing the key. */ def mapValues[U: ClassTag](mapValuesFunc: V => U): DStream[(K, U)] = ssc.withScope { - new MapValuedDStream[K, V, U](self, mapValuesFunc) + new MapValuedDStream[K, V, U](self, sparkContext.clean(mapValuesFunc)) } /** @@ -473,7 +483,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) def flatMapValues[U: ClassTag]( flatMapValuesFunc: V => TraversableOnce[U] ): DStream[(K, U)] = ssc.withScope { - new FlatMapValuedDStream[K, V, U](self, flatMapValuesFunc) + new FlatMapValuedDStream[K, V, U](self, sparkContext.clean(flatMapValuesFunc)) } /** diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..6a1dd6949b204948d4fad167708a4f87616055f2 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import java.io.NotSerializableException + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import org.apache.spark.{HashPartitioner, SparkContext, SparkException} +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.ReturnStatementInClosureException + +/** + * Test that closures passed to DStream operations are actually cleaned. + */ +class DStreamClosureSuite extends FunSuite with BeforeAndAfterAll { + private var ssc: StreamingContext = null + + override def beforeAll(): Unit = { + val sc = new SparkContext("local", "test") + ssc = new StreamingContext(sc, Seconds(1)) + } + + override def afterAll(): Unit = { + ssc.stop(stopSparkContext = true) + ssc = null + } + + test("user provided closures are actually cleaned") { + val dstream = new DummyInputDStream(ssc) + val pairDstream = dstream.map { i => (i, i) } + // DStream + testMap(dstream) + testFlatMap(dstream) + testFilter(dstream) + testMapPartitions(dstream) + testReduce(dstream) + testForeach(dstream) + testForeachRDD(dstream) + testTransform(dstream) + testTransformWith(dstream) + testReduceByWindow(dstream) + // PairDStreamFunctions + testReduceByKey(pairDstream) + testCombineByKey(pairDstream) + testReduceByKeyAndWindow(pairDstream) + testUpdateStateByKey(pairDstream) + testMapValues(pairDstream) + testFlatMapValues(pairDstream) + // StreamingContext + testTransform2(ssc, dstream) + } + + /** + * Verify that the expected exception is thrown. + * + * We use return statements as an indication that a closure is actually being cleaned. + * We expect closure cleaner to find the return statements in the user provided closures. + */ + private def expectCorrectException(body: => Unit): Unit = { + try { + body + } catch { + case rse: ReturnStatementInClosureException => // Success! + case e @ (_: NotSerializableException | _: SparkException) => + throw new TestException( + s"Expected ReturnStatementInClosureException, but got $e.\n" + + "This means the closure provided by user is not actually cleaned.") + } + } + + // DStream operations + private def testMap(ds: DStream[Int]): Unit = expectCorrectException { + ds.map { _ => return; 1 } + } + private def testFlatMap(ds: DStream[Int]): Unit = expectCorrectException { + ds.flatMap { _ => return; Seq.empty } + } + private def testFilter(ds: DStream[Int]): Unit = expectCorrectException { + ds.filter { _ => return; true } + } + private def testMapPartitions(ds: DStream[Int]): Unit = expectCorrectException { + ds.mapPartitions { _ => return; Seq.empty.toIterator } + } + private def testReduce(ds: DStream[Int]): Unit = expectCorrectException { + ds.reduce { case (_, _) => return; 1 } + } + private def testForeach(ds: DStream[Int]): Unit = { + val foreachF1 = (rdd: RDD[Int], t: Time) => return + val foreachF2 = (rdd: RDD[Int]) => return + expectCorrectException { ds.foreach(foreachF1) } + expectCorrectException { ds.foreach(foreachF2) } + } + private def testForeachRDD(ds: DStream[Int]): Unit = { + val foreachRDDF1 = (rdd: RDD[Int], t: Time) => return + val foreachRDDF2 = (rdd: RDD[Int]) => return + expectCorrectException { ds.foreachRDD(foreachRDDF1) } + expectCorrectException { ds.foreachRDD(foreachRDDF2) } + } + private def testTransform(ds: DStream[Int]): Unit = { + val transformF1 = (rdd: RDD[Int]) => { return; rdd } + val transformF2 = (rdd: RDD[Int], time: Time) => { return; rdd } + expectCorrectException { ds.transform(transformF1) } + expectCorrectException { ds.transform(transformF2) } + } + private def testTransformWith(ds: DStream[Int]): Unit = { + val transformF1 = (rdd1: RDD[Int], rdd2: RDD[Int]) => { return; rdd1 } + val transformF2 = (rdd1: RDD[Int], rdd2: RDD[Int], time: Time) => { return; rdd2 } + expectCorrectException { ds.transformWith(ds, transformF1) } + expectCorrectException { ds.transformWith(ds, transformF2) } + } + private def testReduceByWindow(ds: DStream[Int]): Unit = { + val reduceF = (_: Int, _: Int) => { return; 1 } + expectCorrectException { ds.reduceByWindow(reduceF, Seconds(1), Seconds(2)) } + expectCorrectException { ds.reduceByWindow(reduceF, reduceF, Seconds(1), Seconds(2)) } + } + + // PairDStreamFunctions operations + private def testReduceByKey(ds: DStream[(Int, Int)]): Unit = { + val reduceF = (_: Int, _: Int) => { return; 1 } + expectCorrectException { ds.reduceByKey(reduceF) } + expectCorrectException { ds.reduceByKey(reduceF, 5) } + expectCorrectException { ds.reduceByKey(reduceF, new HashPartitioner(5)) } + } + private def testCombineByKey(ds: DStream[(Int, Int)]): Unit = { + expectCorrectException { + ds.combineByKey[Int]( + { _: Int => return; 1 }, + { case (_: Int, _: Int) => return; 1 }, + { case (_: Int, _: Int) => return; 1 }, + new HashPartitioner(5) + ) + } + } + private def testReduceByKeyAndWindow(ds: DStream[(Int, Int)]): Unit = { + val reduceF = (_: Int, _: Int) => { return; 1 } + val filterF = (_: (Int, Int)) => { return; false } + expectCorrectException { ds.reduceByKeyAndWindow(reduceF, Seconds(1)) } + expectCorrectException { ds.reduceByKeyAndWindow(reduceF, Seconds(1), Seconds(2)) } + expectCorrectException { ds.reduceByKeyAndWindow(reduceF, Seconds(1), Seconds(2), 5) } + expectCorrectException { + ds.reduceByKeyAndWindow(reduceF, Seconds(1), Seconds(2), new HashPartitioner(5)) + } + expectCorrectException { ds.reduceByKeyAndWindow(reduceF, reduceF, Seconds(2)) } + expectCorrectException { + ds.reduceByKeyAndWindow( + reduceF, reduceF, Seconds(2), Seconds(3), new HashPartitioner(5), filterF) + } + } + private def testUpdateStateByKey(ds: DStream[(Int, Int)]): Unit = { + val updateF1 = (_: Seq[Int], _: Option[Int]) => { return; Some(1) } + val updateF2 = (_: Iterator[(Int, Seq[Int], Option[Int])]) => { return; Seq((1, 1)).toIterator } + val initialRDD = ds.ssc.sparkContext.emptyRDD[Int].map { i => (i, i) } + expectCorrectException { ds.updateStateByKey(updateF1) } + expectCorrectException { ds.updateStateByKey(updateF1, 5) } + expectCorrectException { ds.updateStateByKey(updateF1, new HashPartitioner(5)) } + expectCorrectException { + ds.updateStateByKey(updateF1, new HashPartitioner(5), initialRDD) + } + expectCorrectException { + ds.updateStateByKey(updateF2, new HashPartitioner(5), true) + } + expectCorrectException { + ds.updateStateByKey(updateF2, new HashPartitioner(5), true, initialRDD) + } + } + private def testMapValues(ds: DStream[(Int, Int)]): Unit = expectCorrectException { + ds.mapValues { _ => return; 1 } + } + private def testFlatMapValues(ds: DStream[(Int, Int)]): Unit = expectCorrectException { + ds.flatMapValues { _ => return; Seq.empty } + } + + // StreamingContext operations + private def testTransform2(ssc: StreamingContext, ds: DStream[Int]): Unit = { + val transformF = (rdds: Seq[RDD[_]], time: Time) => { return; ssc.sparkContext.emptyRDD[Int] } + expectCorrectException { ssc.transform(Seq(ds), transformF) } + } + +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala index 392933102097e18fd4730fce50b63f20ef429b03..e3fb2ef13085985471ffe7af9a16dc61d9ea2461 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.streaming import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} import org.apache.spark.SparkContext -import org.apache.spark.rdd.{RDD, RDDOperationScope} -import org.apache.spark.streaming.dstream.{DStream, InputDStream} +import org.apache.spark.rdd.RDDOperationScope +import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.ui.UIUtils /** @@ -170,21 +170,3 @@ class DStreamScopeSuite extends FunSuite with BeforeAndAfter with BeforeAndAfter } } - -/** - * A dummy stream that does absolutely nothing. - */ -private class DummyDStream(ssc: StreamingContext) extends DStream[Int](ssc) { - override def dependencies: List[DStream[Int]] = List.empty - override def slideDuration: Duration = Seconds(1) - override def compute(time: Time): Option[RDD[Int]] = Some(ssc.sc.emptyRDD[Int]) -} - -/** - * A dummy input stream that does absolutely nothing. - */ -private class DummyInputDStream(ssc: StreamingContext) extends InputDStream[Int](ssc) { - override def start(): Unit = { } - override def stop(): Unit = { } - override def compute(time: Time): Option[RDD[Int]] = Some(ssc.sc.emptyRDD[Int]) -} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 4f70ae7f1f187c98f61c5fd2138e4058e410dda7..554cd30223f444c4a67fde5359641de1553b2fe6 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -35,6 +35,24 @@ import org.apache.spark.streaming.dstream.{DStream, InputDStream, ForEachDStream import org.apache.spark.streaming.scheduler._ import org.apache.spark.util.{ManualClock, Utils} +/** + * A dummy stream that does absolutely nothing. + */ +private[streaming] class DummyDStream(ssc: StreamingContext) extends DStream[Int](ssc) { + override def dependencies: List[DStream[Int]] = List.empty + override def slideDuration: Duration = Seconds(1) + override def compute(time: Time): Option[RDD[Int]] = Some(ssc.sc.emptyRDD[Int]) +} + +/** + * A dummy input stream that does absolutely nothing. + */ +private[streaming] class DummyInputDStream(ssc: StreamingContext) extends InputDStream[Int](ssc) { + override def start(): Unit = { } + override def stop(): Unit = { } + override def compute(time: Time): Option[RDD[Int]] = Some(ssc.sc.emptyRDD[Int]) +} + /** * This is a input stream just for the testsuites. This is equivalent to a checkpointable, * replayable, reliable message queue like Kafka. It requires a sequence as input, and