diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 2d05e09b109482e9c051c8d8e97b883f61768881..4916d9b86cca57c4131746e9b16fb72591d0c60c 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -25,7 +25,7 @@ import scala.collection.mutable.Set import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type} import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkException} private[spark] object ClosureCleaner extends Logging { // Get an ASM class reader for a given class from the JAR that loaded it @@ -108,6 +108,9 @@ private[spark] object ClosureCleaner extends Logging { val outerObjects = getOuterObjects(func) val accessedFields = Map[Class[_], Set[String]]() + + getClassReader(func.getClass).accept(new ReturnStatementFinder(), 0) + for (cls <- outerClasses) accessedFields(cls) = Set[String]() for (cls <- func.getClass :: innerClasses) @@ -180,6 +183,24 @@ private[spark] object ClosureCleaner extends Logging { } } +private[spark] +class ReturnStatementFinder extends ClassVisitor(ASM4) { + override def visitMethod(access: Int, name: String, desc: String, + sig: String, exceptions: Array[String]): MethodVisitor = { + if (name.contains("apply")) { + new MethodVisitor(ASM4) { + override def visitTypeInsn(op: Int, tp: String) { + if (op == NEW && tp.contains("scala/runtime/NonLocalReturnControl")) { + throw new SparkException("Return statements aren't allowed in Spark closures") + } + } + } + } else { + new MethodVisitor(ASM4) {} + } + } +} + private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends ClassVisitor(ASM4) { override def visitMethod(access: Int, name: String, desc: String, diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index d7e48e633e0ee78dab65afda77b716ee135f7af1..054ef54e746a5ca6d20e4ebe1c2b0326fb9a0cb2 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.util import org.scalatest.FunSuite import org.apache.spark.LocalSparkContext._ -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkException} class ClosureCleanerSuite extends FunSuite { test("closures inside an object") { @@ -50,6 +50,19 @@ class ClosureCleanerSuite extends FunSuite { val obj = new TestClassWithNesting(1) assert(obj.run() === 96) // 4 * (1+2+3+4) + 4 * (1+2+3+4) + 16 * 1 } + + test("toplevel return statements in closures are identified at cleaning time") { + val ex = intercept[SparkException] { + TestObjectWithBogusReturns.run() + } + + assert(ex.getMessage.contains("Return statements aren't allowed in Spark closures")) + } + + test("return statements from named functions nested in closures don't raise exceptions") { + val result = TestObjectWithNestedReturns.run() + assert(result == 1) + } } // A non-serializable class we create in closures to make sure that we aren't @@ -108,6 +121,30 @@ class TestClassWithoutFieldAccess { } } +object TestObjectWithBogusReturns { + def run(): Int = { + withSpark(new SparkContext("local", "test")) { sc => + val nums = sc.parallelize(Array(1, 2, 3, 4)) + // this return is invalid since it will transfer control outside the closure + nums.map {x => return 1 ; x * 2} + 1 + } + } +} + +object TestObjectWithNestedReturns { + def run(): Int = { + withSpark(new SparkContext("local", "test")) { sc => + val nums = sc.parallelize(Array(1, 2, 3, 4)) + nums.map {x => + // this return is fine since it will not transfer control outside the closure + def foo(): Int = { return 5; 1 } + foo() + } + 1 + } + } +} object TestObjectWithNesting { def run(): Int = {