diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 222097008546277adb91741319ffef011238cb91..8bfd0471d9c7a6d32693ba6829911ace9b968f8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -72,6 +72,11 @@ trait ScalaReflection { case (d: BigDecimal, _) => Decimal(d) case (d: java.math.BigDecimal, _) => Decimal(d) case (d: java.sql.Date, _) => DateUtils.fromJavaDate(d) + case (r: Row, structType: StructType) => + new GenericRow( + r.toSeq.zip(structType.fields).map { case (elem, field) => + convertToCatalyst(elem, field.dataType) + }.toArray) case (other, _) => other } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index ce0890906bf1b2e06e44936c861bc508ab14a871..34be17325b2b04d65e2a078c1cbf65083b829218 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -904,7 +904,8 @@ class DataFrame private[sql]( */ override def repartition(numPartitions: Int): DataFrame = { sqlContext.createDataFrame( - queryExecution.toRdd.map(_.copy()).repartition(numPartitions), schema) + queryExecution.toRdd.map(_.copy()).repartition(numPartitions), + schema, needsConversion = false) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 1794936a52c6d4129568bd901c5575418051e43d..39dd14e796f06910b6758c3b6b145d29783ff529 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -392,9 +392,23 @@ class SQLContext(@transient val sparkContext: SparkContext) */ @DeveloperApi def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = { + createDataFrame(rowRDD, schema, needsConversion = true) + } + + /** + * Creates a DataFrame from an RDD[Row]. User can specify whether the input rows should be + * converted to Catalyst rows. + */ + private[sql] + def createDataFrame(rowRDD: RDD[Row], schema: StructType, needsConversion: Boolean) = { // TODO: use MutableProjection when rowRDD is another DataFrame and the applied // schema differs from the existing schema on any field data type. - val logicalPlan = LogicalRDD(schema.toAttributes, rowRDD)(self) + val catalystRows = if (needsConversion) { + rowRDD.map(ScalaReflection.convertToCatalyst(_, schema).asInstanceOf[Row]) + } else { + rowRDD + } + val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(self) DataFrame(this, logicalPlan) } @@ -604,7 +618,7 @@ class SQLContext(@transient val sparkContext: SparkContext) JsonRDD.nullTypeToStringType( JsonRDD.inferSchema(json, 1.0, columnNameOfCorruptJsonRecord))) val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord) - createDataFrame(rowRDD, appliedSchema) + createDataFrame(rowRDD, appliedSchema, needsConversion = false) } /** @@ -633,7 +647,7 @@ class SQLContext(@transient val sparkContext: SparkContext) JsonRDD.nullTypeToStringType( JsonRDD.inferSchema(json, samplingRatio, columnNameOfCorruptJsonRecord)) val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord) - createDataFrame(rowRDD, appliedSchema) + createDataFrame(rowRDD, appliedSchema, needsConversion = false) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 43f260d3ef8d3bbcdfc2977f89eae26600289ddb..e12531480ce92ce928dfcdc80172f95d1f429d2c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -122,7 +122,8 @@ private[sql] class DefaultSource val df = sqlContext.createDataFrame( data.queryExecution.toRdd, - data.schema.asNullable) + data.schema.asNullable, + needsConversion = false) val createdRelation = createRelation(sqlContext, parameters, df.schema).asInstanceOf[ParquetRelation2] createdRelation.insert(df, overwrite = mode == SaveMode.Overwrite) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index 9bbe06e59ba3074f01e3a78cef354a05d92d50f8..dbdb0d39c26a1543d87f8358862679125d2be3ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -31,7 +31,8 @@ private[sql] case class InsertIntoDataSource( val relation = logicalRelation.relation.asInstanceOf[InsertableRelation] val data = DataFrame(sqlContext, query) // Apply the schema of the existing table to the new data. - val df = sqlContext.createDataFrame(data.queryExecution.toRdd, logicalRelation.schema) + val df = sqlContext.createDataFrame( + data.queryExecution.toRdd, logicalRelation.schema, needsConversion = false) relation.insert(df, overwrite) // Invalidate the cache. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala index c11d0ae5bf1cca6771647a65687420b611163a93..2fdd798b44bb618c8b777188eec3ea5213fca057 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.types._ * @param y y coordinate */ @SQLUserDefinedType(udt = classOf[ExamplePointUDT]) -private[sql] class ExamplePoint(val x: Double, val y: Double) +private[sql] class ExamplePoint(val x: Double, val y: Double) extends Serializable /** * User-defined type for [[ExamplePoint]]. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 6761d996fd9754cca5a2a0bc2302126a72b04a8c..5297cc01eddfc4724650e083ea9eb50ca565b520 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -21,7 +21,7 @@ import scala.language.postfixOps import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, TestSQLContext} import org.apache.spark.sql.test.TestSQLContext.logicalPlanToSparkQuery import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.test.TestSQLContext.sql @@ -506,4 +506,11 @@ class DataFrameSuite extends QueryTest { testData.select($"*").show() testData.select($"*").show(1000) } + + test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") { + val rowRDD = TestSQLContext.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) + val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false))) + val df = TestSQLContext.createDataFrame(rowRDD, schema) + df.rdd.collect() + } }