From 6ae2746d1e6d73c2628a67afbbed828d6efce5c4 Mon Sep 17 00:00:00 2001
From: Matei Zaharia <matei@eecs.berkeley.edu>
Date: Wed, 6 Jun 2012 16:13:02 -0700
Subject: [PATCH] Handle arrays that contain the same element many times better
 in SizeEstimator. Also added a test for SizeEstimator. Fixes #136.

---
 core/src/main/scala/spark/SizeEstimator.scala | 38 +++++----
 .../test/scala/spark/SizeEstimatorSuite.scala | 77 +++++++++++++++++++
 2 files changed, 102 insertions(+), 13 deletions(-)
 create mode 100644 core/src/test/scala/spark/SizeEstimatorSuite.scala

diff --git a/core/src/main/scala/spark/SizeEstimator.scala b/core/src/main/scala/spark/SizeEstimator.scala
index 4b89503e84..b3bd4daa73 100644
--- a/core/src/main/scala/spark/SizeEstimator.scala
+++ b/core/src/main/scala/spark/SizeEstimator.scala
@@ -9,6 +9,8 @@ import java.util.Random
 
 import scala.collection.mutable.ArrayBuffer
 
+import it.unimi.dsi.fastutil.ints.IntOpenHashSet
+
 /**
  * Estimates the sizes of Java objects (number of bytes of memory they occupy), for use in 
  * memory-aware caches.
@@ -39,8 +41,7 @@ object SizeEstimator {
    * IdentityHashMap of visited objects, and provides utility methods for enqueueing new objects
    * to visit.
    */
-  private class SearchState {
-    val visited = new IdentityHashMap[AnyRef, AnyRef]
+  private class SearchState(val visited: IdentityHashMap[AnyRef, AnyRef]) {
     val stack = new ArrayBuffer[AnyRef]
     var size = 0L
 
@@ -61,16 +62,18 @@ object SizeEstimator {
   }
 
   /**
-   * Cached information about each class. We remember two things: the
-   * "shell size" of the class (size of all non-static fields plus the
-   * java.lang.Object size), and any fields that are pointers to objects.
+   * Cached information about each class. We remember two things: the "shell size" of the class
+   * (size of all non-static fields plus the java.lang.Object size), and any fields that are
+   * pointers to objects.
    */
   private class ClassInfo(
     val shellSize: Long,
     val pointerFields: List[Field]) {}
 
-  def estimate(obj: AnyRef): Long = {
-    val state = new SearchState
+  def estimate(obj: AnyRef): Long = estimate(obj, new IdentityHashMap[AnyRef, AnyRef])
+
+  private def estimate(obj: AnyRef, visited: IdentityHashMap[AnyRef, AnyRef]): Long = {
+    val state = new SearchState(visited)
     state.enqueue(obj)
     while (!state.isFinished) {
       visitSingleObject(state.dequeue(), state)
@@ -91,6 +94,10 @@ object SizeEstimator {
     }
   }
 
+  // Estimat the size of arrays larger than ARRAY_SIZE_FOR_SAMPLING by sampling.
+  private val ARRAY_SIZE_FOR_SAMPLING = 200
+  private val ARRAY_SAMPLE_SIZE = 100 // should be lower than ARRAY_SIZE_FOR_SAMPLING
+
   private def visitArray(array: AnyRef, cls: Class[_], state: SearchState) {
     val length = JArray.getLength(array)
     val elementClass = cls.getComponentType
@@ -98,18 +105,23 @@ object SizeEstimator {
       state.size += length * primitiveSize(elementClass)
     } else {
       state.size += length * POINTER_SIZE
-      if (length <= 100) {
+      if (length <= ARRAY_SIZE_FOR_SAMPLING) {
         for (i <- 0 until length) {
           state.enqueue(JArray.get(array, i))
         }
       } else {
-        // Estimate the size of a large array by sampling elements.
-        // TODO: Add a config setting for turning this off?
+        // Estimate the size of a large array by sampling elements without replacement.
         var size = 0.0
         val rand = new Random(42)
-        for (i <- 0 until 100) {
-          val elem = JArray.get(array, rand.nextInt(length))
-          size += SizeEstimator.estimate(elem)
+        val drawn = new IntOpenHashSet(ARRAY_SAMPLE_SIZE)
+        for (i <- 0 until ARRAY_SAMPLE_SIZE) {
+          var index = 0
+          do {
+            index = rand.nextInt(length)
+          } while (drawn.contains(index))
+          drawn.add(index)
+          val elem = JArray.get(array, index)
+          size += SizeEstimator.estimate(elem, state.visited)
         }
         state.size += ((length / 100.0) * size).toLong
       }
diff --git a/core/src/test/scala/spark/SizeEstimatorSuite.scala b/core/src/test/scala/spark/SizeEstimatorSuite.scala
new file mode 100644
index 0000000000..63bc951858
--- /dev/null
+++ b/core/src/test/scala/spark/SizeEstimatorSuite.scala
@@ -0,0 +1,77 @@
+package spark
+
+import org.scalatest.FunSuite
+
+class DummyClass1 {}
+
+class DummyClass2 {
+  val x: Int = 0
+}
+
+class DummyClass3 {
+  val x: Int = 0
+  val y: Double = 0.0
+}
+
+class DummyClass4(val d: DummyClass3) {
+  val x: Int = 0
+}
+
+class SizeEstimatorSuite extends FunSuite {
+  test("simple classes") {
+    expect(8)(SizeEstimator.estimate(new DummyClass1))
+    expect(12)(SizeEstimator.estimate(new DummyClass2))
+    expect(20)(SizeEstimator.estimate(new DummyClass3))
+    expect(16)(SizeEstimator.estimate(new DummyClass4(null)))
+    expect(36)(SizeEstimator.estimate(new DummyClass4(new DummyClass3)))
+  }
+
+  test("strings") {
+    expect(24)(SizeEstimator.estimate(""))
+    expect(26)(SizeEstimator.estimate("a"))
+    expect(28)(SizeEstimator.estimate("ab"))
+    expect(40)(SizeEstimator.estimate("abcdefgh"))
+  }
+
+  test("primitive arrays") {
+    expect(10)(SizeEstimator.estimate(new Array[Byte](10)))
+    expect(20)(SizeEstimator.estimate(new Array[Char](10)))
+    expect(20)(SizeEstimator.estimate(new Array[Short](10)))
+    expect(40)(SizeEstimator.estimate(new Array[Int](10)))
+    expect(80)(SizeEstimator.estimate(new Array[Long](10)))
+    expect(40)(SizeEstimator.estimate(new Array[Float](10)))
+    expect(80)(SizeEstimator.estimate(new Array[Double](10)))
+    expect(4000)(SizeEstimator.estimate(new Array[Int](1000)))
+    expect(8000)(SizeEstimator.estimate(new Array[Long](1000)))
+  }
+
+  test("object arrays") {
+    // Arrays containing nulls should just have one pointer per element
+    expect(40)(SizeEstimator.estimate(new Array[String](10)))
+    expect(40)(SizeEstimator.estimate(new Array[AnyRef](10)))
+
+    // For object arrays with non-null elements, each object should take one pointer plus
+    // however many bytes that class takes. (Note that Array.fill calls the code in its
+    // second parameter separately for each object, so we get distinct objects.)
+    expect(120)(SizeEstimator.estimate(Array.fill(10)(new DummyClass1))) 
+    expect(160)(SizeEstimator.estimate(Array.fill(10)(new DummyClass2))) 
+    expect(240)(SizeEstimator.estimate(Array.fill(10)(new DummyClass3))) 
+    expect(12 + 16)(SizeEstimator.estimate(Array(new DummyClass1, new DummyClass2)))
+
+    // Past size 100, our samples 100 elements, but we should still get the right size.
+    expect(24000)(SizeEstimator.estimate(Array.fill(1000)(new DummyClass3)))
+
+    // If an array contains the *same* element many times, we should only count it once.
+    val d1 = new DummyClass1
+    expect(48)(SizeEstimator.estimate(Array.fill(10)(d1))) // 10 pointers plus 8-byte object
+    expect(408)(SizeEstimator.estimate(Array.fill(100)(d1))) // 100 pointers plus 8-byte object
+
+    // Same thing with huge array containing the same element many times. Note that this won't
+    // return exactly 4008 because it can't tell that *all* the elements will equal the first
+    // one it samples, but it should be close to that.
+    val estimatedSize = SizeEstimator.estimate(Array.fill(1000)(d1))
+    assert(estimatedSize >= 4000, "Estimated size " + estimatedSize + " should be more than 4000")
+    assert(estimatedSize <= 4100, "Estimated size " + estimatedSize + " should be less than 4100")
+  }
+}
+
-- 
GitLab