Skip to content
Snippets Groups Projects
Commit 2692bdb7 authored by Wenchen Fan's avatar Wenchen Fan Committed by Yin Huai
Browse files

[SPARK-11455][SQL] fix case sensitivity of partition by

depend on `caseSensitive` to do column name equality check, instead of just `==`

Author: Wenchen Fan <wenchen@databricks.com>

Closes #9410 from cloud-fan/partition.
parent e352de0d
No related branches found
No related tags found
No related merge requests found
......@@ -287,10 +287,11 @@ private[sql] object PartitioningUtils {
def validatePartitionColumnDataTypes(
schema: StructType,
partitionColumns: Array[String]): Unit = {
partitionColumns: Array[String],
caseSensitive: Boolean): Unit = {
ResolvedDataSource.partitionColumnsSchema(schema, partitionColumns).foreach { field =>
field.dataType match {
ResolvedDataSource.partitionColumnsSchema(schema, partitionColumns, caseSensitive).foreach {
field => field.dataType match {
case _: AtomicType => // OK
case _ => throw new AnalysisException(s"Cannot use ${field.dataType} for partition column")
}
......
......@@ -99,7 +99,8 @@ object ResolvedDataSource extends Logging {
val maybePartitionsSchema = if (partitionColumns.isEmpty) {
None
} else {
Some(partitionColumnsSchema(schema, partitionColumns))
Some(partitionColumnsSchema(
schema, partitionColumns, sqlContext.conf.caseSensitiveAnalysis))
}
val caseInsensitiveOptions = new CaseInsensitiveMap(options)
......@@ -172,14 +173,24 @@ object ResolvedDataSource extends Logging {
def partitionColumnsSchema(
schema: StructType,
partitionColumns: Array[String]): StructType = {
partitionColumns: Array[String],
caseSensitive: Boolean): StructType = {
val equality = columnNameEquality(caseSensitive)
StructType(partitionColumns.map { col =>
schema.find(_.name == col).getOrElse {
schema.find(f => equality(f.name, col)).getOrElse {
throw new RuntimeException(s"Partition column $col not found in schema $schema")
}
}).asNullable
}
private def columnNameEquality(caseSensitive: Boolean): (String, String) => Boolean = {
if (caseSensitive) {
org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
} else {
org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution
}
}
/** Create a [[ResolvedDataSource]] for saving the content of the given DataFrame. */
def apply(
sqlContext: SQLContext,
......@@ -207,14 +218,18 @@ object ResolvedDataSource extends Logging {
path.makeQualified(fs.getUri, fs.getWorkingDirectory)
}
PartitioningUtils.validatePartitionColumnDataTypes(data.schema, partitionColumns)
val caseSensitive = sqlContext.conf.caseSensitiveAnalysis
PartitioningUtils.validatePartitionColumnDataTypes(
data.schema, partitionColumns, caseSensitive)
val dataSchema = StructType(data.schema.filterNot(f => partitionColumns.contains(f.name)))
val equality = columnNameEquality(caseSensitive)
val dataSchema = StructType(
data.schema.filterNot(f => partitionColumns.exists(equality(_, f.name))))
val r = dataSource.createRelation(
sqlContext,
Array(outputPath.toString),
Some(dataSchema.asNullable),
Some(partitionColumnsSchema(data.schema, partitionColumns)),
Some(partitionColumnsSchema(data.schema, partitionColumns, caseSensitive)),
caseInsensitiveOptions)
// For partitioned relation r, r.schema's column ordering can be different from the column
......
......@@ -140,7 +140,8 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan =>
// OK
}
PartitioningUtils.validatePartitionColumnDataTypes(r.schema, part.keySet.toArray)
PartitioningUtils.validatePartitionColumnDataTypes(
r.schema, part.keySet.toArray, catalog.conf.caseSensitiveAnalysis)
// Get all input data source relations of the query.
val srcRelations = query.collect {
......@@ -190,7 +191,8 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan =>
// OK
}
PartitioningUtils.validatePartitionColumnDataTypes(query.schema, partitionColumns)
PartitioningUtils.validatePartitionColumnDataTypes(
query.schema, partitionColumns, catalog.conf.caseSensitiveAnalysis)
case _ => // OK
}
......
......@@ -1118,4 +1118,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
if (!allSequential) throw new SparkException("Partition should contain all sequential values")
})
}
test("fix case sensitivity of partition by") {
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
withTempPath { path =>
val p = path.getAbsolutePath
Seq(2012 -> "a").toDF("year", "val").write.partitionBy("yEAr").parquet(p)
checkAnswer(sqlContext.read.parquet(p).select("YeaR"), Row(2012))
}
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment