diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a684dbc3afa426073999c86330d843a006af1d42..4bc1c1af40bf4b8effd7e4440f5f01437ee92c73 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -82,7 +82,9 @@ class Analyzer( HiveTypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Nondeterministic", Once, - PullOutNondeterministic) + PullOutNondeterministic), + Batch("Cleanup", fixedPoint, + CleanupAliases) ) /** @@ -146,8 +148,6 @@ class Analyzer( child match { case _: UnresolvedAttribute => u case ne: NamedExpression => ne - case g: GetStructField => Alias(g, g.field.name)() - case g: GetArrayStructFields => Alias(g, g.field.name)() case g: Generator if g.resolved && g.elementTypes.size > 1 => MultiAlias(g, Nil) case e if !e.resolved => u case other => Alias(other, s"_c$i")() @@ -384,9 +384,7 @@ class Analyzer( case u @ UnresolvedAttribute(nameParts) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. val result = - withPosition(u) { - q.resolveChildren(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u) - } + withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } logDebug(s"Resolving $u to $result") result case UnresolvedExtractValue(child, fieldExpr) if child.resolved => @@ -412,11 +410,6 @@ class Analyzer( exprs.exists(_.collect { case _: Star => true }.nonEmpty) } - private def trimUnresolvedAlias(ne: NamedExpression) = ne match { - case UnresolvedAlias(child) => child - case other => other - } - private def resolveSortOrders(ordering: Seq[SortOrder], plan: LogicalPlan, throws: Boolean) = { ordering.map { order => // Resolve SortOrder in one round. @@ -426,7 +419,7 @@ class Analyzer( try { val newOrder = order transformUp { case u @ UnresolvedAttribute(nameParts) => - plan.resolve(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u) + plan.resolve(nameParts, resolver).getOrElse(u) case UnresolvedExtractValue(child, fieldName) if child.resolved => ExtractValue(child, fieldName, resolver) } @@ -968,3 +961,61 @@ object EliminateSubQueries extends Rule[LogicalPlan] { case Subquery(_, child) => child } } + +/** + * Cleans up unnecessary Aliases inside the plan. Basically we only need Alias as a top level + * expression in Project(project list) or Aggregate(aggregate expressions) or + * Window(window expressions). + */ +object CleanupAliases extends Rule[LogicalPlan] { + private def trimAliases(e: Expression): Expression = { + var stop = false + e.transformDown { + // CreateStruct is a special case, we need to retain its top level Aliases as they decide the + // name of StructField. We also need to stop transform down this expression, or the Aliases + // under CreateStruct will be mistakenly trimmed. + case c: CreateStruct if !stop => + stop = true + c.copy(children = c.children.map(trimNonTopLevelAliases)) + case c: CreateStructUnsafe if !stop => + stop = true + c.copy(children = c.children.map(trimNonTopLevelAliases)) + case Alias(child, _) if !stop => child + } + } + + def trimNonTopLevelAliases(e: Expression): Expression = e match { + case a: Alias => + Alias(trimAliases(a.child), a.name)(a.exprId, a.qualifiers, a.explicitMetadata) + case other => trimAliases(other) + } + + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case Project(projectList, child) => + val cleanedProjectList = + projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) + Project(cleanedProjectList, child) + + case Aggregate(grouping, aggs, child) => + val cleanedAggs = aggs.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) + Aggregate(grouping.map(trimAliases), cleanedAggs, child) + + case w @ Window(projectList, windowExprs, partitionSpec, orderSpec, child) => + val cleanedWindowExprs = + windowExprs.map(e => trimNonTopLevelAliases(e).asInstanceOf[NamedExpression]) + Window(projectList, cleanedWindowExprs, partitionSpec.map(trimAliases), + orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child) + + case other => + var stop = false + other transformExpressionsDown { + case c: CreateStruct if !stop => + stop = true + c.copy(children = c.children.map(trimNonTopLevelAliases)) + case c: CreateStructUnsafe if !stop => + stop = true + c.copy(children = c.children.map(trimNonTopLevelAliases)) + case Alias(child, _) if !stop => child + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 4a071e663e0d1f6d264b819d328b6cdbc9076341..298aee34992753476448d05a963b850766a3015d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -75,8 +75,6 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) - override lazy val resolved: Boolean = childrenResolved - override lazy val dataType: StructType = { val fields = children.zipWithIndex.map { case (child, idx) => child match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 4ab5ac2c61e3ce3d316326dcc93d5124c6d80a63..47b06cae154363d2b084c6aded3f0d08ef53ebc7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.immutable.HashSet -import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueries} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.FullOuter @@ -260,8 +260,11 @@ object ProjectCollapsing extends Rule[LogicalPlan] { val substitutedProjection = projectList1.map(_.transform { case a: Attribute => aliasMap.getOrElse(a, a) }).asInstanceOf[Seq[NamedExpression]] - - Project(substitutedProjection, child) + // collapse 2 projects may introduce unnecessary Aliases, trim them here. + val cleanedProjection = substitutedProjection.map(p => + CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression] + ) + Project(cleanedProjection, child) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index c290e6acb361c19e9562798e5b8cdc7eb105379e..9bb466ac2d29cf6a0a3648ebb098ccfe84aa320e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -259,13 +259,13 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { // One match, but we also need to extract the requested nested field. case Seq((a, nestedFields)) => // The foldLeft adds ExtractValues for every remaining parts of the identifier, - // and wrap it with UnresolvedAlias which will be removed later. + // and aliased it with the last part of the name. // For example, consider "a.b.c", where "a" is resolved to an existing attribute. - // Then this will add ExtractValue("c", ExtractValue("b", a)), and wrap it as - // UnresolvedAlias(ExtractValue("c", ExtractValue("b", a))). + // Then this will add ExtractValue("c", ExtractValue("b", a)), and alias the final + // expression as "c". val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) => ExtractValue(expr, Literal(fieldName), resolver)) - Some(UnresolvedAlias(fieldExprs)) + Some(Alias(fieldExprs, nestedFields.last)()) // No matches. case Seq() => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 7c404722d811c368e8d30ccdd22a357db8cea620..73b8261260acba889f288cbdf248f114d5407dc9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -228,7 +228,7 @@ case class Window( child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = - (projectList ++ windowExpressions).map(_.toAttribute) + projectList ++ windowExpressions.map(_.toAttribute) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index c944bc69e25b0cb2471e07b6a46a0cc7f509b9fb..1e0cc81dae974e5fb9a30a26dab89ac024da4231 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -119,4 +119,21 @@ class AnalysisSuite extends AnalysisTest { Project(testRelation.output :+ projected, testRelation))) checkAnalysis(plan, expected) } + + test("SPARK-9634: cleanup unnecessary Aliases in LogicalPlan") { + val a = testRelation.output.head + var plan = testRelation.select(((a + 1).as("a+1") + 2).as("col")) + var expected = testRelation.select((a + 1 + 2).as("col")) + checkAnalysis(plan, expected) + + plan = testRelation.groupBy(a.as("a1").as("a2"))((min(a).as("min_a") + 1).as("col")) + expected = testRelation.groupBy(a)((min(a) + 1).as("col")) + checkAnalysis(plan, expected) + + // CreateStruct is a special case that we should not trim Alias for it. + plan = testRelation.select(CreateStruct(Seq(a, (a + 1).as("a+1"))).as("col")) + checkAnalysis(plan, plan) + plan = testRelation.select(CreateStructUnsafe(Seq(a, (a + 1).as("a+1"))).as("col")) + checkAnalysis(plan, plan) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 27bd084847346ea1465971e86c20daedac70fd43..807bc8c30c12da228d9e426b0d0be61fd5fccc45 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -753,10 +753,16 @@ class Column(protected[sql] val expr: Expression) extends Logging { * df.select($"colA".as("colB")) * }}} * + * If the current column has metadata associated with it, this metadata will be propagated + * to the new column. If this not desired, use `as` with explicitly empty metadata. + * * @group expr_ops * @since 1.3.0 */ - def as(alias: String): Column = Alias(expr, alias)() + def as(alias: String): Column = expr match { + case ne: NamedExpression => Alias(expr, alias)(explicitMetadata = Some(ne.metadata)) + case other => Alias(other, alias)() + } /** * (Scala-specific) Assigns the given aliases to the results of a table generating function. @@ -789,10 +795,16 @@ class Column(protected[sql] val expr: Expression) extends Logging { * df.select($"colA".as('colB)) * }}} * + * If the current column has metadata associated with it, this metadata will be propagated + * to the new column. If this not desired, use `as` with explicitly empty metadata. + * * @group expr_ops * @since 1.3.0 */ - def as(alias: Symbol): Column = Alias(expr, alias.name)() + def as(alias: Symbol): Column = expr match { + case ne: NamedExpression => Alias(expr, alias.name)(explicitMetadata = Some(ne.metadata)) + case other => Alias(other, alias.name)() + } /** * Gives the column an alias with metadata. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index ee74e3e83da5a8db2f4c08b32d1057849541f7f9..37738ec5b3c1d1b2341336f9f5bad8feb919f7e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.expressions.NamedExpression import org.scalatest.Matchers._ import org.apache.spark.sql.execution.{Project, TungstenProject} @@ -110,6 +111,14 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { assert(df.select(df("a").alias("b")).columns.head === "b") } + test("as propagates metadata") { + val metadata = new MetadataBuilder + metadata.putString("key", "value") + val origCol = $"a".as("b", metadata.build()) + val newCol = origCol.as("c") + assert(newCol.expr.asInstanceOf[NamedExpression].metadata.getString("key") === "value") + } + test("single explode") { val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 10bfa9b64f00db81e7b4fe34b320ca7c8b1207d3..cf22797752b971adf30cb624d559612de7e4ff1a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -867,4 +867,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val actual = df.sort(rand(seed)).collect().map(_.getInt(0)) assert(expected === actual) } + + test("SPARK-9323: DataFrame.orderBy should support nested column name") { + val df = sqlContext.read.json(sqlContext.sparkContext.makeRDD( + """{"a": {"b": 1}}""" :: Nil)) + checkAnswer(df.orderBy("a.b"), Row(Row(1))) + } }