From 424e987dfebbbaa37f4496d44090d469a931ce76 Mon Sep 17 00:00:00 2001
From: Xiangrui Meng <meng@databricks.com>
Date: Thu, 2 Apr 2015 17:57:01 +0800
Subject: [PATCH] [SPARK-6672][SQL] convert row to catalyst in
 createDataFrame(RDD[Row], ...)

We assume that `RDD[Row]` contains Scala types. So we need to convert them into catalyst types in createDataFrame. liancheng

Author: Xiangrui Meng <meng@databricks.com>

Closes #5329 from mengxr/SPARK-6672 and squashes the following commits:

2d52644 [Xiangrui Meng] set needsConversion = false in jsonRDD
06896e4 [Xiangrui Meng] add createDataFrame without conversion
4a3767b [Xiangrui Meng] convert Row to catalyst
---
 .../spark/sql/catalyst/ScalaReflection.scala  |  5 +++++
 .../org/apache/spark/sql/DataFrame.scala      |  3 ++-
 .../org/apache/spark/sql/SQLContext.scala     | 20 ++++++++++++++++---
 .../apache/spark/sql/parquet/newParquet.scala |  3 ++-
 .../apache/spark/sql/sources/commands.scala   |  3 ++-
 .../spark/sql/test/ExamplePointUDT.scala      |  2 +-
 .../org/apache/spark/sql/DataFrameSuite.scala |  9 ++++++++-
 7 files changed, 37 insertions(+), 8 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 2220970085..8bfd0471d9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -72,6 +72,11 @@ trait ScalaReflection {
     case (d: BigDecimal, _) => Decimal(d)
     case (d: java.math.BigDecimal, _) => Decimal(d)
     case (d: java.sql.Date, _) => DateUtils.fromJavaDate(d)
+    case (r: Row, structType: StructType) =>
+      new GenericRow(
+        r.toSeq.zip(structType.fields).map { case (elem, field) =>
+          convertToCatalyst(elem, field.dataType)
+        }.toArray)
     case (other, _) => other
   }
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index ce0890906b..34be17325b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -904,7 +904,8 @@ class DataFrame private[sql](
    */
   override def repartition(numPartitions: Int): DataFrame = {
     sqlContext.createDataFrame(
-      queryExecution.toRdd.map(_.copy()).repartition(numPartitions), schema)
+      queryExecution.toRdd.map(_.copy()).repartition(numPartitions),
+      schema, needsConversion = false)
   }
 
   /**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 1794936a52..39dd14e796 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -392,9 +392,23 @@ class SQLContext(@transient val sparkContext: SparkContext)
    */
   @DeveloperApi
   def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = {
+    createDataFrame(rowRDD, schema, needsConversion = true)
+  }
+
+  /**
+   * Creates a DataFrame from an RDD[Row]. User can specify whether the input rows should be
+   * converted to Catalyst rows.
+   */
+  private[sql]
+  def createDataFrame(rowRDD: RDD[Row], schema: StructType, needsConversion: Boolean) = {
     // TODO: use MutableProjection when rowRDD is another DataFrame and the applied
     // schema differs from the existing schema on any field data type.
-    val logicalPlan = LogicalRDD(schema.toAttributes, rowRDD)(self)
+    val catalystRows = if (needsConversion) {
+      rowRDD.map(ScalaReflection.convertToCatalyst(_, schema).asInstanceOf[Row])
+    } else {
+      rowRDD
+    }
+    val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(self)
     DataFrame(this, logicalPlan)
   }
 
@@ -604,7 +618,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
         JsonRDD.nullTypeToStringType(
           JsonRDD.inferSchema(json, 1.0, columnNameOfCorruptJsonRecord)))
     val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
-    createDataFrame(rowRDD, appliedSchema)
+    createDataFrame(rowRDD, appliedSchema, needsConversion = false)
   }
 
   /**
@@ -633,7 +647,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
       JsonRDD.nullTypeToStringType(
         JsonRDD.inferSchema(json, samplingRatio, columnNameOfCorruptJsonRecord))
     val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
-    createDataFrame(rowRDD, appliedSchema)
+    createDataFrame(rowRDD, appliedSchema, needsConversion = false)
   }
 
   /**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
index 43f260d3ef..e12531480c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
@@ -122,7 +122,8 @@ private[sql] class DefaultSource
       val df =
         sqlContext.createDataFrame(
           data.queryExecution.toRdd,
-          data.schema.asNullable)
+          data.schema.asNullable,
+          needsConversion = false)
       val createdRelation =
         createRelation(sqlContext, parameters, df.schema).asInstanceOf[ParquetRelation2]
       createdRelation.insert(df, overwrite = mode == SaveMode.Overwrite)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
index 9bbe06e59b..dbdb0d39c2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
@@ -31,7 +31,8 @@ private[sql] case class InsertIntoDataSource(
     val relation = logicalRelation.relation.asInstanceOf[InsertableRelation]
     val data = DataFrame(sqlContext, query)
     // Apply the schema of the existing table to the new data.
-    val df = sqlContext.createDataFrame(data.queryExecution.toRdd, logicalRelation.schema)
+    val df = sqlContext.createDataFrame(
+      data.queryExecution.toRdd, logicalRelation.schema, needsConversion = false)
     relation.insert(df, overwrite)
 
     // Invalidate the cache.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
index c11d0ae5bf..2fdd798b44 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.types._
  * @param y y coordinate
  */
 @SQLUserDefinedType(udt = classOf[ExamplePointUDT])
-private[sql] class ExamplePoint(val x: Double, val y: Double)
+private[sql] class ExamplePoint(val x: Double, val y: Double) extends Serializable
 
 /**
  * User-defined type for [[ExamplePoint]].
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 6761d996fd..5297cc01ed 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -21,7 +21,7 @@ import scala.language.postfixOps
 
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types._
-import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, TestSQLContext}
 import org.apache.spark.sql.test.TestSQLContext.logicalPlanToSparkQuery
 import org.apache.spark.sql.test.TestSQLContext.implicits._
 import org.apache.spark.sql.test.TestSQLContext.sql
@@ -506,4 +506,11 @@ class DataFrameSuite extends QueryTest {
     testData.select($"*").show()
     testData.select($"*").show(1000)
   }
+
+  test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") {
+    val rowRDD = TestSQLContext.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0))))
+    val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false)))
+    val df = TestSQLContext.createDataFrame(rowRDD, schema)
+    df.rdd.collect()
+  }
 }
-- 
GitLab