diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index 905b8683e10bdeedf91103312b959c2615b7c35b..f5df1848a38c4bb41c558f034c061065780a1f8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources +import org.apache.spark.sql.catalyst.catalog.CatalogStatistics import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} @@ -59,8 +60,11 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq) val prunedFsRelation = fsRelation.copy(location = prunedFileIndex)(sparkSession) - val prunedLogicalRelation = logicalRelation.copy(relation = prunedFsRelation) - + // Change table stats based on the sizeInBytes of pruned files + val withStats = logicalRelation.catalogTable.map(_.copy( + stats = Some(CatalogStatistics(sizeInBytes = BigInt(prunedFileIndex.sizeInBytes))))) + val prunedLogicalRelation = logicalRelation.copy( + relation = prunedFsRelation, catalogTable = withStats) // Keep partition-pruning predicates so that they are visible in physical planning val filterExpression = filters.reduceLeft(And) val filter = Filter(filterExpression, prunedLogicalRelation) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala index f818e29555468000ba2ecf24689ba8d6d0558f31..d91f25a4da013883397a9c9929f74e3aef6d5671 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} @@ -66,4 +67,28 @@ class PruneFileSourcePartitionsSuite extends QueryTest with SQLTestUtils with Te } } } + + test("SPARK-20986 Reset table's statistics after PruneFileSourcePartitions rule") { + withTable("tbl") { + spark.range(10).selectExpr("id", "id % 3 as p").write.partitionBy("p").saveAsTable("tbl") + sql(s"ANALYZE TABLE tbl COMPUTE STATISTICS") + val tableStats = spark.sessionState.catalog.getTableMetadata(TableIdentifier("tbl")).stats + assert(tableStats.isDefined && tableStats.get.sizeInBytes > 0, "tableStats is lost") + + val df = sql("SELECT * FROM tbl WHERE p = 1") + val sizes1 = df.queryExecution.analyzed.collect { + case relation: LogicalRelation => relation.catalogTable.get.stats.get.sizeInBytes + } + assert(sizes1.size === 1, s"Size wrong for:\n ${df.queryExecution}") + assert(sizes1(0) == tableStats.get.sizeInBytes) + + val relations = df.queryExecution.optimizedPlan.collect { + case relation: LogicalRelation => relation + } + assert(relations.size === 1, s"Size wrong for:\n ${df.queryExecution}") + val size2 = relations(0).computeStats(conf).sizeInBytes + assert(size2 == relations(0).catalogTable.get.stats.get.sizeInBytes) + assert(size2 < tableStats.get.sizeInBytes) + } + } }