From db8cc6f28abe4326cea6f53feb604920e4867a27 Mon Sep 17 00:00:00 2001
From: Takuya UESHIN <ueshin@happy-camper.st>
Date: Thu, 15 May 2014 11:20:21 -0700
Subject: [PATCH] [SPARK-1845] [SQL] Use AllScalaRegistrar for
 SparkSqlSerializer to register serializers of ...

...Scala collections.

When I execute `orderBy` or `limit` for `SchemaRDD` including `ArrayType` or `MapType`, `SparkSqlSerializer` throws the following exception:

```
com.esotericsoftware.kryo.KryoException: Class cannot be created (missing no-arg constructor): scala.collection.immutable.$colon$colon
```

or

```
com.esotericsoftware.kryo.KryoException: Class cannot be created (missing no-arg constructor): scala.collection.immutable.Vector
```

or

```
com.esotericsoftware.kryo.KryoException: Class cannot be created (missing no-arg constructor): scala.collection.immutable.HashMap$HashTrieMap
```

and so on.

This is because registrations of serializers for each concrete collections are missing in `SparkSqlSerializer`.
I believe it should use `AllScalaRegistrar`.
`AllScalaRegistrar` covers a lot of serializers for concrete classes of `Seq`, `Map` for `ArrayType`, `MapType`.

Author: Takuya UESHIN <ueshin@happy-camper.st>

Closes #790 from ueshin/issues/SPARK-1845 and squashes the following commits:

d1ed992 [Takuya UESHIN] Use AllScalaRegistrar for SparkSqlSerializer to register serializers of Scala collections.
---
 .../sql/execution/SparkSqlSerializer.scala    | 28 ++---------------
 .../org/apache/spark/sql/DslQuerySuite.scala  | 24 +++++++++++++++
 .../org/apache/spark/sql/SQLQuerySuite.scala  | 30 +++++++++++++++++++
 .../scala/org/apache/spark/sql/TestData.scala | 10 +++++++
 4 files changed, 66 insertions(+), 26 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
index 94c2a249ef..34b355e906 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
@@ -24,6 +24,7 @@ import scala.reflect.ClassTag
 import com.clearspring.analytics.stream.cardinality.HyperLogLog
 import com.esotericsoftware.kryo.io.{Input, Output}
 import com.esotericsoftware.kryo.{Serializer, Kryo}
+import com.twitter.chill.AllScalaRegistrar
 
 import org.apache.spark.{SparkEnv, SparkConf}
 import org.apache.spark.serializer.KryoSerializer
@@ -35,22 +36,14 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co
     val kryo = new Kryo()
     kryo.setRegistrationRequired(false)
     kryo.register(classOf[MutablePair[_, _]])
-    kryo.register(classOf[Array[Any]])
-    // This is kinda hacky...
-    kryo.register(classOf[scala.collection.immutable.Map$Map1], new MapSerializer)
-    kryo.register(classOf[scala.collection.immutable.Map$Map2], new MapSerializer)
-    kryo.register(classOf[scala.collection.immutable.Map$Map3], new MapSerializer)
-    kryo.register(classOf[scala.collection.immutable.Map$Map4], new MapSerializer)
-    kryo.register(classOf[scala.collection.immutable.Map[_,_]], new MapSerializer)
-    kryo.register(classOf[scala.collection.Map[_,_]], new MapSerializer)
     kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow])
     kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericMutableRow])
     kryo.register(classOf[com.clearspring.analytics.stream.cardinality.HyperLogLog],
                   new HyperLogLogSerializer)
-    kryo.register(classOf[scala.collection.mutable.ArrayBuffer[_]])
     kryo.register(classOf[scala.math.BigDecimal], new BigDecimalSerializer)
     kryo.setReferences(false)
     kryo.setClassLoader(Utils.getSparkClassLoader)
+    new AllScalaRegistrar().apply(kryo)
     kryo
   }
 }
@@ -97,20 +90,3 @@ private[sql] class HyperLogLogSerializer extends Serializer[HyperLogLog] {
     HyperLogLog.Builder.build(bytes)
   }
 }
-
-/**
- * Maps do not have a no arg constructor and so cannot be serialized by default. So, we serialize
- * them as `Array[(k,v)]`.
- */
-private[sql] class MapSerializer extends Serializer[Map[_,_]] {
-  def write(kryo: Kryo, output: Output, map: Map[_,_]) {
-    kryo.writeObject(output, map.flatMap(e => Seq(e._1, e._2)).toArray)
-  }
-
-  def read(kryo: Kryo, input: Input, tpe: Class[Map[_,_]]): Map[_,_] = {
-    kryo.readObject(input, classOf[Array[Any]])
-      .sliding(2,2)
-      .map { case Array(k,v) => (k,v) }
-      .toMap
-  }
-}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index 92a707ea57..f43e98d614 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -69,12 +69,36 @@ class DslQuerySuite extends QueryTest {
     checkAnswer(
       testData2.orderBy('a.desc, 'b.asc),
       Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2)))
+
+    checkAnswer(
+      arrayData.orderBy(GetItem('data, 0).asc),
+      arrayData.collect().sortBy(_.data(0)).toSeq)
+
+    checkAnswer(
+      arrayData.orderBy(GetItem('data, 0).desc),
+      arrayData.collect().sortBy(_.data(0)).reverse.toSeq)
+
+    checkAnswer(
+      mapData.orderBy(GetItem('data, 1).asc),
+      mapData.collect().sortBy(_.data(1)).toSeq)
+
+    checkAnswer(
+      mapData.orderBy(GetItem('data, 1).desc),
+      mapData.collect().sortBy(_.data(1)).reverse.toSeq)
   }
 
   test("limit") {
     checkAnswer(
       testData.limit(10),
       testData.take(10).toSeq)
+
+    checkAnswer(
+      arrayData.limit(1),
+      arrayData.take(1).toSeq)
+
+    checkAnswer(
+      mapData.limit(1),
+      mapData.take(1).toSeq)
   }
 
   test("average") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 524549eb54..189dccd525 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -85,6 +85,36 @@ class SQLQuerySuite extends QueryTest {
     checkAnswer(
       sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC"),
       Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2)))
+
+    checkAnswer(
+      sql("SELECT * FROM arrayData ORDER BY data[0] ASC"),
+      arrayData.collect().sortBy(_.data(0)).toSeq)
+
+    checkAnswer(
+      sql("SELECT * FROM arrayData ORDER BY data[0] DESC"),
+      arrayData.collect().sortBy(_.data(0)).reverse.toSeq)
+
+    checkAnswer(
+      sql("SELECT * FROM mapData ORDER BY data[1] ASC"),
+      mapData.collect().sortBy(_.data(1)).toSeq)
+
+    checkAnswer(
+      sql("SELECT * FROM mapData ORDER BY data[1] DESC"),
+      mapData.collect().sortBy(_.data(1)).reverse.toSeq)
+  }
+
+  test("limit") {
+    checkAnswer(
+      sql("SELECT * FROM testData LIMIT 10"),
+      testData.take(10).toSeq)
+
+    checkAnswer(
+      sql("SELECT * FROM arrayData LIMIT 1"),
+      arrayData.collect().take(1).toSeq)
+
+    checkAnswer(
+      sql("SELECT * FROM mapData LIMIT 1"),
+      mapData.collect().take(1).toSeq)
   }
 
   test("average") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index aa71e274f7..1aca387252 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -74,6 +74,16 @@ object TestData {
       ArrayData(Seq(2,3,4), Seq(Seq(2,3,4))) :: Nil)
   arrayData.registerAsTable("arrayData")
 
+  case class MapData(data: Map[Int, String])
+  val mapData =
+    TestSQLContext.sparkContext.parallelize(
+      MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) ::
+      MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) ::
+      MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) ::
+      MapData(Map(1 -> "a4", 2 -> "b4")) ::
+      MapData(Map(1 -> "a5")) :: Nil)
+  mapData.registerAsTable("mapData")
+
   case class StringData(s: String)
   val repeatedData =
     TestSQLContext.sparkContext.parallelize(List.fill(2)(StringData("test")))
-- 
GitLab