From ed8d1531f93f697c54bbaecefe08c37c32b0d391 Mon Sep 17 00:00:00 2001
From: Reynold Xin <rxin@databricks.com>
Date: Tue, 17 Nov 2015 19:02:44 -0800
Subject: [PATCH] [SPARK-11793][SQL] Dataset should set the resolved encoders
 internally for maps.

I also wrote a test case -- but unfortunately the test case is not working due to SPARK-11795.

Author: Reynold Xin <rxin@databricks.com>

Closes #9784 from rxin/SPARK-11503.
---
 .../src/main/scala/org/apache/spark/sql/Dataset.scala |  3 ++-
 .../scala/org/apache/spark/sql/DatasetSuite.scala     | 11 +++++++++++
 2 files changed, 13 insertions(+), 1 deletion(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 4cc3aa2465..bd01dd4dc5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -199,11 +199,12 @@ class Dataset[T] private[sql](
    * @since 1.6.0
    */
   def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = {
+    encoderFor[T].assertUnresolved()
     new Dataset[U](
       sqlContext,
       MapPartitions[T, U](
         func,
-        encoderFor[T],
+        resolvedTEncoder,
         encoderFor[U],
         encoderFor[U].schema.toAttributes,
         logicalPlan))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index c23dd46d37..a3922340cc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -73,6 +73,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
       ("a", 2), ("b", 3), ("c", 4))
   }
 
+  ignore("Dataset should set the resolved encoders internally for maps") {
+    // TODO: Enable this once we fix SPARK-11793.
+    val ds: Dataset[(ClassData, Long)] = Seq(ClassData("one", 1), ClassData("two", 2)).toDS()
+        .map(c => ClassData(c.a, c.b + 1))
+        .groupBy(p => p).count()
+
+    checkAnswer(
+      ds,
+      (ClassData("one", 1), 1L), (ClassData("two", 2), 1L))
+  }
+
   test("select") {
     val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
     checkAnswer(
-- 
GitLab