Skip to content
Snippets Groups Projects
Commit 40419022 authored by Wenchen Fan's avatar Wenchen Fan Committed by Reynold Xin
Browse files

[SPARK-12882][SQL] simplify bucket tests and add more comments

Right now, the bucket tests are kind of hard to understand, this PR simplifies them and add more commetns.

Author: Wenchen Fan <wenchen@databricks.com>

Closes #10813 from cloud-fan/bucket-comment.
parent 4f11e3f2
No related branches found
No related tags found
No related merge requests found
...@@ -23,6 +23,7 @@ import org.apache.spark.sql.{Column, DataFrame, DataFrameWriter, QueryTest, SQLC ...@@ -23,6 +23,7 @@ import org.apache.spark.sql.{Column, DataFrame, DataFrameWriter, QueryTest, SQLC
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.execution.Exchange import org.apache.spark.sql.execution.Exchange
import org.apache.spark.sql.execution.datasources.BucketSpec
import org.apache.spark.sql.execution.joins.SortMergeJoin import org.apache.spark.sql.execution.joins.SortMergeJoin
import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.hive.test.TestHiveSingleton
...@@ -61,15 +62,30 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet ...@@ -61,15 +62,30 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
private val df1 = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k").as("df1") private val df1 = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k").as("df1")
private val df2 = (0 until 50).map(i => (i % 7, i % 11, i.toString)).toDF("i", "j", "k").as("df2") private val df2 = (0 until 50).map(i => (i % 7, i % 11, i.toString)).toDF("i", "j", "k").as("df2")
/**
* A helper method to test the bucket read functionality using join. It will save `df1` and `df2`
* to hive tables, bucketed or not, according to the given bucket specifics. Next we will join
* these 2 tables, and firstly make sure the answer is corrected, and then check if the shuffle
* exists as user expected according to the `shuffleLeft` and `shuffleRight`.
*/
private def testBucketing( private def testBucketing(
bucketing1: DataFrameWriter => DataFrameWriter, bucketSpecLeft: Option[BucketSpec],
bucketing2: DataFrameWriter => DataFrameWriter, bucketSpecRight: Option[BucketSpec],
joinColumns: Seq[String], joinColumns: Seq[String],
shuffleLeft: Boolean, shuffleLeft: Boolean,
shuffleRight: Boolean): Unit = { shuffleRight: Boolean): Unit = {
withTable("bucketed_table1", "bucketed_table2") { withTable("bucketed_table1", "bucketed_table2") {
bucketing1(df1.write.format("parquet")).saveAsTable("bucketed_table1") def withBucket(writer: DataFrameWriter, bucketSpec: Option[BucketSpec]): DataFrameWriter = {
bucketing2(df2.write.format("parquet")).saveAsTable("bucketed_table2") bucketSpec.map { spec =>
writer.bucketBy(
spec.numBuckets,
spec.bucketColumnNames.head,
spec.bucketColumnNames.tail: _*)
}.getOrElse(writer)
}
withBucket(df1.write.format("parquet"), bucketSpecLeft).saveAsTable("bucketed_table1")
withBucket(df2.write.format("parquet"), bucketSpecRight).saveAsTable("bucketed_table2")
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
val t1 = hiveContext.table("bucketed_table1") val t1 = hiveContext.table("bucketed_table1")
...@@ -95,42 +111,42 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet ...@@ -95,42 +111,42 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
} }
test("avoid shuffle when join 2 bucketed tables") { test("avoid shuffle when join 2 bucketed tables") {
val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j") val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil))
testBucketing(bucketing, bucketing, Seq("i", "j"), shuffleLeft = false, shuffleRight = false) testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = false, shuffleRight = false)
} }
// Enable it after fix https://issues.apache.org/jira/browse/SPARK-12704 // Enable it after fix https://issues.apache.org/jira/browse/SPARK-12704
ignore("avoid shuffle when join keys are a super-set of bucket keys") { ignore("avoid shuffle when join keys are a super-set of bucket keys") {
val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i") val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil))
testBucketing(bucketing, bucketing, Seq("i", "j"), shuffleLeft = false, shuffleRight = false) testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = false, shuffleRight = false)
} }
test("only shuffle one side when join bucketed table and non-bucketed table") { test("only shuffle one side when join bucketed table and non-bucketed table") {
val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j") val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil))
testBucketing(bucketing, identity, Seq("i", "j"), shuffleLeft = false, shuffleRight = true) testBucketing(bucketSpec, None, Seq("i", "j"), shuffleLeft = false, shuffleRight = true)
} }
test("only shuffle one side when 2 bucketed tables have different bucket number") { test("only shuffle one side when 2 bucketed tables have different bucket number") {
val bucketing1 = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j") val bucketSpec1 = Some(BucketSpec(8, Seq("i", "j"), Nil))
val bucketing2 = (writer: DataFrameWriter) => writer.bucketBy(5, "i", "j") val bucketSpec2 = Some(BucketSpec(5, Seq("i", "j"), Nil))
testBucketing(bucketing1, bucketing2, Seq("i", "j"), shuffleLeft = false, shuffleRight = true) testBucketing(bucketSpec1, bucketSpec2, Seq("i", "j"), shuffleLeft = false, shuffleRight = true)
} }
test("only shuffle one side when 2 bucketed tables have different bucket keys") { test("only shuffle one side when 2 bucketed tables have different bucket keys") {
val bucketing1 = (writer: DataFrameWriter) => writer.bucketBy(8, "i") val bucketSpec1 = Some(BucketSpec(8, Seq("i"), Nil))
val bucketing2 = (writer: DataFrameWriter) => writer.bucketBy(8, "j") val bucketSpec2 = Some(BucketSpec(8, Seq("j"), Nil))
testBucketing(bucketing1, bucketing2, Seq("i"), shuffleLeft = false, shuffleRight = true) testBucketing(bucketSpec1, bucketSpec2, Seq("i"), shuffleLeft = false, shuffleRight = true)
} }
test("shuffle when join keys are not equal to bucket keys") { test("shuffle when join keys are not equal to bucket keys") {
val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i") val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil))
testBucketing(bucketing, bucketing, Seq("j"), shuffleLeft = true, shuffleRight = true) testBucketing(bucketSpec, bucketSpec, Seq("j"), shuffleLeft = true, shuffleRight = true)
} }
test("shuffle when join 2 bucketed tables with bucketing disabled") { test("shuffle when join 2 bucketed tables with bucketing disabled") {
val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j") val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil))
withSQLConf(SQLConf.BUCKETING_ENABLED.key -> "false") { withSQLConf(SQLConf.BUCKETING_ENABLED.key -> "false") {
testBucketing(bucketing, bucketing, Seq("i", "j"), shuffleLeft = true, shuffleRight = true) testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = true, shuffleRight = true)
} }
} }
......
...@@ -65,39 +65,55 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle ...@@ -65,39 +65,55 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k")
/**
* A helper method to check the bucket write functionality in low level, i.e. check the written
* bucket files to see if the data are correct. User should pass in a data dir that these bucket
* files are written to, and the format of data(parquet, json, etc.), and the bucketing
* information.
*/
private def testBucketing( private def testBucketing(
dataDir: File, dataDir: File,
source: String, source: String,
numBuckets: Int,
bucketCols: Seq[String], bucketCols: Seq[String],
sortCols: Seq[String] = Nil): Unit = { sortCols: Seq[String] = Nil): Unit = {
val allBucketFiles = dataDir.listFiles().filterNot(f => val allBucketFiles = dataDir.listFiles().filterNot(f =>
f.getName.startsWith(".") || f.getName.startsWith("_") f.getName.startsWith(".") || f.getName.startsWith("_")
) )
val groupedBucketFiles = allBucketFiles.groupBy(f => BucketingUtils.getBucketId(f.getName).get)
assert(groupedBucketFiles.size <= 8)
for ((bucketId, bucketFiles) <- groupedBucketFiles) {
for (bucketFilePath <- bucketFiles.map(_.getAbsolutePath)) {
val types = df.select((bucketCols ++ sortCols).map(col): _*).schema.map(_.dataType)
val columns = (bucketCols ++ sortCols).zip(types).map {
case (colName, dt) => col(colName).cast(dt)
}
val readBack = sqlContext.read.format(source).load(bucketFilePath).select(columns: _*)
if (sortCols.nonEmpty) { for (bucketFile <- allBucketFiles) {
checkAnswer(readBack.sort(sortCols.map(col): _*), readBack.collect()) val bucketId = BucketingUtils.getBucketId(bucketFile.getName).get
} assert(bucketId >= 0 && bucketId < numBuckets)
val qe = readBack.select(bucketCols.map(col): _*).queryExecution // We may loss the type information after write(e.g. json format doesn't keep schema
val rows = qe.toRdd.map(_.copy()).collect() // information), here we get the types from the original dataframe.
val getBucketId = UnsafeProjection.create( val types = df.select((bucketCols ++ sortCols).map(col): _*).schema.map(_.dataType)
HashPartitioning(qe.analyzed.output, 8).partitionIdExpression :: Nil, val columns = (bucketCols ++ sortCols).zip(types).map {
qe.analyzed.output) case (colName, dt) => col(colName).cast(dt)
}
for (row <- rows) { // Read the bucket file into a dataframe, so that it's easier to test.
val actualBucketId = getBucketId(row).getInt(0) val readBack = sqlContext.read.format(source)
assert(actualBucketId == bucketId) .load(bucketFile.getAbsolutePath)
} .select(columns: _*)
// If we specified sort columns while writing bucket table, make sure the data in this
// bucket file is already sorted.
if (sortCols.nonEmpty) {
checkAnswer(readBack.sort(sortCols.map(col): _*), readBack.collect())
}
// Go through all rows in this bucket file, calculate bucket id according to bucket column
// values, and make sure it equals to the expected bucket id that inferred from file name.
val qe = readBack.select(bucketCols.map(col): _*).queryExecution
val rows = qe.toRdd.map(_.copy()).collect()
val getBucketId = UnsafeProjection.create(
HashPartitioning(qe.analyzed.output, numBuckets).partitionIdExpression :: Nil,
qe.analyzed.output)
for (row <- rows) {
val actualBucketId = getBucketId(row).getInt(0)
assert(actualBucketId == bucketId)
} }
} }
} }
...@@ -113,7 +129,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle ...@@ -113,7 +129,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
val tableDir = new File(hiveContext.warehousePath, "bucketed_table") val tableDir = new File(hiveContext.warehousePath, "bucketed_table")
for (i <- 0 until 5) { for (i <- 0 until 5) {
testBucketing(new File(tableDir, s"i=$i"), source, Seq("j", "k")) testBucketing(new File(tableDir, s"i=$i"), source, 8, Seq("j", "k"))
} }
} }
} }
...@@ -131,7 +147,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle ...@@ -131,7 +147,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
val tableDir = new File(hiveContext.warehousePath, "bucketed_table") val tableDir = new File(hiveContext.warehousePath, "bucketed_table")
for (i <- 0 until 5) { for (i <- 0 until 5) {
testBucketing(new File(tableDir, s"i=$i"), source, Seq("j"), Seq("k")) testBucketing(new File(tableDir, s"i=$i"), source, 8, Seq("j"), Seq("k"))
} }
} }
} }
...@@ -146,7 +162,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle ...@@ -146,7 +162,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
.saveAsTable("bucketed_table") .saveAsTable("bucketed_table")
val tableDir = new File(hiveContext.warehousePath, "bucketed_table") val tableDir = new File(hiveContext.warehousePath, "bucketed_table")
testBucketing(tableDir, source, Seq("i", "j")) testBucketing(tableDir, source, 8, Seq("i", "j"))
} }
} }
} }
...@@ -161,7 +177,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle ...@@ -161,7 +177,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
.saveAsTable("bucketed_table") .saveAsTable("bucketed_table")
val tableDir = new File(hiveContext.warehousePath, "bucketed_table") val tableDir = new File(hiveContext.warehousePath, "bucketed_table")
testBucketing(tableDir, source, Seq("i", "j"), Seq("k")) testBucketing(tableDir, source, 8, Seq("i", "j"), Seq("k"))
} }
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment