From 201236628a344194f7c20ba8e9afeeaefbe9318c Mon Sep 17 00:00:00 2001
From: Michael Armbrust <michael@databricks.com>
Date: Tue, 24 Feb 2015 10:52:18 -0800
Subject: [PATCH] [SPARK-5532][SQL] Repartition should not use external rdd
 representation

Author: Michael Armbrust <michael@databricks.com>

Closes #4738 from marmbrus/udtRepart and squashes the following commits:

c06d7b5 [Michael Armbrust] fix compilation
91c8829 [Michael Armbrust] [SQL][SPARK-5532] Repartition should not use external rdd representation
---
 .../scala/org/apache/spark/sql/DataFrame.scala   |  5 +++--
 .../spark/sql/execution/debug/package.scala      |  1 +
 .../apache/spark/sql/UserDefinedTypeSuite.scala  | 16 +++++++++++++++-
 3 files changed, 19 insertions(+), 3 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 27ac398063..04bf5d9b0f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -799,14 +799,15 @@ class DataFrame protected[sql](
    * Returns the number of rows in the [[DataFrame]].
    * @group action
    */
-  override def count(): Long = groupBy().count().rdd.collect().head.getLong(0)
+  override def count(): Long = groupBy().count().collect().head.getLong(0)
 
   /**
    * Returns a new [[DataFrame]] that has exactly `numPartitions` partitions.
    * @group rdd
    */
   override def repartition(numPartitions: Int): DataFrame = {
-    sqlContext.createDataFrame(rdd.repartition(numPartitions), schema)
+    sqlContext.createDataFrame(
+      queryExecution.toRdd.map(_.copy()).repartition(numPartitions), schema)
   }
 
   /**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index 73162b22fa..ffe388cfa9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -168,6 +168,7 @@ package object debug {
       case (_: Short, ShortType) =>
       case (_: Boolean, BooleanType) =>
       case (_: Double, DoubleType) =>
+      case (v, udt: UserDefinedType[_]) => typeCheck(v, udt.sqlType)
 
       case (d, t) => sys.error(s"Invalid data found: got $d (${d.getClass}) expected $t")
     }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index 5f21d990e2..9c098df24c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -17,10 +17,11 @@
 
 package org.apache.spark.sql
 
+import java.io.File
+
 import scala.beans.{BeanInfo, BeanProperty}
 
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.TestSQLContext
 import org.apache.spark.sql.test.TestSQLContext.{sparkContext, sql}
 import org.apache.spark.sql.test.TestSQLContext.implicits._
@@ -91,4 +92,17 @@ class UserDefinedTypeSuite extends QueryTest {
       sql("SELECT testType(features) from points"),
       Seq(Row(true), Row(true)))
   }
+
+
+  test("UDTs with Parquet") {
+    val tempDir = File.createTempFile("parquet", "test")
+    tempDir.delete()
+    pointsRDD.saveAsParquetFile(tempDir.getCanonicalPath)
+  }
+
+  test("Repartition UDTs with Parquet") {
+    val tempDir = File.createTempFile("parquet", "test")
+    tempDir.delete()
+    pointsRDD.repartition(1).saveAsParquetFile(tempDir.getCanonicalPath)
+  }
 }
-- 
GitLab