Skip to content
Snippets Groups Projects
Commit 8b75f8c1 authored by Wenchen Fan's avatar Wenchen Fan
Browse files

[SPARK-19587][SQL] bucket sorting columns should not be picked from partition columns

## What changes were proposed in this pull request?

We will throw an exception if bucket columns are part of partition columns, this should also apply to sort columns.

This PR also move the checking logic from `DataFrameWriter` to `PreprocessTableCreation`, which is the central place for checking and normailization.

## How was this patch tested?

updated test.

Author: Wenchen Fan <wenchen@databricks.com>

Closes #16931 from cloud-fan/bucket.
parent 733c59ec
No related branches found
No related tags found
No related merge requests found
......@@ -215,7 +215,6 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
df.sparkSession,
className = source,
partitionColumns = partitioningColumns.getOrElse(Nil),
bucketSpec = getBucketSpec,
options = extraOptions.toMap)
dataSource.write(mode, df)
......@@ -270,52 +269,17 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
ifNotExists = false)).toRdd
}
private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { cols =>
cols.map(normalize(_, "Partition"))
}
private def normalizedBucketColNames: Option[Seq[String]] = bucketColumnNames.map { cols =>
cols.map(normalize(_, "Bucketing"))
}
private def normalizedSortColNames: Option[Seq[String]] = sortColumnNames.map { cols =>
cols.map(normalize(_, "Sorting"))
}
private def getBucketSpec: Option[BucketSpec] = {
if (sortColumnNames.isDefined) {
require(numBuckets.isDefined, "sortBy must be used together with bucketBy")
}
for {
n <- numBuckets
} yield {
numBuckets.map { n =>
require(n > 0 && n < 100000, "Bucket number must be greater than 0 and less than 100000.")
// partitionBy columns cannot be used in bucketBy
if (normalizedParCols.nonEmpty &&
normalizedBucketColNames.get.toSet.intersect(normalizedParCols.get.toSet).nonEmpty) {
throw new AnalysisException(
s"bucketBy columns '${bucketColumnNames.get.mkString(", ")}' should not be part of " +
s"partitionBy columns '${partitioningColumns.get.mkString(", ")}'")
}
BucketSpec(n, normalizedBucketColNames.get, normalizedSortColNames.getOrElse(Nil))
BucketSpec(n, bucketColumnNames.get, sortColumnNames.getOrElse(Nil))
}
}
/**
* The given column name may not be equal to any of the existing column names if we were in
* case-insensitive context. Normalize the given column name to the real one so that we don't
* need to care about case sensitivity afterwards.
*/
private def normalize(columnName: String, columnType: String): String = {
val validColumnNames = df.logicalPlan.output.map(_.name)
validColumnNames.find(df.sparkSession.sessionState.analyzer.resolver(_, columnName))
.getOrElse(throw new AnalysisException(s"$columnType column $columnName not found in " +
s"existing columns (${validColumnNames.mkString(", ")})"))
}
private def assertNotBucketed(operation: String): Unit = {
if (numBuckets.isDefined || sortColumnNames.isDefined) {
throw new AnalysisException(s"'$operation' does not support bucketing right now")
......
......@@ -226,9 +226,21 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi
}
checkDuplication(columnNames, "table definition of " + table.identifier)
table.copy(
partitionColumnNames = normalizePartitionColumns(schema, table),
bucketSpec = normalizeBucketSpec(schema, table))
val normalizedPartCols = normalizePartitionColumns(schema, table)
val normalizedBucketSpec = normalizeBucketSpec(schema, table)
normalizedBucketSpec.foreach { spec =>
for (bucketCol <- spec.bucketColumnNames if normalizedPartCols.contains(bucketCol)) {
throw new AnalysisException(s"bucketing column '$bucketCol' should not be part of " +
s"partition columns '${normalizedPartCols.mkString(", ")}'")
}
for (sortCol <- spec.sortColumnNames if normalizedPartCols.contains(sortCol)) {
throw new AnalysisException(s"bucket sorting column '$sortCol' should not be part of " +
s"partition columns '${normalizedPartCols.mkString(", ")}'")
}
}
table.copy(partitionColumnNames = normalizedPartCols, bucketSpec = normalizedBucketSpec)
}
private def normalizePartitionColumns(schema: StructType, table: CatalogTable): Seq[String] = {
......
......@@ -169,19 +169,20 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
}
}
test("write bucketed data with the overlapping bucketBy and partitionBy columns") {
intercept[AnalysisException](df.write
test("write bucketed data with the overlapping bucketBy/sortBy and partitionBy columns") {
val e1 = intercept[AnalysisException](df.write
.partitionBy("i", "j")
.bucketBy(8, "j", "k")
.sortBy("k")
.saveAsTable("bucketed_table"))
}
assert(e1.message.contains("bucketing column 'j' should not be part of partition columns"))
test("write bucketed data with the identical bucketBy and partitionBy columns") {
intercept[AnalysisException](df.write
.partitionBy("i")
.bucketBy(8, "i")
val e2 = intercept[AnalysisException](df.write
.partitionBy("i", "j")
.bucketBy(8, "k")
.sortBy("i")
.saveAsTable("bucketed_table"))
assert(e2.message.contains("bucket sorting column 'i' should not be part of partition columns"))
}
test("write bucketed data without partitionBy") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment