From 3b5ccb12b8d33d99df0f206fecf00f51c2b88fdb Mon Sep 17 00:00:00 2001
From: Wenchen Fan <wenchen@databricks.com>
Date: Fri, 15 Jan 2016 17:20:01 -0800
Subject: [PATCH] [SPARK-12649][SQL] support reading bucketed table

This PR adds the support to read bucketed tables, and correctly populate `outputPartitioning`, so that we can avoid shuffle for some cases.

TODO(follow-up PRs):

* bucket pruning
* avoid shuffle for bucketed table join when use any super-set of the bucketing key.
 (we should re-visit it after https://issues.apache.org/jira/browse/SPARK-12704 is fixed)
* recognize hive bucketed table

Author: Wenchen Fan <wenchen@databricks.com>

Closes #10604 from cloud-fan/bucket-read.
---
 .../apache/spark/sql/DataFrameReader.scala    |   1 +
 .../scala/org/apache/spark/sql/SQLConf.scala  |   6 +
 .../spark/sql/execution/ExistingRDD.scala     |  28 ++-
 .../InsertIntoHadoopFsRelation.scala          |   2 +-
 .../datasources/ResolvedDataSource.scala      |   4 +-
 .../datasources/WriterContainer.scala         |   2 +-
 .../sql/execution/datasources/bucket.scala    |  21 ++-
 .../spark/sql/execution/datasources/ddl.scala |   2 +-
 .../datasources/json/JSONRelation.scala       |   4 +-
 .../datasources/parquet/ParquetRelation.scala |   2 +-
 .../sql/execution/datasources/rules.scala     |   1 +
 .../apache/spark/sql/sources/interfaces.scala |  55 +++++-
 .../datasources/json/JsonSuite.scala          |   2 +
 .../spark/sql/hive/HiveMetastoreCatalog.scala |  26 +--
 .../spark/sql/hive/execution/commands.scala   |   7 +-
 .../spark/sql/hive/orc/OrcRelation.scala      |   2 +-
 .../spark/sql/sources/BucketedReadSuite.scala | 178 ++++++++++++++++++
 .../sql/sources/BucketedWriteSuite.scala      |  16 +-
 18 files changed, 314 insertions(+), 45 deletions(-)
 create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 8f852e5216..634c1bd473 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -109,6 +109,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
       sqlContext,
       userSpecifiedSchema = userSpecifiedSchema,
       partitionColumns = Array.empty[String],
+      bucketSpec = None,
       provider = source,
       options = extraOptions.toMap)
     DataFrame(sqlContext, LogicalRelation(resolved.relation))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 7976795ff5..4e3662724c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -422,6 +422,10 @@ private[spark] object SQLConf {
       doc = "The maximum number of concurrent files to open before falling back on sorting when " +
             "writing out files using dynamic partitioning.")
 
+  val BUCKETING_ENABLED = booleanConf("spark.sql.sources.bucketing.enabled",
+    defaultValue = Some(true),
+    doc = "When false, we will treat bucketed table as normal table")
+
   // The output committer class used by HadoopFsRelation. The specified class needs to be a
   // subclass of org.apache.hadoop.mapreduce.OutputCommitter.
   //
@@ -590,6 +594,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with ParserCon
   private[spark] def parallelPartitionDiscoveryThreshold: Int =
     getConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD)
 
+  private[spark] def bucketingEnabled(): Boolean = getConf(SQLConf.BUCKETING_ENABLED)
+
   // Do not use a value larger than 4000 as the default value of this property.
   // See the comments of SCHEMA_STRING_LENGTH_THRESHOLD above for more information.
   private[spark] def schemaStringLengthThreshold: Int = getConf(SCHEMA_STRING_LENGTH_THRESHOLD)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index 569a21feaa..92cfd5f841 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -18,11 +18,12 @@
 package org.apache.spark.sql.execution
 
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.{AnalysisException, Row, SQLContext}
 import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
 import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, GenericMutableRow, UnsafeProjection}
+import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
+import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning}
 import org.apache.spark.sql.sources.{BaseRelation, HadoopFsRelation}
 import org.apache.spark.sql.types.DataType
 
@@ -98,7 +99,8 @@ private[sql] case class PhysicalRDD(
     rdd: RDD[InternalRow],
     override val nodeName: String,
     override val metadata: Map[String, String] = Map.empty,
-    isUnsafeRow: Boolean = false)
+    isUnsafeRow: Boolean = false,
+    override val outputPartitioning: Partitioning = UnknownPartitioning(0))
   extends LeafNode {
 
   protected override def doExecute(): RDD[InternalRow] = {
@@ -130,6 +132,24 @@ private[sql] object PhysicalRDD {
       metadata: Map[String, String] = Map.empty): PhysicalRDD = {
     // All HadoopFsRelations output UnsafeRows
     val outputUnsafeRows = relation.isInstanceOf[HadoopFsRelation]
-    PhysicalRDD(output, rdd, relation.toString, metadata, outputUnsafeRows)
+
+    val bucketSpec = relation match {
+      case r: HadoopFsRelation => r.getBucketSpec
+      case _ => None
+    }
+
+    def toAttribute(colName: String): Attribute = output.find(_.name == colName).getOrElse {
+      throw new AnalysisException(s"bucket column $colName not found in existing columns " +
+        s"(${output.map(_.name).mkString(", ")})")
+    }
+
+    bucketSpec.map { spec =>
+      val numBuckets = spec.numBuckets
+      val bucketColumns = spec.bucketColumnNames.map(toAttribute)
+      val partitioning = HashPartitioning(bucketColumns, numBuckets)
+      PhysicalRDD(output, rdd, relation.toString, metadata, outputUnsafeRows, partitioning)
+    }.getOrElse {
+      PhysicalRDD(output, rdd, relation.toString, metadata, outputUnsafeRows)
+    }
   }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala
index 7a8691e7cb..314c957d57 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala
@@ -125,7 +125,7 @@ private[sql] case class InsertIntoHadoopFsRelation(
              |Actual: ${partitionColumns.mkString(", ")}
           """.stripMargin)
 
-        val writerContainer = if (partitionColumns.isEmpty && relation.bucketSpec.isEmpty) {
+        val writerContainer = if (partitionColumns.isEmpty && relation.getBucketSpec.isEmpty) {
           new DefaultWriterContainer(relation, job, isAppend)
         } else {
           val output = df.queryExecution.executedPlan.output
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala
index ece9b8a9a9..cc8dcf5930 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala
@@ -97,6 +97,7 @@ object ResolvedDataSource extends Logging {
       sqlContext: SQLContext,
       userSpecifiedSchema: Option[StructType],
       partitionColumns: Array[String],
+      bucketSpec: Option[BucketSpec],
       provider: String,
       options: Map[String, String]): ResolvedDataSource = {
     val clazz: Class[_] = lookupDataSource(provider)
@@ -142,6 +143,7 @@ object ResolvedDataSource extends Logging {
             paths,
             Some(dataSchema),
             maybePartitionsSchema,
+            bucketSpec,
             caseInsensitiveOptions)
         case dataSource: org.apache.spark.sql.sources.RelationProvider =>
           throw new AnalysisException(s"$className does not allow user-specified schemas.")
@@ -173,7 +175,7 @@ object ResolvedDataSource extends Logging {
                 SparkHadoopUtil.get.globPathIfNecessary(qualified).map(_.toString)
               }
           }
-          dataSource.createRelation(sqlContext, paths, None, None, caseInsensitiveOptions)
+          dataSource.createRelation(sqlContext, paths, None, None, None, caseInsensitiveOptions)
         case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider =>
           throw new AnalysisException(
             s"A schema needs to be specified when using $className.")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
index fc77529b7d..563fd9eefc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
@@ -311,7 +311,7 @@ private[sql] class DynamicPartitionWriterContainer(
     isAppend: Boolean)
   extends BaseWriterContainer(relation, job, isAppend) {
 
-  private val bucketSpec = relation.bucketSpec
+  private val bucketSpec = relation.getBucketSpec
 
   private val bucketColumns: Seq[Attribute] = bucketSpec.toSeq.flatMap {
     spec => spec.bucketColumnNames.map(c => inputSchema.find(_.name == c).get)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala
index 9976829638..c7ecd6125d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala
@@ -44,9 +44,7 @@ private[sql] trait BucketedHadoopFsRelationProvider extends HadoopFsRelationProv
       dataSchema: Option[StructType],
       partitionColumns: Option[StructType],
       parameters: Map[String, String]): HadoopFsRelation =
-    // TODO: throw exception here as we won't call this method during execution, after bucketed read
-    // support is finished.
-    createRelation(sqlContext, paths, dataSchema, partitionColumns, bucketSpec = None, parameters)
+    throw new UnsupportedOperationException("use the overload version with bucketSpec parameter")
 }
 
 private[sql] abstract class BucketedOutputWriterFactory extends OutputWriterFactory {
@@ -54,5 +52,20 @@ private[sql] abstract class BucketedOutputWriterFactory extends OutputWriterFact
       path: String,
       dataSchema: StructType,
       context: TaskAttemptContext): OutputWriter =
-    throw new UnsupportedOperationException("use bucket version")
+    throw new UnsupportedOperationException("use the overload version with bucketSpec parameter")
+}
+
+private[sql] object BucketingUtils {
+  // The file name of bucketed data should have 3 parts:
+  //   1. some other information in the head of file name, ends with `-`
+  //   2. bucket id part, some numbers
+  //   3. optional file extension part, in the tail of file name, starts with `.`
+  // An example of bucketed parquet file name with bucket id 3:
+  //   part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb-00003.gz.parquet
+  private val bucketedFileName = """.*-(\d+)(?:\..*)?$""".r
+
+  def getBucketId(fileName: String): Option[Int] = fileName match {
+    case bucketedFileName(bucketId) => Some(bucketId.toInt)
+    case other => None
+  }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
index 0897fcadbc..c3603936df 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
@@ -91,7 +91,7 @@ case class CreateTempTableUsing(
 
   def run(sqlContext: SQLContext): Seq[Row] = {
     val resolved = ResolvedDataSource(
-      sqlContext, userSpecifiedSchema, Array.empty[String], provider, options)
+      sqlContext, userSpecifiedSchema, Array.empty[String], bucketSpec = None, provider, options)
     sqlContext.catalog.registerTable(
       tableIdent,
       DataFrame(sqlContext, LogicalRelation(resolved.relation)).logicalPlan)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala
index 8a6fa4aeeb..20c60b9c43 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala
@@ -57,7 +57,7 @@ class DefaultSource extends BucketedHadoopFsRelationProvider with DataSourceRegi
       maybeDataSchema = dataSchema,
       maybePartitionSpec = None,
       userDefinedPartitionColumns = partitionColumns,
-      bucketSpec = bucketSpec,
+      maybeBucketSpec = bucketSpec,
       paths = paths,
       parameters = parameters)(sqlContext)
   }
@@ -68,7 +68,7 @@ private[sql] class JSONRelation(
     val maybeDataSchema: Option[StructType],
     val maybePartitionSpec: Option[PartitionSpec],
     override val userDefinedPartitionColumns: Option[StructType],
-    override val bucketSpec: Option[BucketSpec] = None,
+    override val maybeBucketSpec: Option[BucketSpec] = None,
     override val paths: Array[String] = Array.empty[String],
     parameters: Map[String, String] = Map.empty[String, String])
     (@transient val sqlContext: SQLContext)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
index 991a5d5aef..30ddec686c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
@@ -112,7 +112,7 @@ private[sql] class ParquetRelation(
     // This is for metastore conversion.
     private val maybePartitionSpec: Option[PartitionSpec],
     override val userDefinedPartitionColumns: Option[StructType],
-    override val bucketSpec: Option[BucketSpec],
+    override val maybeBucketSpec: Option[BucketSpec],
     parameters: Map[String, String])(
     val sqlContext: SQLContext)
   extends HadoopFsRelation(maybePartitionSpec, parameters)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
index dd3e66d8a9..9358c9c37b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
@@ -36,6 +36,7 @@ private[sql] class ResolveDataSource(sqlContext: SQLContext) extends Rule[Logica
           sqlContext,
           userSpecifiedSchema = None,
           partitionColumns = Array(),
+          bucketSpec = None,
           provider = u.tableIdentifier.database.get,
           options = Map("path" -> u.tableIdentifier.table))
         val plan = LogicalRelation(resolved.relation)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
index 9f3607369c..7800776fa1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -28,13 +28,13 @@ import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
 import org.apache.spark.{Logging, SparkContext}
 import org.apache.spark.annotation.{DeveloperApi, Experimental}
 import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.{RDD, UnionRDD}
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
 import org.apache.spark.sql.execution.{FileRelation, RDDConversions}
-import org.apache.spark.sql.execution.datasources.{BucketSpec, Partition, PartitioningUtils, PartitionSpec}
+import org.apache.spark.sql.execution.datasources._
 import org.apache.spark.sql.types.{StringType, StructType}
 import org.apache.spark.util.SerializableConfiguration
 
@@ -458,7 +458,12 @@ abstract class HadoopFsRelation private[sql](
 
   private var _partitionSpec: PartitionSpec = _
 
-  private[sql] def bucketSpec: Option[BucketSpec] = None
+  private[this] var malformedBucketFile = false
+
+  private[sql] def maybeBucketSpec: Option[BucketSpec] = None
+
+  final private[sql] def getBucketSpec: Option[BucketSpec] =
+    maybeBucketSpec.filter(_ => sqlContext.conf.bucketingEnabled() && !malformedBucketFile)
 
   private class FileStatusCache {
     var leafFiles = mutable.LinkedHashMap.empty[Path, FileStatus]
@@ -664,6 +669,35 @@ abstract class HadoopFsRelation private[sql](
     })
   }
 
+  /**
+   * Groups the input files by bucket id, if bucketing is enabled and this data source is bucketed.
+   * Returns None if there exists any malformed bucket files.
+   */
+  private def groupBucketFiles(
+      files: Array[FileStatus]): Option[scala.collection.Map[Int, Array[FileStatus]]] = {
+    malformedBucketFile = false
+    if (getBucketSpec.isDefined) {
+      val groupedBucketFiles = mutable.HashMap.empty[Int, mutable.ArrayBuffer[FileStatus]]
+      var i = 0
+      while (!malformedBucketFile && i < files.length) {
+        val bucketId = BucketingUtils.getBucketId(files(i).getPath.getName)
+        if (bucketId.isEmpty) {
+          logError(s"File ${files(i).getPath} is expected to be a bucket file, but there is no " +
+            "bucket id information in file name. Fall back to non-bucketing mode.")
+          malformedBucketFile = true
+        } else {
+          val bucketFiles =
+            groupedBucketFiles.getOrElseUpdate(bucketId.get, mutable.ArrayBuffer.empty)
+          bucketFiles += files(i)
+        }
+        i += 1
+      }
+      if (malformedBucketFile) None else Some(groupedBucketFiles.mapValues(_.toArray))
+    } else {
+      None
+    }
+  }
+
   final private[sql] def buildInternalScan(
       requiredColumns: Array[String],
       filters: Array[Filter],
@@ -683,7 +717,20 @@ abstract class HadoopFsRelation private[sql](
       }
     }
 
-    buildInternalScan(requiredColumns, filters, inputStatuses, broadcastedConf)
+    groupBucketFiles(inputStatuses).map { groupedBucketFiles =>
+      // For each bucket id, firstly we get all files belong to this bucket, by detecting bucket
+      // id from file name. Then read these files into a RDD(use one-partition empty RDD for empty
+      // bucket), and coalesce it to one partition. Finally union all bucket RDDs to one result.
+      val perBucketRows = (0 until maybeBucketSpec.get.numBuckets).map { bucketId =>
+        groupedBucketFiles.get(bucketId).map { inputStatuses =>
+          buildInternalScan(requiredColumns, filters, inputStatuses, broadcastedConf).coalesce(1)
+        }.getOrElse(sqlContext.emptyResult)
+      }
+
+      new UnionRDD(sqlContext.sparkContext, perBucketRows)
+    }.getOrElse {
+      buildInternalScan(requiredColumns, filters, inputStatuses, broadcastedConf)
+    }
   }
 
   /**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index e70eb2a060..8de8ba355e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -1223,6 +1223,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
         sqlContext,
         userSpecifiedSchema = None,
         partitionColumns = Array.empty[String],
+        bucketSpec = None,
         provider = classOf[DefaultSource].getCanonicalName,
         options = Map("path" -> path))
 
@@ -1230,6 +1231,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
         sqlContext,
         userSpecifiedSchema = None,
         partitionColumns = Array.empty[String],
+        bucketSpec = None,
         provider = classOf[DefaultSource].getCanonicalName,
         options = Map("path" -> path))
       assert(d1 === d2)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 3d54048c24..0cfe03ba91 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -143,19 +143,16 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
           }
         }
 
-        def partColsFromParts: Option[Seq[String]] = {
-          table.properties.get("spark.sql.sources.schema.numPartCols").map { numPartCols =>
-            (0 until numPartCols.toInt).map { index =>
-              val partCol = table.properties.get(s"spark.sql.sources.schema.partCol.$index").orNull
-              if (partCol == null) {
+        def getColumnNames(colType: String): Seq[String] = {
+          table.properties.get(s"spark.sql.sources.schema.num${colType.capitalize}Cols").map {
+            numCols => (0 until numCols.toInt).map { index =>
+              table.properties.get(s"spark.sql.sources.schema.${colType}Col.$index").getOrElse {
                 throw new AnalysisException(
-                  "Could not read partitioned columns from the metastore because it is corrupted " +
-                    s"(missing part $index of the it, $numPartCols parts are expected).")
+                  s"Could not read $colType columns from the metastore because it is corrupted " +
+                    s"(missing part $index of it, $numCols parts are expected).")
               }
-
-              partCol
             }
-          }
+          }.getOrElse(Nil)
         }
 
         // Originally, we used spark.sql.sources.schema to store the schema of a data source table.
@@ -170,7 +167,11 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
         // We only need names at here since userSpecifiedSchema we loaded from the metastore
         // contains partition columns. We can always get datatypes of partitioning columns
         // from userSpecifiedSchema.
-        val partitionColumns = partColsFromParts.getOrElse(Nil)
+        val partitionColumns = getColumnNames("part")
+
+        val bucketSpec = table.properties.get("spark.sql.sources.schema.numBuckets").map { n =>
+          BucketSpec(n.toInt, getColumnNames("bucket"), getColumnNames("sort"))
+        }
 
         // It does not appear that the ql client for the metastore has a way to enumerate all the
         // SerDe properties directly...
@@ -181,6 +182,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
             hive,
             userSpecifiedSchema,
             partitionColumns.toArray,
+            bucketSpec,
             table.properties("spark.sql.sources.provider"),
             options)
 
@@ -282,7 +284,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
 
     val maybeSerDe = HiveSerDe.sourceToSerDe(provider, hive.hiveconf)
     val dataSource = ResolvedDataSource(
-      hive, userSpecifiedSchema, partitionColumns, provider, options)
+      hive, userSpecifiedSchema, partitionColumns, bucketSpec, provider, options)
 
     def newSparkSQLSpecificMetastoreTable(): HiveTable = {
       HiveTable(
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
index 07a352873d..e703ac0164 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
@@ -213,7 +213,12 @@ case class CreateMetastoreDataSourceAsSelect(
         case SaveMode.Append =>
           // Check if the specified data source match the data source of the existing table.
           val resolved = ResolvedDataSource(
-            sqlContext, Some(query.schema.asNullable), partitionColumns, provider, optionsWithPath)
+            sqlContext,
+            Some(query.schema.asNullable),
+            partitionColumns,
+            bucketSpec,
+            provider,
+            optionsWithPath)
           val createdRelation = LogicalRelation(resolved.relation)
           EliminateSubQueries(sqlContext.catalog.lookupRelation(tableIdent)) match {
             case l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation, _, _) =>
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
index 14fa152c23..40409169b0 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
@@ -156,7 +156,7 @@ private[sql] class OrcRelation(
     maybeDataSchema: Option[StructType],
     maybePartitionSpec: Option[PartitionSpec],
     override val userDefinedPartitionColumns: Option[StructType],
-    override val bucketSpec: Option[BucketSpec],
+    override val maybeBucketSpec: Option[BucketSpec],
     parameters: Map[String, String])(
     @transient val sqlContext: SQLContext)
   extends HadoopFsRelation(maybePartitionSpec, parameters)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
new file mode 100644
index 0000000000..58ecdd3b80
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources
+
+import java.io.File
+
+import org.apache.spark.sql.{Column, DataFrame, DataFrameWriter, QueryTest, SQLConf}
+import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
+import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
+import org.apache.spark.sql.execution.Exchange
+import org.apache.spark.sql.execution.joins.SortMergeJoin
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.util.Utils
+
+class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
+  import testImplicits._
+
+  test("read bucketed data") {
+    val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k")
+    withTable("bucketed_table") {
+      df.write
+        .format("parquet")
+        .partitionBy("i")
+        .bucketBy(8, "j", "k")
+        .saveAsTable("bucketed_table")
+
+      for (i <- 0 until 5) {
+        val rdd = hiveContext.table("bucketed_table").filter($"i" === i).queryExecution.toRdd
+        assert(rdd.partitions.length == 8)
+
+        val attrs = df.select("j", "k").schema.toAttributes
+        val checkBucketId = rdd.mapPartitionsWithIndex((index, rows) => {
+          val getBucketId = UnsafeProjection.create(
+            HashPartitioning(attrs, 8).partitionIdExpression :: Nil,
+            attrs)
+          rows.map(row => getBucketId(row).getInt(0) == index)
+        })
+
+        assert(checkBucketId.collect().reduce(_ && _))
+      }
+    }
+  }
+
+  private val df1 = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k").as("df1")
+  private val df2 = (0 until 50).map(i => (i % 7, i % 11, i.toString)).toDF("i", "j", "k").as("df2")
+
+  private def testBucketing(
+      bucketing1: DataFrameWriter => DataFrameWriter,
+      bucketing2: DataFrameWriter => DataFrameWriter,
+      joinColumns: Seq[String],
+      shuffleLeft: Boolean,
+      shuffleRight: Boolean): Unit = {
+    withTable("bucketed_table1", "bucketed_table2") {
+      bucketing1(df1.write.format("parquet")).saveAsTable("bucketed_table1")
+      bucketing2(df2.write.format("parquet")).saveAsTable("bucketed_table2")
+
+      withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
+        val t1 = hiveContext.table("bucketed_table1")
+        val t2 = hiveContext.table("bucketed_table2")
+        val joined = t1.join(t2, joinCondition(t1, t2, joinColumns))
+
+        // First check the result is corrected.
+        checkAnswer(
+          joined.sort("bucketed_table1.k", "bucketed_table2.k"),
+          df1.join(df2, joinCondition(df1, df2, joinColumns)).sort("df1.k", "df2.k"))
+
+        assert(joined.queryExecution.executedPlan.isInstanceOf[SortMergeJoin])
+        val joinOperator = joined.queryExecution.executedPlan.asInstanceOf[SortMergeJoin]
+
+        assert(joinOperator.left.find(_.isInstanceOf[Exchange]).isDefined == shuffleLeft)
+        assert(joinOperator.right.find(_.isInstanceOf[Exchange]).isDefined == shuffleRight)
+      }
+    }
+  }
+
+  private def joinCondition(left: DataFrame, right: DataFrame, joinCols: Seq[String]): Column = {
+    joinCols.map(col => left(col) === right(col)).reduce(_ && _)
+  }
+
+  test("avoid shuffle when join 2 bucketed tables") {
+    val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j")
+    testBucketing(bucketing, bucketing, Seq("i", "j"), shuffleLeft = false, shuffleRight = false)
+  }
+
+  // Enable it after fix https://issues.apache.org/jira/browse/SPARK-12704
+  ignore("avoid shuffle when join keys are a super-set of bucket keys") {
+    val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i")
+    testBucketing(bucketing, bucketing, Seq("i", "j"), shuffleLeft = false, shuffleRight = false)
+  }
+
+  test("only shuffle one side when join bucketed table and non-bucketed table") {
+    val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j")
+    testBucketing(bucketing, identity, Seq("i", "j"), shuffleLeft = false, shuffleRight = true)
+  }
+
+  test("only shuffle one side when 2 bucketed tables have different bucket number") {
+    val bucketing1 = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j")
+    val bucketing2 = (writer: DataFrameWriter) => writer.bucketBy(5, "i", "j")
+    testBucketing(bucketing1, bucketing2, Seq("i", "j"), shuffleLeft = false, shuffleRight = true)
+  }
+
+  test("only shuffle one side when 2 bucketed tables have different bucket keys") {
+    val bucketing1 = (writer: DataFrameWriter) => writer.bucketBy(8, "i")
+    val bucketing2 = (writer: DataFrameWriter) => writer.bucketBy(8, "j")
+    testBucketing(bucketing1, bucketing2, Seq("i"), shuffleLeft = false, shuffleRight = true)
+  }
+
+  test("shuffle when join keys are not equal to bucket keys") {
+    val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i")
+    testBucketing(bucketing, bucketing, Seq("j"), shuffleLeft = true, shuffleRight = true)
+  }
+
+  test("shuffle when join 2 bucketed tables with bucketing disabled") {
+    val bucketing = (writer: DataFrameWriter) => writer.bucketBy(8, "i", "j")
+    withSQLConf(SQLConf.BUCKETING_ENABLED.key -> "false") {
+      testBucketing(bucketing, bucketing, Seq("i", "j"), shuffleLeft = true, shuffleRight = true)
+    }
+  }
+
+  test("avoid shuffle when grouping keys are equal to bucket keys") {
+    withTable("bucketed_table") {
+      df1.write.format("parquet").bucketBy(8, "i", "j").saveAsTable("bucketed_table")
+      val tbl = hiveContext.table("bucketed_table")
+      val agged = tbl.groupBy("i", "j").agg(max("k"))
+
+      checkAnswer(
+        agged.sort("i", "j"),
+        df1.groupBy("i", "j").agg(max("k")).sort("i", "j"))
+
+      assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[Exchange]).isEmpty)
+    }
+  }
+
+  test("avoid shuffle when grouping keys are a super-set of bucket keys") {
+    withTable("bucketed_table") {
+      df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table")
+      val tbl = hiveContext.table("bucketed_table")
+      val agged = tbl.groupBy("i", "j").agg(max("k"))
+
+      checkAnswer(
+        agged.sort("i", "j"),
+        df1.groupBy("i", "j").agg(max("k")).sort("i", "j"))
+
+      assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[Exchange]).isEmpty)
+    }
+  }
+
+  test("fallback to non-bucketing mode if there exists any malformed bucket files") {
+    withTable("bucketed_table") {
+      df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table")
+      val tableDir = new File(hiveContext.warehousePath, "bucketed_table")
+      Utils.deleteRecursively(tableDir)
+      df1.write.parquet(tableDir.getAbsolutePath)
+
+      val agged = hiveContext.table("bucketed_table").groupBy("i").count()
+      // make sure we fall back to non-bucketing mode and can't avoid shuffle
+      assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[Exchange]).isDefined)
+      checkAnswer(agged.sort("i"), df1.groupBy("i").count().sort("i"))
+    }
+  }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
index 3ea9826544..e812439bed 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
@@ -22,6 +22,7 @@ import java.io.File
 import org.apache.spark.sql.{AnalysisException, QueryTest}
 import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
 import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
+import org.apache.spark.sql.execution.datasources.BucketingUtils
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.hive.test.TestHiveSingleton
 import org.apache.spark.sql.test.SQLTestUtils
@@ -62,15 +63,6 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
     intercept[IllegalArgumentException](df.write.bucketBy(2, "i").insertInto("tt"))
   }
 
-  private val testFileName = """.*-(\d+)$""".r
-  private val otherFileName = """.*-(\d+)\..*""".r
-  private def getBucketId(fileName: String): Int = {
-    fileName match {
-      case testFileName(bucketId) => bucketId.toInt
-      case otherFileName(bucketId) => bucketId.toInt
-    }
-  }
-
   private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k")
 
   private def testBucketing(
@@ -81,7 +73,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
     val allBucketFiles = dataDir.listFiles().filterNot(f =>
       f.getName.startsWith(".") || f.getName.startsWith("_")
     )
-    val groupedBucketFiles = allBucketFiles.groupBy(f => getBucketId(f.getName))
+    val groupedBucketFiles = allBucketFiles.groupBy(f => BucketingUtils.getBucketId(f.getName).get)
     assert(groupedBucketFiles.size <= 8)
 
     for ((bucketId, bucketFiles) <- groupedBucketFiles) {
@@ -98,12 +90,12 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
 
         val qe = readBack.select(bucketCols.map(col): _*).queryExecution
         val rows = qe.toRdd.map(_.copy()).collect()
-        val getHashCode = UnsafeProjection.create(
+        val getBucketId = UnsafeProjection.create(
           HashPartitioning(qe.analyzed.output, 8).partitionIdExpression :: Nil,
           qe.analyzed.output)
 
         for (row <- rows) {
-          val actualBucketId = getHashCode(row).getInt(0)
+          val actualBucketId = getBucketId(row).getInt(0)
           assert(actualBucketId == bucketId)
         }
       }
-- 
GitLab