diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 545807ffbce55ee05addb8fb6cfe093393c368e1..76305237b03d5e048e73cb0267b1d4466f9f7525 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1002,9 +1002,7 @@ class SparkContext(config: SparkConf) extends Logging { require(p >= 0 && p < rdd.partitions.size, s"Invalid partition requested: $p") } val callSite = getCallSite - // There's no need to check this function for serializability, - // since it will be run right away. - val cleanedFunc = clean(func, false) + val cleanedFunc = clean(func) logInfo("Starting job: " + callSite) val start = System.nanoTime dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal, @@ -1137,18 +1135,14 @@ class SparkContext(config: SparkConf) extends Logging { def cancelAllJobs() { dagScheduler.cancelAllJobs() } - + /** * Clean a closure to make it ready to serialized and send to tasks * (removes unreferenced variables in $outer's, updates REPL variables) - * - * @param f closure to be cleaned and optionally serialized - * @param captureNow whether or not to serialize this closure and capture any free - * variables immediately; defaults to true. If this is set and f is not serializable, - * it will raise an exception. */ - private[spark] def clean[F <: AnyRef : ClassTag](f: F, captureNow: Boolean = true): F = { - ClosureCleaner.clean(f, captureNow) + private[spark] def clean[F <: AnyRef](f: F): F = { + ClosureCleaner.clean(f) + f } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index e363ea777d8eb3e5c4582b91b892d8e97b33eef5..3437b2cac19c2cfa73f67ae5b48b05c4c617bef4 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -660,16 +660,14 @@ abstract class RDD[T: ClassTag]( * Applies a function f to all elements of this RDD. */ def foreach(f: T => Unit) { - val cleanF = sc.clean(f) - sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF)) + sc.runJob(this, (iter: Iterator[T]) => iter.foreach(f)) } /** * Applies a function f to each partition of this RDD. */ def foreachPartition(f: Iterator[T] => Unit) { - val cleanF = sc.clean(f) - sc.runJob(this, (iter: Iterator[T]) => cleanF(iter)) + sc.runJob(this, (iter: Iterator[T]) => f(iter)) } /** diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index e474b1a850d65e7eb918d0b28a4dbde13ef58e0c..cdbbc65292188813fb78a4f0c25f68360812a64f 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -22,14 +22,10 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import scala.collection.mutable.Map import scala.collection.mutable.Set -import scala.reflect.ClassTag - import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type} import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ import org.apache.spark.Logging -import org.apache.spark.SparkEnv -import org.apache.spark.SparkException private[spark] object ClosureCleaner extends Logging { // Get an ASM class reader for a given class from the JAR that loaded it @@ -105,7 +101,7 @@ private[spark] object ClosureCleaner extends Logging { } } - def clean[F <: AnyRef : ClassTag](func: F, captureNow: Boolean = true): F = { + def clean(func: AnyRef) { // TODO: cache outerClasses / innerClasses / accessedFields val outerClasses = getOuterClasses(func) val innerClasses = getInnerClasses(func) @@ -154,21 +150,6 @@ private[spark] object ClosureCleaner extends Logging { field.setAccessible(true) field.set(func, outer) } - - if (captureNow) { - cloneViaSerializing(func) - } else { - func - } - } - - private def cloneViaSerializing[T: ClassTag](func: T): T = { - try { - val serializer = SparkEnv.get.closureSerializer.newInstance() - serializer.deserialize[T](serializer.serialize[T](func)) - } catch { - case ex: Exception => throw new SparkException("Task not serializable: " + ex.toString) - } } private def instantiateClass(cls: Class[_], outer: AnyRef, inInterpreter: Boolean): AnyRef = { diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index 4f9300419e6f8af2b07500346a633f717f03bd3c..12dbebcb286443fb18e9ad69f918d05f1691449e 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -107,7 +107,7 @@ class FailureSuite extends FunSuite with LocalSparkContext { FailureSuiteState.clear() } - test("failure because closure in final-stage task is not serializable") { + test("failure because task closure is not serializable") { sc = new SparkContext("local[1,1]", "test") val a = new NonSerializable @@ -118,13 +118,6 @@ class FailureSuite extends FunSuite with LocalSparkContext { assert(thrown.getClass === classOf[SparkException]) assert(thrown.getMessage.contains("NotSerializableException")) - FailureSuiteState.clear() - } - - test("failure because closure in early-stage task is not serializable") { - sc = new SparkContext("local[1,1]", "test") - val a = new NonSerializable - // Non-serializable closure in an earlier stage val thrown1 = intercept[SparkException] { sc.parallelize(1 to 10, 2).map(x => (x, a)).partitionBy(new HashPartitioner(3)).count() @@ -132,13 +125,6 @@ class FailureSuite extends FunSuite with LocalSparkContext { assert(thrown1.getClass === classOf[SparkException]) assert(thrown1.getMessage.contains("NotSerializableException")) - FailureSuiteState.clear() - } - - test("failure because closure in foreach task is not serializable") { - sc = new SparkContext("local[1,1]", "test") - val a = new NonSerializable - // Non-serializable closure in foreach function val thrown2 = intercept[SparkException] { sc.parallelize(1 to 10, 2).foreach(x => println(a)) @@ -149,6 +135,5 @@ class FailureSuite extends FunSuite with LocalSparkContext { FailureSuiteState.clear() } - // TODO: Need to add tests with shuffle fetch failures. } diff --git a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala deleted file mode 100644 index 76662264e7e94e21900fa538f696d7265e46d735..0000000000000000000000000000000000000000 --- a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.serializer; - -import java.io.NotSerializableException - -import org.scalatest.FunSuite - -import org.apache.spark.rdd.RDD -import org.apache.spark.SparkException -import org.apache.spark.SharedSparkContext - -/* A trivial (but unserializable) container for trivial functions */ -class UnserializableClass { - def op[T](x: T) = x.toString - - def pred[T](x: T) = x.toString.length % 2 == 0 -} - -class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContext { - - def fixture = (sc.parallelize(0 until 1000).map(_.toString), new UnserializableClass) - - test("throws expected serialization exceptions on actions") { - val (data, uc) = fixture - - val ex = intercept[SparkException] { - data.map(uc.op(_)).count - } - - assert(ex.getMessage.matches(".*Task not serializable.*")) - } - - // There is probably a cleaner way to eliminate boilerplate here, but we're - // iterating over a map from transformation names to functions that perform that - // transformation on a given RDD, creating one test case for each - - for (transformation <- - Map("map" -> map _, "flatMap" -> flatMap _, "filter" -> filter _, "mapWith" -> mapWith _, - "mapPartitions" -> mapPartitions _, "mapPartitionsWithIndex" -> mapPartitionsWithIndex _, - "mapPartitionsWithContext" -> mapPartitionsWithContext _, "filterWith" -> filterWith _)) { - val (name, xf) = transformation - - test(s"$name transformations throw proactive serialization exceptions") { - val (data, uc) = fixture - - val ex = intercept[SparkException] { - xf(data, uc) - } - - assert(ex.getMessage.matches(".*Task not serializable.*"), s"RDD.$name doesn't proactively throw NotSerializableException") - } - } - - def map(x: RDD[String], uc: UnserializableClass): RDD[String] = - x.map(y => uc.op(y)) - - def mapWith(x: RDD[String], uc: UnserializableClass): RDD[String] = - x.mapWith(x => x.toString)((x,y) => x + uc.op(y)) - - def flatMap(x: RDD[String], uc: UnserializableClass): RDD[String] = - x.flatMap(y=>Seq(uc.op(y))) - - def filter(x: RDD[String], uc: UnserializableClass): RDD[String] = - x.filter(y=>uc.pred(y)) - - def filterWith(x: RDD[String], uc: UnserializableClass): RDD[String] = - x.filterWith(x => x.toString)((x,y) => uc.pred(y)) - - def mapPartitions(x: RDD[String], uc: UnserializableClass): RDD[String] = - x.mapPartitions(_.map(y => uc.op(y))) - - def mapPartitionsWithIndex(x: RDD[String], uc: UnserializableClass): RDD[String] = - x.mapPartitionsWithIndex((_, it) => it.map(y => uc.op(y))) - - def mapPartitionsWithContext(x: RDD[String], uc: UnserializableClass): RDD[String] = - x.mapPartitionsWithContext((_, it) => it.map(y => uc.op(y))) - -} diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index c635da6cacd708060491c8f9d80d3483cb5c08df..439e5644e20a37bae720cb8592cddf6e535d23ff 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -50,27 +50,6 @@ class ClosureCleanerSuite extends FunSuite { val obj = new TestClassWithNesting(1) assert(obj.run() === 96) // 4 * (1+2+3+4) + 4 * (1+2+3+4) + 16 * 1 } - - test("capturing free variables in closures at RDD definition") { - val obj = new TestCaptureVarClass() - val (ones, onesPlusZeroes) = obj.run() - - assert(ones === onesPlusZeroes) - } - - test("capturing free variable fields in closures at RDD definition") { - val obj = new TestCaptureFieldClass() - val (ones, onesPlusZeroes) = obj.run() - - assert(ones === onesPlusZeroes) - } - - test("capturing arrays in closures at RDD definition") { - val obj = new TestCaptureArrayEltClass() - val (observed, expected) = obj.run() - - assert(observed === expected) - } } // A non-serializable class we create in closures to make sure that we aren't @@ -164,50 +143,3 @@ class TestClassWithNesting(val y: Int) extends Serializable { } } } - -class TestCaptureFieldClass extends Serializable { - class ZeroBox extends Serializable { - var zero = 0 - } - - def run(): (Int, Int) = { - val zb = new ZeroBox - - withSpark(new SparkContext("local", "test")) {sc => - val ones = sc.parallelize(Array(1, 1, 1, 1, 1)) - val onesPlusZeroes = ones.map(_ + zb.zero) - - zb.zero = 5 - - (ones.reduce(_ + _), onesPlusZeroes.reduce(_ + _)) - } - } -} - -class TestCaptureArrayEltClass extends Serializable { - def run(): (Int, Int) = { - withSpark(new SparkContext("local", "test")) {sc => - val rdd = sc.parallelize(1 to 10) - val data = Array(1, 2, 3) - val expected = data(0) - val mapped = rdd.map(x => data(0)) - data(0) = 4 - (mapped.first, expected) - } - } -} - -class TestCaptureVarClass extends Serializable { - def run(): (Int, Int) = { - var zero = 0 - - withSpark(new SparkContext("local", "test")) {sc => - val ones = sc.parallelize(Array(1, 1, 1, 1, 1)) - val onesPlusZeroes = ones.map(_ + zero) - - zero = 5 - - (ones.reduce(_ + _), onesPlusZeroes.reduce(_ + _)) - } - } -} diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index c65e36636fe10b51b7a8ea4c72b778cae6148add..28d34dd9a1a414b82caec571dd98f202fefee6e9 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -62,7 +62,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { assert( graph.edges.count() === rawEdges.size ) // Vertices not explicitly provided but referenced by edges should be created automatically assert( graph.vertices.count() === 100) - graph.triplets.collect.map { et => + graph.triplets.map { et => assert((et.srcId < 10 && et.srcAttr) || (et.srcId >= 10 && !et.srcAttr)) assert((et.dstId < 10 && et.dstAttr) || (et.dstId >= 10 && !et.dstAttr)) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index 4759b629a9931e3fc07ba477ecb11745376b53cc..d043200f71a0b7cdab67c5181dcc6ad155ab5a92 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -539,7 +539,7 @@ abstract class DStream[T: ClassTag] ( * on each RDD of 'this' DStream. */ def transform[U: ClassTag](transformFunc: RDD[T] => RDD[U]): DStream[U] = { - transform((r: RDD[T], t: Time) => context.sparkContext.clean(transformFunc(r), false)) + transform((r: RDD[T], t: Time) => context.sparkContext.clean(transformFunc(r))) } /** @@ -547,7 +547,7 @@ abstract class DStream[T: ClassTag] ( * on each RDD of 'this' DStream. */ def transform[U: ClassTag](transformFunc: (RDD[T], Time) => RDD[U]): DStream[U] = { - val cleanedF = context.sparkContext.clean(transformFunc, false) + val cleanedF = context.sparkContext.clean(transformFunc) val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => { assert(rdds.length == 1) cleanedF(rdds.head.asInstanceOf[RDD[T]], time) @@ -562,7 +562,7 @@ abstract class DStream[T: ClassTag] ( def transformWith[U: ClassTag, V: ClassTag]( other: DStream[U], transformFunc: (RDD[T], RDD[U]) => RDD[V] ): DStream[V] = { - val cleanedF = ssc.sparkContext.clean(transformFunc, false) + val cleanedF = ssc.sparkContext.clean(transformFunc) transformWith(other, (rdd1: RDD[T], rdd2: RDD[U], time: Time) => cleanedF(rdd1, rdd2)) } @@ -573,7 +573,7 @@ abstract class DStream[T: ClassTag] ( def transformWith[U: ClassTag, V: ClassTag]( other: DStream[U], transformFunc: (RDD[T], RDD[U], Time) => RDD[V] ): DStream[V] = { - val cleanedF = ssc.sparkContext.clean(transformFunc, false) + val cleanedF = ssc.sparkContext.clean(transformFunc) val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => { assert(rdds.length == 2) val rdd1 = rdds(0).asInstanceOf[RDD[T]]