Skip to content
Snippets Groups Projects
Commit 9e4c79a4 authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Closure cleaner unit test

parent f346e646
No related branches found
No related tags found
No related merge requests found
package spark
import java.io.NotSerializableException
import org.scalatest.FunSuite
import SparkContext._
class ClosureCleanerSuite extends FunSuite {
test("closures inside an object") {
assert(TestObject.run() === 30) // 6 + 7 + 8 + 9
}
test("closures inside a class") {
val obj = new TestClass
assert(obj.run() === 30) // 6 + 7 + 8 + 9
}
test("closures inside a class with no default constructor") {
val obj = new TestClassWithoutDefaultConstructor(5)
assert(obj.run() === 30) // 6 + 7 + 8 + 9
}
test("closures that don't use fields of the outer class") {
val obj = new TestClassWithoutFieldAccess
assert(obj.run() === 30) // 6 + 7 + 8 + 9
}
test("nested closures inside an object") {
assert(TestObjectWithNesting.run() === 96) // 4 * (1+2+3+4) + 4 * (1+2+3+4) + 16 * 1
}
test("nested closures inside a class") {
val obj = new TestClassWithNesting(1)
assert(obj.run() === 96) // 4 * (1+2+3+4) + 4 * (1+2+3+4) + 16 * 1
}
}
// A non-serializable class we create in closures to make sure that we aren't
// keeping references to unneeded variables from our outer closures.
class NonSerializable {}
object TestObject {
def run(): Int = {
var nonSer = new NonSerializable
var x = 5
val sc = new SparkContext("local", "test")
val nums = sc.parallelize(Array(1, 2, 3, 4))
val answer = nums.map(_ + x).reduce(_ + _)
sc.stop()
return answer
}
}
class TestClass extends Serializable {
var x = 5
def getX = x
def run(): Int = {
var nonSer = new NonSerializable
val sc = new SparkContext("local", "test")
val nums = sc.parallelize(Array(1, 2, 3, 4))
val answer = nums.map(_ + getX).reduce(_ + _)
sc.stop()
return answer
}
}
class TestClassWithoutDefaultConstructor(x: Int) extends Serializable {
def getX = x
def run(): Int = {
var nonSer = new NonSerializable
val sc = new SparkContext("local", "test")
val nums = sc.parallelize(Array(1, 2, 3, 4))
val answer = nums.map(_ + getX).reduce(_ + _)
sc.stop()
return answer
}
}
// This class is not serializable, but we aren't using any of its fields in our
// closures, so they won't have a $outer pointing to it and should still work.
class TestClassWithoutFieldAccess {
var nonSer = new NonSerializable
def run(): Int = {
var nonSer2 = new NonSerializable
var x = 5
val sc = new SparkContext("local", "test")
val nums = sc.parallelize(Array(1, 2, 3, 4))
val answer = nums.map(_ + x).reduce(_ + _)
sc.stop()
return answer
}
}
object TestObjectWithNesting {
def run(): Int = {
var nonSer = new NonSerializable
var answer = 0
val sc = new SparkContext("local", "test")
val nums = sc.parallelize(Array(1, 2, 3, 4))
var y = 1
for (i <- 1 to 4) {
var nonSer2 = new NonSerializable
var x = i
answer += nums.map(_ + x + y).reduce(_ + _)
}
sc.stop()
return answer
}
}
class TestClassWithNesting(val y: Int) extends Serializable {
def getY = y
def run(): Int = {
var nonSer = new NonSerializable
var answer = 0
val sc = new SparkContext("local", "test")
val nums = sc.parallelize(Array(1, 2, 3, 4))
for (i <- 1 to 4) {
var nonSer2 = new NonSerializable
var x = i
answer += nums.map(_ + x + getY).reduce(_ + _)
}
sc.stop()
return answer
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment