Skip to content
Snippets Groups Projects
Commit bbd8f5be authored by Takuya UESHIN's avatar Takuya UESHIN Committed by Michael Armbrust
Browse files

[SPARK-4245][SQL] Fix containsNull of the result ArrayType of CreateArray expression.

The `containsNull` of the result `ArrayType` of `CreateArray` should be `true` only if the children is empty or there exists nullable child.

Author: Takuya UESHIN <ueshin@happy-camper.st>

Closes #3110 from ueshin/issues/SPARK-4245 and squashes the following commits:

6f64746 [Takuya UESHIN] Move equalsIgnoreNullability method into DataType.
5a90e02 [Takuya UESHIN] Refine InsertIntoHiveType and add some comments.
cbecba8 [Takuya UESHIN] Fix a test title.
884ec37 [Takuya UESHIN] Merge branch 'master' into issues/SPARK-4245
3c5274b [Takuya UESHIN] Add tests to insert data of types ArrayType / MapType / StructType with nullability is false into Hive table.
41a94a9 [Takuya UESHIN] Replace InsertIntoTable with InsertIntoHiveTable if data types ignoring nullability are same.
43e6ef5 [Takuya UESHIN] Fix containsNull for empty array.
778e997 [Takuya UESHIN] Fix containsNull of the result ArrayType of CreateArray expression.
parent ade72c43
No related branches found
No related tags found
No related merge requests found
......@@ -115,7 +115,9 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
override def dataType: DataType = {
assert(resolved, s"Invalid dataType of mixed ArrayType ${childTypes.mkString(",")}")
ArrayType(childTypes.headOption.getOrElse(NullType))
ArrayType(
childTypes.headOption.getOrElse(NullType),
containsNull = children.exists(_.nullable))
}
override def nullable: Boolean = false
......
......@@ -171,6 +171,27 @@ object DataType {
case _ =>
}
}
/**
* Compares two types, ignoring nullability of ArrayType, MapType, StructType.
*/
def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = {
(left, right) match {
case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) =>
equalsIgnoreNullability(leftElementType, rightElementType)
case (MapType(leftKeyType, leftValueType, _), MapType(rightKeyType, rightValueType, _)) =>
equalsIgnoreNullability(leftKeyType, rightKeyType) &&
equalsIgnoreNullability(leftValueType, rightValueType)
case (StructType(leftFields), StructType(rightFields)) =>
leftFields.size == rightFields.size &&
leftFields.zip(rightFields)
.forall{
case (left, right) =>
left.name == right.name && equalsIgnoreNullability(left.dataType, right.dataType)
}
case (left, right) => left == right
}
}
}
abstract class DataType {
......
......@@ -286,6 +286,12 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
if (childOutputDataTypes == tableOutputDataTypes) {
p
} else if (childOutputDataTypes.size == tableOutputDataTypes.size &&
childOutputDataTypes.zip(tableOutputDataTypes)
.forall { case (left, right) => DataType.equalsIgnoreNullability(left, right) }) {
// If both types ignoring nullability of ArrayType, MapType, StructType are the same,
// use InsertIntoHiveTable instead of InsertIntoTable.
InsertIntoHiveTable(p.table, p.partition, p.child, p.overwrite)
} else {
// Only do the casting when child output data types differ from table output data types.
val castedChildOutput = child.output.zip(table.output).map {
......@@ -316,6 +322,27 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
override def unregisterAllTables() = {}
}
/**
* A logical plan representing insertion into Hive table.
* This plan ignores nullability of ArrayType, MapType, StructType unlike InsertIntoTable
* because Hive table doesn't have nullability for ARRAY, MAP, STRUCT types.
*/
private[hive] case class InsertIntoHiveTable(
table: LogicalPlan,
partition: Map[String, Option[String]],
child: LogicalPlan,
overwrite: Boolean)
extends LogicalPlan {
override def children = child :: Nil
override def output = child.output
override lazy val resolved = childrenResolved && child.output.zip(table.output).forall {
case (childAttr, tableAttr) =>
DataType.equalsIgnoreNullability(childAttr.dataType, tableAttr.dataType)
}
}
/**
* :: DeveloperApi ::
* Provides conversions between Spark SQL data types and Hive Metastore types.
......
......@@ -161,7 +161,11 @@ private[hive] trait HiveStrategies {
object DataSinks extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.InsertIntoTable(table: MetastoreRelation, partition, child, overwrite) =>
InsertIntoHiveTable(table, partition, planLater(child), overwrite)(hiveContext) :: Nil
execution.InsertIntoHiveTable(
table, partition, planLater(child), overwrite)(hiveContext) :: Nil
case hive.InsertIntoHiveTable(table: MetastoreRelation, partition, child, overwrite) =>
execution.InsertIntoHiveTable(
table, partition, planLater(child), overwrite)(hiveContext) :: Nil
case logical.CreateTableAsSelect(
Some(database), tableName, child, allowExisting, Some(extra: ASTNode)) =>
CreateTableAsSelect(
......
......@@ -121,4 +121,54 @@ class InsertIntoHiveTableSuite extends QueryTest {
sql("DROP TABLE table_with_partition")
sql("DROP TABLE tmp_table")
}
test("Insert ArrayType.containsNull == false") {
val schema = StructType(Seq(
StructField("a", ArrayType(StringType, containsNull = false))))
val rowRDD = TestHive.sparkContext.parallelize((1 to 100).map(i => Row(Seq(s"value$i"))))
val schemaRDD = applySchema(rowRDD, schema)
schemaRDD.registerTempTable("tableWithArrayValue")
sql("CREATE TABLE hiveTableWithArrayValue(a Array <STRING>)")
sql("INSERT OVERWRITE TABLE hiveTableWithArrayValue SELECT a FROM tableWithArrayValue")
checkAnswer(
sql("SELECT * FROM hiveTableWithArrayValue"),
rowRDD.collect().toSeq)
sql("DROP TABLE hiveTableWithArrayValue")
}
test("Insert MapType.valueContainsNull == false") {
val schema = StructType(Seq(
StructField("m", MapType(StringType, StringType, valueContainsNull = false))))
val rowRDD = TestHive.sparkContext.parallelize(
(1 to 100).map(i => Row(Map(s"key$i" -> s"value$i"))))
val schemaRDD = applySchema(rowRDD, schema)
schemaRDD.registerTempTable("tableWithMapValue")
sql("CREATE TABLE hiveTableWithMapValue(m Map <STRING, STRING>)")
sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue")
checkAnswer(
sql("SELECT * FROM hiveTableWithMapValue"),
rowRDD.collect().toSeq)
sql("DROP TABLE hiveTableWithMapValue")
}
test("Insert StructType.fields.exists(_.nullable == false)") {
val schema = StructType(Seq(
StructField("s", StructType(Seq(StructField("f", StringType, nullable = false))))))
val rowRDD = TestHive.sparkContext.parallelize(
(1 to 100).map(i => Row(Row(s"value$i"))))
val schemaRDD = applySchema(rowRDD, schema)
schemaRDD.registerTempTable("tableWithStructValue")
sql("CREATE TABLE hiveTableWithStructValue(s Struct <f: STRING>)")
sql("INSERT OVERWRITE TABLE hiveTableWithStructValue SELECT s FROM tableWithStructValue")
checkAnswer(
sql("SELECT * FROM hiveTableWithStructValue"),
rowRDD.collect().toSeq)
sql("DROP TABLE hiveTableWithStructValue")
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment