From ed8d1531f93f697c54bbaecefe08c37c32b0d391 Mon Sep 17 00:00:00 2001 From: Reynold Xin <rxin@databricks.com> Date: Tue, 17 Nov 2015 19:02:44 -0800 Subject: [PATCH] [SPARK-11793][SQL] Dataset should set the resolved encoders internally for maps. I also wrote a test case -- but unfortunately the test case is not working due to SPARK-11795. Author: Reynold Xin <rxin@databricks.com> Closes #9784 from rxin/SPARK-11503. --- .../src/main/scala/org/apache/spark/sql/Dataset.scala | 3 ++- .../scala/org/apache/spark/sql/DatasetSuite.scala | 11 +++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 4cc3aa2465..bd01dd4dc5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -199,11 +199,12 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { + encoderFor[T].assertUnresolved() new Dataset[U]( sqlContext, MapPartitions[T, U]( func, - encoderFor[T], + resolvedTEncoder, encoderFor[U], encoderFor[U].schema.toAttributes, logicalPlan)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index c23dd46d37..a3922340cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -73,6 +73,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ("a", 2), ("b", 3), ("c", 4)) } + ignore("Dataset should set the resolved encoders internally for maps") { + // TODO: Enable this once we fix SPARK-11793. + val ds: Dataset[(ClassData, Long)] = Seq(ClassData("one", 1), ClassData("two", 2)).toDS() + .map(c => ClassData(c.a, c.b + 1)) + .groupBy(p => p).count() + + checkAnswer( + ds, + (ClassData("one", 1), 1L), (ClassData("two", 2), 1L)) + } + test("select") { val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() checkAnswer( -- GitLab