diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 2fb65bd43550750b69acc92288cedd2597196ed7..51faa333307b3b3ea687a3899ef1d9bb09331575 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -423,7 +423,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT lazy val allAttributes: AttributeSeq = children.flatMap(_.output) } -object QueryPlan { +object QueryPlan extends PredicateHelper { /** * Normalize the exprIds in the given expression, by updating the exprId in `AttributeReference` * with its referenced ordinal from input attributes. It's similar to `BindReferences` but we @@ -442,4 +442,17 @@ object QueryPlan { } }.canonicalized.asInstanceOf[T] } + + /** + * Composes the given predicates into a conjunctive predicate, which is normalized and reordered. + * Then returns a new sequence of predicates by splitting the conjunctive predicate. + */ + def normalizePredicates(predicates: Seq[Expression], output: AttributeSeq): Seq[Expression] = { + if (predicates.nonEmpty) { + val normalized = normalizeExprId(predicates.reduce(And), output) + splitConjunctivePredicates(normalized) + } else { + Nil + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 251098c9a884c94eee69af7ce91987a48fd5ff67..74fc23a52a141a907855e5567929d46171f43d44 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils -trait DataSourceScanExec extends LeafExecNode with CodegenSupport with PredicateHelper { +trait DataSourceScanExec extends LeafExecNode with CodegenSupport { val relation: BaseRelation val metastoreTableIdentifier: Option[TableIdentifier] @@ -519,18 +519,8 @@ case class FileSourceScanExec( relation, output.map(QueryPlan.normalizeExprId(_, output)), requiredSchema, - canonicalizeFilters(partitionFilters, output), - canonicalizeFilters(dataFilters, output), + QueryPlan.normalizePredicates(partitionFilters, output), + QueryPlan.normalizePredicates(dataFilters, output), None) } - - private def canonicalizeFilters(filters: Seq[Expression], output: Seq[Attribute]) - : Seq[Expression] = { - if (filters.nonEmpty) { - val normalizedFilters = QueryPlan.normalizeExprId(filters.reduce(And), output) - splitConjunctivePredicates(normalizedFilters) - } else { - Nil - } - } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 666548d1a490b27a5c5297fe3a4590c6221aa1b8..e191071efbf18864d3b8b4975ec1183adb8ff232 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -206,7 +206,7 @@ case class HiveTableScanExec( HiveTableScanExec( requestedAttributes.map(QueryPlan.normalizeExprId(_, input)), relation.canonicalized.asInstanceOf[CatalogRelation], - partitionPruningPred.map(QueryPlan.normalizeExprId(_, input)))(sparkSession) + QueryPlan.normalizePredicates(partitionPruningPred, input))(sparkSession) } override def otherCopyArgs: Seq[AnyRef] = Seq(sparkSession) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index 90e037e2927907315cd2495a2b98a659eb645946..ae64cb3210b533ff441f3973f57713a9c0837bc6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -164,16 +164,30 @@ class HiveTableScanSuite extends HiveComparisonTest with SQLTestUtils with TestH |PARTITION (p1='a',p2='c',p3='c',p4='d',p5='e') |SELECT v.id """.stripMargin) - val plan = sql( - s""" - |SELECT * FROM $table - """.stripMargin).queryExecution.sparkPlan - val scan = plan.collectFirst { - case p: HiveTableScanExec => p - }.get + val scan = getHiveTableScanExec(s"SELECT * FROM $table") val numDataCols = scan.relation.dataCols.length scan.rawPartitions.foreach(p => assert(p.getCols.size == numDataCols)) } } } + + test("HiveTableScanExec canonicalization for different orders of partition filters") { + val table = "hive_tbl_part" + withTable(table) { + sql( + s""" + |CREATE TABLE $table (id int) + |PARTITIONED BY (a int, b int) + """.stripMargin) + val scan1 = getHiveTableScanExec(s"SELECT * FROM $table WHERE a = 1 AND b = 2") + val scan2 = getHiveTableScanExec(s"SELECT * FROM $table WHERE b = 2 AND a = 1") + assert(scan1.sameResult(scan2)) + } + } + + private def getHiveTableScanExec(query: String): HiveTableScanExec = { + sql(query).queryExecution.sparkPlan.collectFirst { + case p: HiveTableScanExec => p + }.get + } }