From 6bb56faea8d238ea22c2de33db93b1b39f492b3a Mon Sep 17 00:00:00 2001
From: Sandy Ryza <sandy@cloudera.com>
Date: Tue, 21 Oct 2014 21:53:09 -0700
Subject: [PATCH] SPARK-1813. Add a utility to SparkConf that makes using Kryo
 really easy

Author: Sandy Ryza <sandy@cloudera.com>

Closes #789 from sryza/sandy-spark-1813 and squashes the following commits:

48b05e9 [Sandy Ryza] Simplify
b824932 [Sandy Ryza] Allow both spark.kryo.classesToRegister and spark.kryo.registrator at the same time
6a15bb7 [Sandy Ryza] Small fix
a2278c0 [Sandy Ryza] Respond to review comments
6ef592e [Sandy Ryza] SPARK-1813. Add a utility to SparkConf that makes using Kryo really easy
---
 .../scala/org/apache/spark/SparkConf.scala    | 17 ++++-
 .../spark/serializer/KryoSerializer.scala     | 43 ++++++++-----
 .../java/org/apache/spark/JavaAPISuite.java   | 12 ++++
 .../org/apache/spark/SparkConfSuite.scala     | 62 +++++++++++++++++++
 .../serializer/KryoSerializerSuite.scala      |  6 +-
 docs/configuration.md                         | 15 ++++-
 docs/tuning.md                                | 17 +----
 .../spark/examples/bagel/PageRankUtils.scala  | 17 -----
 .../examples/bagel/WikipediaPageRank.scala    |  3 +-
 .../spark/examples/graphx/Analytics.scala     |  6 +-
 .../examples/graphx/SynthBenchmark.scala      |  5 +-
 .../spark/examples/mllib/MovieLensALS.scala   | 12 +---
 .../spark/graphx/GraphKryoRegistrator.scala   |  2 +-
 .../org/apache/spark/graphx/GraphXUtils.scala | 47 ++++++++++++++
 .../spark/graphx/LocalSparkContext.scala      |  6 +-
 .../graphx/impl/EdgePartitionSuite.scala      |  6 +-
 .../graphx/impl/VertexPartitionSuite.scala    |  6 +-
 17 files changed, 195 insertions(+), 87 deletions(-)
 create mode 100644 graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala

diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala
index 605df0e929..dbbcc23305 100644
--- a/core/src/main/scala/org/apache/spark/SparkConf.scala
+++ b/core/src/main/scala/org/apache/spark/SparkConf.scala
@@ -18,7 +18,8 @@
 package org.apache.spark
 
 import scala.collection.JavaConverters._
-import scala.collection.mutable.HashMap
+import scala.collection.mutable.{HashMap, LinkedHashSet}
+import org.apache.spark.serializer.KryoSerializer
 
 /**
  * Configuration for a Spark application. Used to set various Spark parameters as key-value pairs.
@@ -140,6 +141,20 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
     this
   }
 
+  /**
+   * Use Kryo serialization and register the given set of classes with Kryo.
+   * If called multiple times, this will append the classes from all calls together.
+   */
+  def registerKryoClasses(classes: Array[Class[_]]): SparkConf = {
+    val allClassNames = new LinkedHashSet[String]()
+    allClassNames ++= get("spark.kryo.classesToRegister", "").split(',').filter(!_.isEmpty)
+    allClassNames ++= classes.map(_.getName)
+
+    set("spark.kryo.classesToRegister", allClassNames.mkString(","))
+    set("spark.serializer", classOf[KryoSerializer].getName)
+    this
+  }
+
   /** Remove a parameter from the configuration */
   def remove(key: String): SparkConf = {
     settings.remove(key)
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index d6386f8c06..621a951c27 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -53,7 +53,18 @@ class KryoSerializer(conf: SparkConf)
   private val maxBufferSize = conf.getInt("spark.kryoserializer.buffer.max.mb", 64) * 1024 * 1024
   private val referenceTracking = conf.getBoolean("spark.kryo.referenceTracking", true)
   private val registrationRequired = conf.getBoolean("spark.kryo.registrationRequired", false)
-  private val registrator = conf.getOption("spark.kryo.registrator")
+  private val userRegistrator = conf.getOption("spark.kryo.registrator")
+  private val classesToRegister = conf.get("spark.kryo.classesToRegister", "")
+    .split(',')
+    .filter(!_.isEmpty)
+    .map { className =>
+      try {
+        Class.forName(className)
+      } catch {
+        case e: Exception =>
+          throw new SparkException("Failed to load class to register with Kryo", e)
+      }
+    }
 
   def newKryoOutput() = new KryoOutput(bufferSize, math.max(bufferSize, maxBufferSize))
 
@@ -80,22 +91,20 @@ class KryoSerializer(conf: SparkConf)
     kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer())
     kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer())
 
-    // Allow the user to register their own classes by setting spark.kryo.registrator
-    for (regCls <- registrator) {
-      logDebug("Running user registrator: " + regCls)
-      try {
-        val reg = Class.forName(regCls, true, classLoader).newInstance()
-          .asInstanceOf[KryoRegistrator]
-
-        // Use the default classloader when calling the user registrator.
-        Thread.currentThread.setContextClassLoader(classLoader)
-        reg.registerClasses(kryo)
-      } catch {
-        case e: Exception =>
-          throw new SparkException(s"Failed to invoke $regCls", e)
-      } finally {
-        Thread.currentThread.setContextClassLoader(oldClassLoader)
-      }
+    try {
+      // Use the default classloader when calling the user registrator.
+      Thread.currentThread.setContextClassLoader(classLoader)
+      // Register classes given through spark.kryo.classesToRegister.
+      classesToRegister.foreach { clazz => kryo.register(clazz) }
+      // Allow the user to register their own classes by setting spark.kryo.registrator.
+      userRegistrator
+        .map(Class.forName(_, true, classLoader).newInstance().asInstanceOf[KryoRegistrator])
+        .foreach { reg => reg.registerClasses(kryo) }
+    } catch {
+      case e: Exception =>
+        throw new SparkException(s"Failed to register classes with Kryo", e)
+    } finally {
+      Thread.currentThread.setContextClassLoader(oldClassLoader)
     }
 
     // Register Chill's classes; we do this after our ranges and the user's own classes to let
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index 3190148fb5..814e40c4f7 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -1418,4 +1418,16 @@ public class JavaAPISuite implements Serializable {
     }
   }
 
+  static class Class1 {}
+  static class Class2 {}
+
+  @Test
+  public void testRegisterKryoClasses() {
+    SparkConf conf = new SparkConf();
+    conf.registerKryoClasses(new Class[]{ Class1.class, Class2.class });
+    Assert.assertEquals(
+        Class1.class.getName() + "," + Class2.class.getName(),
+        conf.get("spark.kryo.classesToRegister"));
+  }
+
 }
diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala
index 87e9012622..5d018ea986 100644
--- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala
@@ -18,6 +18,8 @@
 package org.apache.spark
 
 import org.scalatest.FunSuite
+import org.apache.spark.serializer.{KryoRegistrator, KryoSerializer}
+import com.esotericsoftware.kryo.Kryo
 
 class SparkConfSuite extends FunSuite with LocalSparkContext {
   test("loading from system properties") {
@@ -133,4 +135,64 @@ class SparkConfSuite extends FunSuite with LocalSparkContext {
       System.clearProperty("spark.test.a.b.c")
     }
   }
+
+  test("register kryo classes through registerKryoClasses") {
+    val conf = new SparkConf().set("spark.kryo.registrationRequired", "true")
+
+    conf.registerKryoClasses(Array(classOf[Class1], classOf[Class2]))
+    assert(conf.get("spark.kryo.classesToRegister") ===
+      classOf[Class1].getName + "," + classOf[Class2].getName)
+
+    conf.registerKryoClasses(Array(classOf[Class3]))
+    assert(conf.get("spark.kryo.classesToRegister") ===
+      classOf[Class1].getName + "," + classOf[Class2].getName + "," + classOf[Class3].getName)
+
+    conf.registerKryoClasses(Array(classOf[Class2]))
+    assert(conf.get("spark.kryo.classesToRegister") ===
+      classOf[Class1].getName + "," + classOf[Class2].getName + "," + classOf[Class3].getName)
+
+    // Kryo doesn't expose a way to discover registered classes, but at least make sure this doesn't
+    // blow up.
+    val serializer = new KryoSerializer(conf)
+    serializer.newInstance().serialize(new Class1())
+    serializer.newInstance().serialize(new Class2())
+    serializer.newInstance().serialize(new Class3())
+  }
+
+  test("register kryo classes through registerKryoClasses and custom registrator") {
+    val conf = new SparkConf().set("spark.kryo.registrationRequired", "true")
+
+    conf.registerKryoClasses(Array(classOf[Class1]))
+    assert(conf.get("spark.kryo.classesToRegister") === classOf[Class1].getName)
+
+    conf.set("spark.kryo.registrator", classOf[CustomRegistrator].getName)
+
+    // Kryo doesn't expose a way to discover registered classes, but at least make sure this doesn't
+    // blow up.
+    val serializer = new KryoSerializer(conf)
+    serializer.newInstance().serialize(new Class1())
+    serializer.newInstance().serialize(new Class2())
+  }
+
+  test("register kryo classes through conf") {
+    val conf = new SparkConf().set("spark.kryo.registrationRequired", "true")
+    conf.set("spark.kryo.classesToRegister", "java.lang.StringBuffer")
+    conf.set("spark.serializer", classOf[KryoSerializer].getName)
+
+    // Kryo doesn't expose a way to discover registered classes, but at least make sure this doesn't
+    // blow up.
+    val serializer = new KryoSerializer(conf)
+    serializer.newInstance().serialize(new StringBuffer())
+  }
+
+}
+
+class Class1 {}
+class Class2 {}
+class Class3 {}
+
+class CustomRegistrator extends KryoRegistrator {
+  def registerClasses(kryo: Kryo) {
+    kryo.register(classOf[Class2])
+  }
 }
diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
index e1e35b688d..64ac6d2d92 100644
--- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
@@ -210,13 +210,13 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext {
   }
 
   test("kryo with nonexistent custom registrator should fail") {
-    import org.apache.spark.{SparkConf, SparkException}
+    import org.apache.spark.SparkException
 
     val conf = new SparkConf(false)
     conf.set("spark.kryo.registrator", "this.class.does.not.exist")
-    
+
     val thrown = intercept[SparkException](new KryoSerializer(conf).newInstance())
-    assert(thrown.getMessage.contains("Failed to invoke this.class.does.not.exist"))
+    assert(thrown.getMessage.contains("Failed to register classes with Kryo"))
   }
 
   test("default class loader can be set by a different thread") {
diff --git a/docs/configuration.md b/docs/configuration.md
index 96fa1377ec..66738d3ca7 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -124,12 +124,23 @@ of the most common options to set are:
     <code>org.apache.spark.Serializer</code></a>.
   </td>
 </tr>
+<tr>
+  <td><code>spark.kryo.classesToRegister</code></td>
+  <td>(none)</td>
+  <td>
+    If you use Kryo serialization, give a comma-separated list of custom class names to register
+    with Kryo.
+    See the <a href="tuning.html#data-serialization">tuning guide</a> for more details.
+  </td>
+</tr>
 <tr>
   <td><code>spark.kryo.registrator</code></td>
   <td>(none)</td>
   <td>
-    If you use Kryo serialization, set this class to register your custom classes with Kryo.
-    It should be set to a class that extends
+    If you use Kryo serialization, set this class to register your custom classes with Kryo. This
+    property is useful if you need to register your classes in a custom way, e.g. to specify a custom
+    field serializer. Otherwise <code>spark.kryo.classesToRegister</code> is simpler. It should be
+    set to a class that extends
     <a href="api/scala/index.html#org.apache.spark.serializer.KryoRegistrator">
     <code>KryoRegistrator</code></a>.
     See the <a href="tuning.html#data-serialization">tuning guide</a> for more details.
diff --git a/docs/tuning.md b/docs/tuning.md
index 8fb2a0433b..9b5c9adac6 100644
--- a/docs/tuning.md
+++ b/docs/tuning.md
@@ -47,24 +47,11 @@ registration requirement, but we recommend trying it in any network-intensive ap
 Spark automatically includes Kryo serializers for the many commonly-used core Scala classes covered
 in the AllScalaRegistrar from the [Twitter chill](https://github.com/twitter/chill) library.
 
-To register your own custom classes with Kryo, create a public class that extends
-[`org.apache.spark.serializer.KryoRegistrator`](api/scala/index.html#org.apache.spark.serializer.KryoRegistrator) and set the
-`spark.kryo.registrator` config property to point to it, as follows:
+To register your own custom classes with Kryo, use the `registerKryoClasses` method.
 
 {% highlight scala %}
-import com.esotericsoftware.kryo.Kryo
-import org.apache.spark.serializer.KryoRegistrator
-
-class MyRegistrator extends KryoRegistrator {
-  override def registerClasses(kryo: Kryo) {
-    kryo.register(classOf[MyClass1])
-    kryo.register(classOf[MyClass2])
-  }
-}
-
 val conf = new SparkConf().setMaster(...).setAppName(...)
-conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
-conf.set("spark.kryo.registrator", "mypackage.MyRegistrator")
+conf.registerKryoClasses(Seq(classOf[MyClass1], classOf[MyClass2]))
 val sc = new SparkContext(conf)
 {% endhighlight %}
 
diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala
index e06f4dcd54..e322d4ce5a 100644
--- a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala
@@ -18,17 +18,7 @@
 package org.apache.spark.examples.bagel
 
 import org.apache.spark._
-import org.apache.spark.SparkContext._
-import org.apache.spark.serializer.KryoRegistrator
-
 import org.apache.spark.bagel._
-import org.apache.spark.bagel.Bagel._
-
-import scala.collection.mutable.ArrayBuffer
-
-import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream}
-
-import com.esotericsoftware.kryo._
 
 class PageRankUtils extends Serializable {
   def computeWithCombiner(numVertices: Long, epsilon: Double)(
@@ -99,13 +89,6 @@ class PRMessage() extends Message[String] with Serializable {
   }
 }
 
-class PRKryoRegistrator extends KryoRegistrator {
-  def registerClasses(kryo: Kryo) {
-    kryo.register(classOf[PRVertex])
-    kryo.register(classOf[PRMessage])
-  }
-}
-
 class CustomPartitioner(partitions: Int) extends Partitioner {
   def numPartitions = partitions
 
diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala
index e4db3ec513..859abedf2a 100644
--- a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala
@@ -38,8 +38,7 @@ object WikipediaPageRank {
     }
     val sparkConf = new SparkConf()
     sparkConf.setAppName("WikipediaPageRank")
-    sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
-    sparkConf.set("spark.kryo.registrator",  classOf[PRKryoRegistrator].getName)
+    sparkConf.registerKryoClasses(Array(classOf[PRVertex], classOf[PRMessage]))
 
     val inputFile = args(0)
     val threshold = args(1).toDouble
diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala
index 45527d9382..d70d93608a 100644
--- a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala
@@ -46,10 +46,8 @@ object Analytics extends Logging {
     }
     val options = mutable.Map(optionsList: _*)
 
-    val conf = new SparkConf()
-      .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
-      .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator")
-      .set("spark.locality.wait", "100000")
+    val conf = new SparkConf().set("spark.locality.wait", "100000")
+    GraphXUtils.registerKryoClasses(conf)
 
     val numEPart = options.remove("numEPart").map(_.toInt).getOrElse {
       println("Set the number of edge partitions using --numEPart.")
diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala
index 5f35a58364..0567602171 100644
--- a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.examples.graphx
 
 import org.apache.spark.SparkContext._
-import org.apache.spark.graphx.PartitionStrategy
+import org.apache.spark.graphx.{GraphXUtils, PartitionStrategy}
 import org.apache.spark.{SparkContext, SparkConf}
 import org.apache.spark.graphx.util.GraphGenerators
 import java.io.{PrintWriter, FileOutputStream}
@@ -80,8 +80,7 @@ object SynthBenchmark {
 
     val conf = new SparkConf()
       .setAppName(s"GraphX Synth Benchmark (nverts = $numVertices, app = $app)")
-      .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
-      .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator")
+    GraphXUtils.registerKryoClasses(conf)
 
     val sc = new SparkContext(conf)
 
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala
index fc6678013b..8796c28db8 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala
@@ -19,7 +19,6 @@ package org.apache.spark.examples.mllib
 
 import scala.collection.mutable
 
-import com.esotericsoftware.kryo.Kryo
 import org.apache.log4j.{Level, Logger}
 import scopt.OptionParser
 
@@ -27,7 +26,6 @@ import org.apache.spark.{SparkConf, SparkContext}
 import org.apache.spark.SparkContext._
 import org.apache.spark.mllib.recommendation.{ALS, MatrixFactorizationModel, Rating}
 import org.apache.spark.rdd.RDD
-import org.apache.spark.serializer.{KryoSerializer, KryoRegistrator}
 
 /**
  * An example app for ALS on MovieLens data (http://grouplens.org/datasets/movielens/).
@@ -40,13 +38,6 @@ import org.apache.spark.serializer.{KryoSerializer, KryoRegistrator}
  */
 object MovieLensALS {
 
-  class ALSRegistrator extends KryoRegistrator {
-    override def registerClasses(kryo: Kryo) {
-      kryo.register(classOf[Rating])
-      kryo.register(classOf[mutable.BitSet])
-    }
-  }
-
   case class Params(
       input: String = null,
       kryo: Boolean = false,
@@ -108,8 +99,7 @@ object MovieLensALS {
   def run(params: Params) {
     val conf = new SparkConf().setAppName(s"MovieLensALS with $params")
     if (params.kryo) {
-      conf.set("spark.serializer", classOf[KryoSerializer].getName)
-        .set("spark.kryo.registrator", classOf[ALSRegistrator].getName)
+      conf.registerKryoClasses(Array(classOf[mutable.BitSet], classOf[Rating]))
         .set("spark.kryoserializer.buffer.mb", "8")
     }
     val sc = new SparkContext(conf)
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala
index 1948c978c3..563c948957 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala
@@ -27,10 +27,10 @@ import org.apache.spark.graphx.impl._
 import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap
 import org.apache.spark.util.collection.OpenHashSet
 
-
 /**
  * Registers GraphX classes with Kryo for improved performance.
  */
+@deprecated("Register GraphX classes with Kryo using GraphXUtils.registerKryoClasses", "1.2.0")
 class GraphKryoRegistrator extends KryoRegistrator {
 
   def registerClasses(kryo: Kryo) {
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala
new file mode 100644
index 0000000000..2cb07937ea
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.graphx
+
+import org.apache.spark.SparkConf
+
+import org.apache.spark.graphx.impl._
+import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap
+
+import org.apache.spark.util.collection.{OpenHashSet, BitSet}
+import org.apache.spark.util.BoundedPriorityQueue
+
+object GraphXUtils {
+  /**
+   * Registers classes that GraphX uses with Kryo.
+   */
+  def registerKryoClasses(conf: SparkConf) {
+    conf.registerKryoClasses(Array(
+      classOf[Edge[Object]],
+      classOf[(VertexId, Object)],
+      classOf[EdgePartition[Object, Object]],
+      classOf[BitSet],
+      classOf[VertexIdToIndexMap],
+      classOf[VertexAttributeBlock[Object]],
+      classOf[PartitionStrategy],
+      classOf[BoundedPriorityQueue[Object]],
+      classOf[EdgeDirection],
+      classOf[GraphXPrimitiveKeyOpenHashMap[VertexId, Int]],
+      classOf[OpenHashSet[Int]],
+      classOf[OpenHashSet[Long]]))
+  }
+}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala b/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala
index 47594a800a..a3e28efc75 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala
@@ -17,9 +17,6 @@
 
 package org.apache.spark.graphx
 
-import org.scalatest.Suite
-import org.scalatest.BeforeAndAfterEach
-
 import org.apache.spark.SparkConf
 import org.apache.spark.SparkContext
 
@@ -31,8 +28,7 @@ trait LocalSparkContext {
   /** Runs `f` on a new SparkContext and ensures that it is stopped afterwards. */
   def withSpark[T](f: SparkContext => T) = {
     val conf = new SparkConf()
-      .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
-      .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator")
+    GraphXUtils.registerKryoClasses(conf)
     val sc = new SparkContext("local", "test", conf)
     try {
       f(sc)
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala
index 9d00f76327..db1dac6160 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala
@@ -129,9 +129,9 @@ class EdgePartitionSuite extends FunSuite {
     val aList = List((0, 1, 0), (1, 0, 0), (1, 2, 0), (5, 4, 0), (5, 5, 0))
     val a: EdgePartition[Int, Int] = makeEdgePartition(aList)
     val javaSer = new JavaSerializer(new SparkConf())
-    val kryoSer = new KryoSerializer(new SparkConf()
-      .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
-      .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator"))
+    val conf = new SparkConf()
+    GraphXUtils.registerKryoClasses(conf)
+    val kryoSer = new KryoSerializer(conf)
 
     for (ser <- List(javaSer, kryoSer); s = ser.newInstance()) {
       val aSer: EdgePartition[Int, Int] = s.deserialize(s.serialize(a))
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala
index f9e771a900..fe8304c1cd 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala
@@ -125,9 +125,9 @@ class VertexPartitionSuite extends FunSuite {
     val verts = Set((0L, 1), (1L, 1), (2L, 1))
     val vp = VertexPartition(verts.iterator)
     val javaSer = new JavaSerializer(new SparkConf())
-    val kryoSer = new KryoSerializer(new SparkConf()
-      .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
-      .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator"))
+    val conf = new SparkConf()
+    GraphXUtils.registerKryoClasses(conf)
+    val kryoSer = new KryoSerializer(conf)
 
     for (ser <- List(javaSer, kryoSer); s = ser.newInstance()) {
       val vpSer: VertexPartition[Int] = s.deserialize(s.serialize(vp))
-- 
GitLab