diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 236476900a5197862dbf9857b9a0e3d2e90d739e..8595762988b4b1497c704ac153258932eb8864b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -296,10 +296,13 @@ class Analyzer( val nonNullBitmask = x.bitmasks.reduce(_ & _) - val groupByAttributes = groupByAliases.zipWithIndex.map { case (a, idx) => + val expandedAttributes = groupByAliases.zipWithIndex.map { case (a, idx) => a.toAttribute.withNullability((nonNullBitmask & 1 << idx) == 0) } + val expand = Expand(x.bitmasks, groupByAliases, expandedAttributes, gid, x.child) + val groupingAttrs = expand.output.drop(x.child.output.length) + val aggregations: Seq[NamedExpression] = x.aggregations.map { case expr => // collect all the found AggregateExpression, so we can check an expression is part of // any AggregateExpression or not. @@ -321,15 +324,12 @@ class Analyzer( if (index == -1) { e } else { - groupByAttributes(index) + groupingAttrs(index) } }.asInstanceOf[NamedExpression] } - Aggregate( - groupByAttributes :+ gid, - aggregations, - Expand(x.bitmasks, groupByAliases, groupByAttributes, gid, x.child)) + Aggregate(groupingAttrs, aggregations, expand) case f @ Filter(cond, child) if hasGroupingFunction(cond) => val groupingExprs = findGroupingExprs(child) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index ecc2d773e7753650547f55c4cbcd2401186e5614..e6d554565d442924625aca30dc69b6e26c805f97 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1020,8 +1020,6 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { case filter @ Filter(_, f: Filter) => filter // should not push predicates through sample, or will generate different results. case filter @ Filter(_, s: Sample) => filter - // TODO: push predicates through expand - case filter @ Filter(_, e: Expand) => filter case filter @ Filter(condition, u: UnaryNode) if u.expressions.forall(_.deterministic) => pushDownPredicate(filter, u.child) { predicate => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index d4fc9e4da944aebd97539e1fc0b9fb46fda5800b..a445ce694750a57d5ee16d1481e0bd61ac80a28d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -516,7 +516,10 @@ private[sql] object Expand { // groupingId is the last output, here we use the bit mask as the concrete value for it. } :+ Literal.create(bitmask, IntegerType) } - val output = child.output ++ groupByAttrs :+ gid + + // the `groupByAttrs` has different meaning in `Expand.output`, it could be the original + // grouping expression or null, so here we create new instance of it. + val output = child.output ++ groupByAttrs.map(_.newInstance) :+ gid Expand(projections, output, Project(child.output ++ groupByAliases, child)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index df7529d83f7c829955f16d733844fb1723dc487f..9174b4e649a6eae7d3a44915f172082529f50850 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -743,4 +743,19 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("expand") { + val agg = testRelation + .groupBy(Cube(Seq('a, 'b)))('a, 'b, sum('c)) + .analyze + .asInstanceOf[Aggregate] + + val a = agg.output(0) + val b = agg.output(1) + + val query = agg.where(a > 1 && b > 2) + val optimized = Optimize.execute(query) + val correctedAnswer = agg.copy(child = agg.child.where(a > 1 && b > 2)).analyze + comparePlans(optimized, correctedAnswer) + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index e54358e657690c1aa15465880d6111f7db1a4c0e..2d44813f0eac55968c9e2ac840b5c374e929a469 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -288,8 +288,9 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi private def isGroupingSet(a: Aggregate, e: Expand, p: Project): Boolean = { assert(a.child == e && e.child == p) - a.groupingExpressions.forall(_.isInstanceOf[Attribute]) && - sameOutput(e.output, p.child.output ++ a.groupingExpressions.map(_.asInstanceOf[Attribute])) + a.groupingExpressions.forall(_.isInstanceOf[Attribute]) && sameOutput( + e.output.drop(p.child.output.length), + a.groupingExpressions.map(_.asInstanceOf[Attribute])) } private def groupingSetToSQL( @@ -303,25 +304,28 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi val numOriginalOutput = project.child.output.length // Assumption: Aggregate's groupingExpressions is composed of - // 1) the attributes of aliased group by expressions + // 1) the grouping attributes // 2) gid, which is always the last one val groupByAttributes = agg.groupingExpressions.dropRight(1).map(_.asInstanceOf[Attribute]) // Assumption: Project's projectList is composed of // 1) the original output (Project's child.output), // 2) the aliased group by expressions. + val expandedAttributes = project.output.drop(numOriginalOutput) val groupByExprs = project.projectList.drop(numOriginalOutput).map(_.asInstanceOf[Alias].child) val groupingSQL = groupByExprs.map(_.sql).mkString(", ") // a map from group by attributes to the original group by expressions. val groupByAttrMap = AttributeMap(groupByAttributes.zip(groupByExprs)) + // a map from expanded attributes to the original group by expressions. + val expandedAttrMap = AttributeMap(expandedAttributes.zip(groupByExprs)) val groupingSet: Seq[Seq[Expression]] = expand.projections.map { project => // Assumption: expand.projections is composed of // 1) the original output (Project's child.output), - // 2) group by attributes(or null literal) + // 2) expanded attributes(or null literal) // 3) gid, which is always the last one in each project in Expand project.drop(numOriginalOutput).dropRight(1).collect { - case attr: Attribute if groupByAttrMap.contains(attr) => groupByAttrMap(attr) + case attr: Attribute if expandedAttrMap.contains(attr) => expandedAttrMap(attr) } } val groupingSetSQL = "GROUPING SETS(" +