From 54b4f2ad43c0ad333a3751a7f10da711b94677a0 Mon Sep 17 00:00:00 2001
From: wangzhenhua <wangzhenhua@huawei.com>
Date: Fri, 12 May 2017 20:43:22 +0800
Subject: [PATCH] [SPARK-20718][SQL][FOLLOWUP] Fix canonicalization for
 HiveTableScanExec

## What changes were proposed in this pull request?

Fix canonicalization for different filter orders in `HiveTableScanExec`.

## How was this patch tested?

Added a new test case.

Author: wangzhenhua <wangzhenhua@huawei.com>

Closes #17962 from wzhfy/canonicalizeHiveTableScanExec.
---
 .../spark/sql/catalyst/plans/QueryPlan.scala  | 15 +++++++++-
 .../sql/execution/DataSourceScanExec.scala    | 16 ++---------
 .../hive/execution/HiveTableScanExec.scala    |  2 +-
 .../hive/execution/HiveTableScanSuite.scala   | 28 ++++++++++++++-----
 4 files changed, 39 insertions(+), 22 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 2fb65bd435..51faa33330 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -423,7 +423,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
   lazy val allAttributes: AttributeSeq = children.flatMap(_.output)
 }
 
-object QueryPlan {
+object QueryPlan extends PredicateHelper {
   /**
    * Normalize the exprIds in the given expression, by updating the exprId in `AttributeReference`
    * with its referenced ordinal from input attributes. It's similar to `BindReferences` but we
@@ -442,4 +442,17 @@ object QueryPlan {
         }
     }.canonicalized.asInstanceOf[T]
   }
+
+  /**
+   * Composes the given predicates into a conjunctive predicate, which is normalized and reordered.
+   * Then returns a new sequence of predicates by splitting the conjunctive predicate.
+   */
+  def normalizePredicates(predicates: Seq[Expression], output: AttributeSeq): Seq[Expression] = {
+    if (predicates.nonEmpty) {
+      val normalized = normalizeExprId(predicates.reduce(And), output)
+      splitConjunctivePredicates(normalized)
+    } else {
+      Nil
+    }
+  }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
index 251098c9a8..74fc23a52a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
@@ -38,7 +38,7 @@ import org.apache.spark.sql.sources.BaseRelation
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.util.Utils
 
-trait DataSourceScanExec extends LeafExecNode with CodegenSupport with PredicateHelper {
+trait DataSourceScanExec extends LeafExecNode with CodegenSupport {
   val relation: BaseRelation
   val metastoreTableIdentifier: Option[TableIdentifier]
 
@@ -519,18 +519,8 @@ case class FileSourceScanExec(
       relation,
       output.map(QueryPlan.normalizeExprId(_, output)),
       requiredSchema,
-      canonicalizeFilters(partitionFilters, output),
-      canonicalizeFilters(dataFilters, output),
+      QueryPlan.normalizePredicates(partitionFilters, output),
+      QueryPlan.normalizePredicates(dataFilters, output),
       None)
   }
-
-  private def canonicalizeFilters(filters: Seq[Expression], output: Seq[Attribute])
-    : Seq[Expression] = {
-    if (filters.nonEmpty) {
-      val normalizedFilters = QueryPlan.normalizeExprId(filters.reduce(And), output)
-      splitConjunctivePredicates(normalizedFilters)
-    } else {
-      Nil
-    }
-  }
 }
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
index 666548d1a4..e191071efb 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
@@ -206,7 +206,7 @@ case class HiveTableScanExec(
     HiveTableScanExec(
       requestedAttributes.map(QueryPlan.normalizeExprId(_, input)),
       relation.canonicalized.asInstanceOf[CatalogRelation],
-      partitionPruningPred.map(QueryPlan.normalizeExprId(_, input)))(sparkSession)
+      QueryPlan.normalizePredicates(partitionPruningPred, input))(sparkSession)
   }
 
   override def otherCopyArgs: Seq[AnyRef] = Seq(sparkSession)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala
index 90e037e292..ae64cb3210 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala
@@ -164,16 +164,30 @@ class HiveTableScanSuite extends HiveComparisonTest with SQLTestUtils with TestH
              |PARTITION (p1='a',p2='c',p3='c',p4='d',p5='e')
              |SELECT v.id
            """.stripMargin)
-        val plan = sql(
-          s"""
-             |SELECT * FROM $table
-           """.stripMargin).queryExecution.sparkPlan
-        val scan = plan.collectFirst {
-          case p: HiveTableScanExec => p
-        }.get
+        val scan = getHiveTableScanExec(s"SELECT * FROM $table")
         val numDataCols = scan.relation.dataCols.length
         scan.rawPartitions.foreach(p => assert(p.getCols.size == numDataCols))
       }
     }
   }
+
+  test("HiveTableScanExec canonicalization for different orders of partition filters") {
+    val table = "hive_tbl_part"
+    withTable(table) {
+      sql(
+        s"""
+           |CREATE TABLE $table (id int)
+           |PARTITIONED BY (a int, b int)
+         """.stripMargin)
+      val scan1 = getHiveTableScanExec(s"SELECT * FROM $table WHERE a = 1 AND b = 2")
+      val scan2 = getHiveTableScanExec(s"SELECT * FROM $table WHERE b = 2 AND a = 1")
+      assert(scan1.sameResult(scan2))
+    }
+  }
+
+  private def getHiveTableScanExec(query: String): HiveTableScanExec = {
+    sql(query).queryExecution.sparkPlan.collectFirst {
+      case p: HiveTableScanExec => p
+    }.get
+  }
 }
-- 
GitLab