diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index 3c59b982c2dca7f21d0d52fe3c442715c347c60e..06e588f56f1e98ad4e001bb8a7703751a52136b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.command +import java.net.URI + import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} @@ -81,6 +83,21 @@ case class AnalyzeTableCommand( object AnalyzeTableCommand extends Logging { def calculateTotalSize(sessionState: SessionState, catalogTable: CatalogTable): Long = { + if (catalogTable.partitionColumnNames.isEmpty) { + calculateLocationSize(sessionState, catalogTable.identifier, catalogTable.storage.locationUri) + } else { + // Calculate table size as a sum of the visible partitions. See SPARK-21079 + val partitions = sessionState.catalog.listPartitions(catalogTable.identifier) + partitions.map(p => + calculateLocationSize(sessionState, catalogTable.identifier, p.storage.locationUri) + ).sum + } + } + + private def calculateLocationSize( + sessionState: SessionState, + tableId: TableIdentifier, + locationUri: Option[URI]): Long = { // This method is mainly based on // org.apache.hadoop.hive.ql.stats.StatsUtils.getFileSizeForTable(HiveConf, Table) // in Hive 0.13 (except that we do not use fs.getContentSummary). @@ -91,13 +108,13 @@ object AnalyzeTableCommand extends Logging { // countFileSize to count the table size. val stagingDir = sessionState.conf.getConfString("hive.exec.stagingdir", ".hive-staging") - def calculateTableSize(fs: FileSystem, path: Path): Long = { + def calculateLocationSize(fs: FileSystem, path: Path): Long = { val fileStatus = fs.getFileStatus(path) val size = if (fileStatus.isDirectory) { fs.listStatus(path) .map { status => if (!status.getPath.getName.startsWith(stagingDir)) { - calculateTableSize(fs, status.getPath) + calculateLocationSize(fs, status.getPath) } else { 0L } @@ -109,16 +126,16 @@ object AnalyzeTableCommand extends Logging { size } - catalogTable.storage.locationUri.map { p => + locationUri.map { p => val path = new Path(p) try { val fs = path.getFileSystem(sessionState.newHadoopConf()) - calculateTableSize(fs, path) + calculateLocationSize(fs, path) } catch { case NonFatal(e) => logWarning( - s"Failed to get the size of table ${catalogTable.identifier.table} in the " + - s"database ${catalogTable.identifier.database} because of ${e.toString}", e) + s"Failed to get the size of table ${tableId.table} in the " + + s"database ${tableId.database} because of ${e.toString}", e) 0L } }.getOrElse(0L) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 279db9a397258e91094b4ed18aefdde30e92999b..0ee18bbe9befe372d89b0f7e524c62ee80724ec2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleton { @@ -128,6 +129,77 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto TableIdentifier("tempTable"), ignoreIfNotExists = true, purge = false) } + test("SPARK-21079 - analyze table with location different than that of individual partitions") { + def queryTotalSize(tableName: String): BigInt = + spark.table(tableName).queryExecution.analyzed.stats(conf).sizeInBytes + + val tableName = "analyzeTable_part" + withTable(tableName) { + withTempPath { path => + sql(s"CREATE TABLE $tableName (key STRING, value STRING) PARTITIONED BY (ds STRING)") + + val partitionDates = List("2010-01-01", "2010-01-02", "2010-01-03") + partitionDates.foreach { ds => + sql(s"INSERT INTO TABLE $tableName PARTITION (ds='$ds') SELECT * FROM src") + } + + sql(s"ALTER TABLE $tableName SET LOCATION '$path'") + + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan") + + assert(queryTotalSize(tableName) === BigInt(17436)) + } + } + } + + test("SPARK-21079 - analyze partitioned table with only a subset of partitions visible") { + def queryTotalSize(tableName: String): BigInt = + spark.table(tableName).queryExecution.analyzed.stats(conf).sizeInBytes + + val sourceTableName = "analyzeTable_part" + val tableName = "analyzeTable_part_vis" + withTable(sourceTableName, tableName) { + withTempPath { path => + // Create a table with 3 partitions all located under a single top-level directory 'path' + sql( + s""" + |CREATE TABLE $sourceTableName (key STRING, value STRING) + |PARTITIONED BY (ds STRING) + |LOCATION '$path' + """.stripMargin) + + val partitionDates = List("2010-01-01", "2010-01-02", "2010-01-03") + partitionDates.foreach { ds => + sql( + s""" + |INSERT INTO TABLE $sourceTableName PARTITION (ds='$ds') + |SELECT * FROM src + """.stripMargin) + } + + // Create another table referring to the same location + sql( + s""" + |CREATE TABLE $tableName (key STRING, value STRING) + |PARTITIONED BY (ds STRING) + |LOCATION '$path' + """.stripMargin) + + // Register only one of the partitions found on disk + val ds = partitionDates.head + sql(s"ALTER TABLE $tableName ADD PARTITION (ds='$ds')").collect() + + // Analyze original table - expect 3 partitions + sql(s"ANALYZE TABLE $sourceTableName COMPUTE STATISTICS noscan") + assert(queryTotalSize(sourceTableName) === BigInt(3 * 5812)) + + // Analyze partial-copy table - expect only 1 partition + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan") + assert(queryTotalSize(tableName) === BigInt(5812)) + } + } + } + test("analyzing views is not supported") { def assertAnalyzeUnsupported(analyzeCommand: String): Unit = { val err = intercept[AnalysisException] {