diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index e30aa0a796924b6916945149943db0d7a6bbb025..cc11c0f35cdc964862cfd1a08c245e026278b827 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -601,6 +601,24 @@ class SQLTests(ReusedPySparkTestCase): point = df1.head().point self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) + def test_unionAll_with_udt(self): + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT + row1 = (1.0, ExamplePoint(1.0, 2.0)) + row2 = (2.0, ExamplePoint(3.0, 4.0)) + schema = StructType([StructField("label", DoubleType(), False), + StructField("point", ExamplePointUDT(), False)]) + df1 = self.sqlCtx.createDataFrame([row1], schema) + df2 = self.sqlCtx.createDataFrame([row2], schema) + + result = df1.unionAll(df2).orderBy("label").collect() + self.assertEqual( + result, + [ + Row(label=1.0, point=ExamplePoint(1.0, 2.0)), + Row(label=2.0, point=ExamplePoint(3.0, 4.0)) + ] + ) + def test_column_operators(self): ci = self.df.key cs = self.df.value diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala index d7a2c23be8a9a6908286ed185bda1524d48732cd..7664c30ee7650f7cf23bb9b9ac1a8f648e2dab33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -86,6 +86,11 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { this.getClass == dataType.getClass override def sql: String = sqlType.sql + + override def equals(other: Any): Boolean = other match { + case that: UserDefinedType[_] => this.acceptsType(that) + case _ => false + } } /** @@ -112,4 +117,9 @@ private[sql] class PythonUserDefinedType( ("serializedClass" -> serializedPyClass) ~ ("sqlType" -> sqlType.jsonValue) } + + override def equals(other: Any): Boolean = other match { + case that: PythonUserDefinedType => this.pyUDT.equals(that.pyUDT) + case _ => false + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala index 20a17ba82be9d96ba2f5c411e7ce1cb6f1db0ebe..e2c9fc421b83f671838c72729364a6af574281a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -26,7 +26,12 @@ import org.apache.spark.sql.types._ * @param y y coordinate */ @SQLUserDefinedType(udt = classOf[ExamplePointUDT]) -private[sql] class ExamplePoint(val x: Double, val y: Double) extends Serializable +private[sql] class ExamplePoint(val x: Double, val y: Double) extends Serializable { + override def equals(other: Any): Boolean = other match { + case that: ExamplePoint => this.x == that.x && this.y == that.y + case _ => false + } +} /** * User-defined type for [[ExamplePoint]]. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 50a246489ee5535ae8a54aeef4a2d4b79a46781d..4930c485da83fd6a4e55495b7c874682f48b2916 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -112,6 +112,22 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { ) } + test("unionAll should union DataFrames with UDTs (SPARK-13410)") { + val rowRDD1 = sparkContext.parallelize(Seq(Row(1, new ExamplePoint(1.0, 2.0)))) + val schema1 = StructType(Array(StructField("label", IntegerType, false), + StructField("point", new ExamplePointUDT(), false))) + val rowRDD2 = sparkContext.parallelize(Seq(Row(2, new ExamplePoint(3.0, 4.0)))) + val schema2 = StructType(Array(StructField("label", IntegerType, false), + StructField("point", new ExamplePointUDT(), false))) + val df1 = sqlContext.createDataFrame(rowRDD1, schema1) + val df2 = sqlContext.createDataFrame(rowRDD2, schema2) + + checkAnswer( + df1.unionAll(df2).orderBy("label"), + Seq(Row(1, new ExamplePoint(1.0, 2.0)), Row(2, new ExamplePoint(3.0, 4.0))) + ) + } + test("empty data frame") { assert(sqlContext.emptyDataFrame.columns.toSeq === Seq.empty[String]) assert(sqlContext.emptyDataFrame.count() === 0)