diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 648a65e7c0eb35849f7d71ee3f5b7edd9ab8b90b..324f40a051c381f555dc3a945a08874029f32f93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -85,7 +85,22 @@ object SamplePushDown extends Rule[LogicalPlan] { } /** - * Pushes operations to either side of a Union, Intersect or Except. + * Pushes certain operations to both sides of a Union, Intersect or Except operator. + * Operations that are safe to pushdown are listed as follows. + * Union: + * Right now, Union means UNION ALL, which does not de-duplicate rows. So, it is + * safe to pushdown Filters and Projections through it. Once we add UNION DISTINCT, + * we will not be able to pushdown Projections. + * + * Intersect: + * It is not safe to pushdown Projections through it because we need to get the + * intersect of rows by comparing the entire rows. It is fine to pushdown Filters + * because we will not have non-deterministic expressions. + * + * Except: + * It is not safe to pushdown Projections through it because we need to get the + * intersect of rows by comparing the entire rows. It is fine to pushdown Filters + * because we will not have non-deterministic expressions. */ object SetOperationPushDown extends Rule[LogicalPlan] { @@ -122,40 +137,26 @@ object SetOperationPushDown extends Rule[LogicalPlan] { Filter(condition, left), Filter(pushToRight(condition, rewrites), right)) - // Push down projection into union + // Push down projection through UNION ALL case Project(projectList, u @ Union(left, right)) => val rewrites = buildRewrites(u) Union( Project(projectList, left), Project(projectList.map(pushToRight(_, rewrites)), right)) - // Push down filter into intersect + // Push down filter through INTERSECT case Filter(condition, i @ Intersect(left, right)) => val rewrites = buildRewrites(i) Intersect( Filter(condition, left), Filter(pushToRight(condition, rewrites), right)) - // Push down projection into intersect - case Project(projectList, i @ Intersect(left, right)) => - val rewrites = buildRewrites(i) - Intersect( - Project(projectList, left), - Project(projectList.map(pushToRight(_, rewrites)), right)) - - // Push down filter into except + // Push down filter through EXCEPT case Filter(condition, e @ Except(left, right)) => val rewrites = buildRewrites(e) Except( Filter(condition, left), Filter(pushToRight(condition, rewrites), right)) - - // Push down projection into except - case Project(projectList, e @ Except(left, right)) => - val rewrites = buildRewrites(e) - Except( - Project(projectList, left), - Project(projectList.map(pushToRight(_, rewrites)), right)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala index 49c979bc7d72c740b378326a6805219f8018f4df..3fca47a023dc68707bffecf4058360ecfc365cf2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala @@ -60,23 +60,22 @@ class SetOperationPushDownSuite extends PlanTest { comparePlans(exceptOptimized, exceptCorrectAnswer) } - test("union/intersect/except: project to each side") { + test("union: project to each side") { val unionQuery = testUnion.select('a) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Union(testRelation.select('a), testRelation2.select('d)).analyze + comparePlans(unionOptimized, unionCorrectAnswer) + } + + test("SPARK-10539: Project should not be pushed down through Intersect or Except") { val intersectQuery = testIntersect.select('b, 'c) val exceptQuery = testExcept.select('a, 'b, 'c) - val unionOptimized = Optimize.execute(unionQuery.analyze) val intersectOptimized = Optimize.execute(intersectQuery.analyze) val exceptOptimized = Optimize.execute(exceptQuery.analyze) - val unionCorrectAnswer = - Union(testRelation.select('a), testRelation2.select('d)).analyze - val intersectCorrectAnswer = - Intersect(testRelation.select('b, 'c), testRelation2.select('e, 'f)).analyze - val exceptCorrectAnswer = - Except(testRelation.select('a, 'b, 'c), testRelation2.select('d, 'e, 'f)).analyze - - comparePlans(unionOptimized, unionCorrectAnswer) - comparePlans(intersectOptimized, intersectCorrectAnswer) - comparePlans(exceptOptimized, exceptCorrectAnswer) } + comparePlans(intersectOptimized, intersectQuery.analyze) + comparePlans(exceptOptimized, exceptQuery.analyze) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index c167999af580ea60c8a119498fc2ccc5cfb84f76..1370713975f2fbdf09a9f6a2aefe06762c5fc6e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -907,4 +907,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(row.getDouble(1) - row.getDouble(3) === 0.0 +- 0.001) } } + + test("SPARK-10539: Project should not be pushed down through Intersect or Except") { + val df1 = (1 to 100).map(Tuple1.apply).toDF("i") + val df2 = (1 to 30).map(Tuple1.apply).toDF("i") + val intersect = df1.intersect(df2) + val except = df1.except(df2) + assert(intersect.count() === 30) + assert(except.count() === 70) + } }