From 8b5b2e272f48f7ddf8aeece0205cb4a5853c364e Mon Sep 17 00:00:00 2001
From: lianhuiwang <lianhuiwang09@gmail.com>
Date: Wed, 14 Jun 2017 09:57:56 +0800
Subject: [PATCH] [SPARK-20986][SQL] Reset table's statistics after
 PruneFileSourcePartitions rule.

## What changes were proposed in this pull request?
After PruneFileSourcePartitions rule, It needs reset table's statistics because PruneFileSourcePartitions can filter some unnecessary partitions. So the statistics need to be changed.

## How was this patch tested?
add unit test.

Author: lianhuiwang <lianhuiwang09@gmail.com>

Closes #18205 from lianhuiwang/SPARK-20986.
---
 .../PruneFileSourcePartitions.scala           |  8 ++++--
 .../PruneFileSourcePartitionsSuite.scala      | 25 +++++++++++++++++++
 2 files changed, 31 insertions(+), 2 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala
index 905b8683e1..f5df1848a3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.execution.datasources
 
+import org.apache.spark.sql.catalyst.catalog.CatalogStatistics
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.planning.PhysicalOperation
 import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
@@ -59,8 +60,11 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] {
         val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq)
         val prunedFsRelation =
           fsRelation.copy(location = prunedFileIndex)(sparkSession)
-        val prunedLogicalRelation = logicalRelation.copy(relation = prunedFsRelation)
-
+        // Change table stats based on the sizeInBytes of pruned files
+        val withStats = logicalRelation.catalogTable.map(_.copy(
+          stats = Some(CatalogStatistics(sizeInBytes = BigInt(prunedFileIndex.sizeInBytes)))))
+        val prunedLogicalRelation = logicalRelation.copy(
+          relation = prunedFsRelation, catalogTable = withStats)
         // Keep partition-pruning predicates so that they are visible in physical planning
         val filterExpression = filters.reduceLeft(And)
         val filter = Filter(filterExpression, prunedLogicalRelation)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala
index f818e29555..d91f25a4da 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.hive.execution
 
 import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
@@ -66,4 +67,28 @@ class PruneFileSourcePartitionsSuite extends QueryTest with SQLTestUtils with Te
       }
     }
   }
+
+  test("SPARK-20986 Reset table's statistics after PruneFileSourcePartitions rule") {
+    withTable("tbl") {
+      spark.range(10).selectExpr("id", "id % 3 as p").write.partitionBy("p").saveAsTable("tbl")
+      sql(s"ANALYZE TABLE tbl COMPUTE STATISTICS")
+      val tableStats = spark.sessionState.catalog.getTableMetadata(TableIdentifier("tbl")).stats
+      assert(tableStats.isDefined && tableStats.get.sizeInBytes > 0, "tableStats is lost")
+
+      val df = sql("SELECT * FROM tbl WHERE p = 1")
+      val sizes1 = df.queryExecution.analyzed.collect {
+        case relation: LogicalRelation => relation.catalogTable.get.stats.get.sizeInBytes
+      }
+      assert(sizes1.size === 1, s"Size wrong for:\n ${df.queryExecution}")
+      assert(sizes1(0) == tableStats.get.sizeInBytes)
+
+      val relations = df.queryExecution.optimizedPlan.collect {
+        case relation: LogicalRelation => relation
+      }
+      assert(relations.size === 1, s"Size wrong for:\n ${df.queryExecution}")
+      val size2 = relations(0).computeStats(conf).sizeInBytes
+      assert(size2 == relations(0).catalogTable.get.stats.get.sizeInBytes)
+      assert(size2 < tableStats.get.sizeInBytes)
+    }
+  }
 }
-- 
GitLab