diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 4d53b232d551020e692bdae7503b70d9fa9c002e..62b241f05270aab63c769d6f1fe3ce0654127c34 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -416,9 +416,10 @@ class Analyzer( case UnresolvedAlias(f @ UnresolvedFunction(_, args, _), _) if containsStar(args) => val newChildren = expandStarExpressions(args, child) UnresolvedAlias(child = f.copy(children = newChildren)) :: Nil - case Alias(f @ UnresolvedFunction(_, args, _), name) if containsStar(args) => + case a @ Alias(f @ UnresolvedFunction(_, args, _), name) if containsStar(args) => val newChildren = expandStarExpressions(args, child) - Alias(child = f.copy(children = newChildren), name)() :: Nil + Alias(child = f.copy(children = newChildren), name)( + isGenerated = a.isGenerated) :: Nil case UnresolvedAlias(c @ CreateArray(args), _) if containsStar(args) => val expandedArgs = args.flatMap { case s: Star => s.expand(child, resolver) @@ -528,7 +529,7 @@ class Analyzer( def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = { expressions.map { - case a: Alias => Alias(a.child, a.name)() + case a: Alias => Alias(a.child, a.name)(isGenerated = a.isGenerated) case other => other } } @@ -734,7 +735,10 @@ class Analyzer( // Try resolving the condition of the filter as though it is in the aggregate clause val aggregatedCondition = - Aggregate(grouping, Alias(havingCondition, "havingCondition")() :: Nil, child) + Aggregate( + grouping, + Alias(havingCondition, "havingCondition")(isGenerated = true) :: Nil, + child) val resolvedOperator = execute(aggregatedCondition) def resolvedAggregateFilter = resolvedOperator @@ -759,7 +763,8 @@ class Analyzer( // Try resolving the ordering as though it is in the aggregate clause. try { val unresolvedSortOrders = sortOrder.filter(s => !s.resolved || containsAggregate(s)) - val aliasedOrdering = unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")()) + val aliasedOrdering = + unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")(isGenerated = true)) val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering) val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate] val resolvedAliasedOrdering: Seq[Alias] = @@ -1190,7 +1195,7 @@ class Analyzer( leafNondeterministic.map { e => val ne = e match { case n: NamedExpression => n - case _ => Alias(e, "_nondeterministic")() + case _ => Alias(e, "_nondeterministic")(isGenerated = true) } new TreeNodeRef(e) -> ne } @@ -1355,7 +1360,8 @@ object CleanupAliases extends Rule[LogicalPlan] { def trimNonTopLevelAliases(e: Expression): Expression = e match { case a: Alias => - Alias(trimAliases(a.child), a.name)(a.exprId, a.qualifiers, a.explicitMetadata) + Alias(trimAliases(a.child), a.name)( + a.exprId, a.qualifiers, a.explicitMetadata, a.isGenerated) case other => trimAliases(other) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala index 4e7d1341028ca517e32496c0287fa36c2b391f12..5dfce89bd68a6a4e0e4e6d2a49e9c538fa3b5f21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala @@ -126,7 +126,8 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP // Aggregation strategy can handle the query with single distinct if (distinctAggGroups.size > 1) { // Create the attributes for the grouping id and the group by clause. - val gid = new AttributeReference("gid", IntegerType, false)() + val gid = + new AttributeReference("gid", IntegerType, false)(isGenerated = true) val groupByMap = a.groupingExpressions.collect { case ne: NamedExpression => ne -> ne.toAttribute case e => e -> new AttributeReference(e.prettyString, e.dataType, e.nullable)() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 7983501ada9bd3ce3fe0363633cc6cfc299a3f20..207b8a0a88556c5b130edb30f87a82d027517cce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -79,6 +79,9 @@ trait NamedExpression extends Expression { /** Returns the metadata when an expression is a reference to another expression with metadata. */ def metadata: Metadata = Metadata.empty + /** Returns true if the expression is generated by Catalyst */ + def isGenerated: java.lang.Boolean = false + /** Returns a copy of this expression with a new `exprId`. */ def newInstance(): NamedExpression @@ -114,16 +117,21 @@ abstract class Attribute extends LeafExpression with NamedExpression { * Note that exprId and qualifiers are in a separate parameter list because * we only pattern match on child and name. * - * @param child the computation being performed - * @param name the name to be associated with the result of computing [[child]]. + * @param child The computation being performed + * @param name The name to be associated with the result of computing [[child]]. * @param exprId A globally unique id used to check if an [[AttributeReference]] refers to this * alias. Auto-assigned if left blank. + * @param qualifiers A list of strings that can be used to referred to this attribute in a fully + * qualified way. Consider the examples tableName.name, subQueryAlias.name. + * tableName and subQueryAlias are possible qualifiers. * @param explicitMetadata Explicit metadata associated with this alias that overwrites child's. + * @param isGenerated A flag to indicate if this alias is generated by Catalyst */ case class Alias(child: Expression, name: String)( val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil, - val explicitMetadata: Option[Metadata] = None) + val explicitMetadata: Option[Metadata] = None, + override val isGenerated: java.lang.Boolean = false) extends UnaryExpression with NamedExpression { // Alias(Generator, xx) need to be transformed into Generate(generator, ...) @@ -148,11 +156,13 @@ case class Alias(child: Expression, name: String)( } def newInstance(): NamedExpression = - Alias(child, name)(qualifiers = qualifiers, explicitMetadata = explicitMetadata) + Alias(child, name)( + qualifiers = qualifiers, explicitMetadata = explicitMetadata, isGenerated = isGenerated) override def toAttribute: Attribute = { if (resolved) { - AttributeReference(name, child.dataType, child.nullable, metadata)(exprId, qualifiers) + AttributeReference(name, child.dataType, child.nullable, metadata)( + exprId, qualifiers, isGenerated) } else { UnresolvedAttribute(name) } @@ -161,7 +171,7 @@ case class Alias(child: Expression, name: String)( override def toString: String = s"$child AS $name#${exprId.id}$typeSuffix" override protected final def otherCopyArgs: Seq[AnyRef] = { - exprId :: qualifiers :: explicitMetadata :: Nil + exprId :: qualifiers :: explicitMetadata :: isGenerated :: Nil } override def equals(other: Any): Boolean = other match { @@ -174,7 +184,8 @@ case class Alias(child: Expression, name: String)( override def sql: String = { val qualifiersString = if (qualifiers.isEmpty) "" else qualifiers.map("`" + _ + "`").mkString("", ".", ".") - s"${child.sql} AS $qualifiersString`$name`" + val aliasName = if (isGenerated) s"$name#${exprId.id}" else s"$name" + s"${child.sql} AS $qualifiersString`$aliasName`" } } @@ -187,9 +198,10 @@ case class Alias(child: Expression, name: String)( * @param metadata The metadata of this attribute. * @param exprId A globally unique id used to check if different AttributeReferences refer to the * same attribute. - * @param qualifiers a list of strings that can be used to referred to this attribute in a fully + * @param qualifiers A list of strings that can be used to referred to this attribute in a fully * qualified way. Consider the examples tableName.name, subQueryAlias.name. * tableName and subQueryAlias are possible qualifiers. + * @param isGenerated A flag to indicate if this reference is generated by Catalyst */ case class AttributeReference( name: String, @@ -197,7 +209,8 @@ case class AttributeReference( nullable: Boolean = true, override val metadata: Metadata = Metadata.empty)( val exprId: ExprId = NamedExpression.newExprId, - val qualifiers: Seq[String] = Nil) + val qualifiers: Seq[String] = Nil, + override val isGenerated: java.lang.Boolean = false) extends Attribute with Unevaluable { /** @@ -234,7 +247,8 @@ case class AttributeReference( } override def newInstance(): AttributeReference = - AttributeReference(name, dataType, nullable, metadata)(qualifiers = qualifiers) + AttributeReference(name, dataType, nullable, metadata)( + qualifiers = qualifiers, isGenerated = isGenerated) /** * Returns a copy of this [[AttributeReference]] with changed nullability. @@ -243,7 +257,7 @@ case class AttributeReference( if (nullable == newNullability) { this } else { - AttributeReference(name, dataType, newNullability, metadata)(exprId, qualifiers) + AttributeReference(name, dataType, newNullability, metadata)(exprId, qualifiers, isGenerated) } } @@ -251,7 +265,7 @@ case class AttributeReference( if (name == newName) { this } else { - AttributeReference(newName, dataType, nullable)(exprId, qualifiers) + AttributeReference(newName, dataType, nullable, metadata)(exprId, qualifiers, isGenerated) } } @@ -262,7 +276,7 @@ case class AttributeReference( if (newQualifiers.toSet == qualifiers.toSet) { this } else { - AttributeReference(name, dataType, nullable, metadata)(exprId, newQualifiers) + AttributeReference(name, dataType, nullable, metadata)(exprId, newQualifiers, isGenerated) } } @@ -270,12 +284,12 @@ case class AttributeReference( if (exprId == newExprId) { this } else { - AttributeReference(name, dataType, nullable, metadata)(newExprId, qualifiers) + AttributeReference(name, dataType, nullable, metadata)(newExprId, qualifiers, isGenerated) } } override protected final def otherCopyArgs: Seq[AnyRef] = { - exprId :: qualifiers :: Nil + exprId :: qualifiers :: isGenerated :: Nil } override def toString: String = s"$name#${exprId.id}$typeSuffix" @@ -287,7 +301,8 @@ case class AttributeReference( override def sql: String = { val qualifiersString = if (qualifiers.isEmpty) "" else qualifiers.map("`" + _ + "`").mkString("", ".", ".") - s"$qualifiersString`$name`" + val attrRefName = if (isGenerated) s"$name#${exprId.id}" else s"$name" + s"$qualifiersString`$attrRefName`" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index f0ee124e88a9f2dc6ccdf695c13c9589c05b48bd..7302b63646d6672bfa65cc1b3f73e21de69298e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -78,10 +78,13 @@ object PhysicalOperation extends PredicateHelper { private def substitute(aliases: Map[Attribute, Expression])(expr: Expression): Expression = { expr.transform { case a @ Alias(ref: AttributeReference, name) => - aliases.get(ref).map(Alias(_, name)(a.exprId, a.qualifiers)).getOrElse(a) + aliases.get(ref) + .map(Alias(_, name)(a.exprId, a.qualifiers, isGenerated = a.isGenerated)) + .getOrElse(a) case a: AttributeReference => - aliases.get(a).map(Alias(_, a.name)(a.exprId, a.qualifiers)).getOrElse(a) + aliases.get(a) + .map(Alias(_, a.name)(a.exprId, a.qualifiers, isGenerated = a.isGenerated)).getOrElse(a) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index d8944a424156eae49b6b6d1eb069386284b4ca77..18b7bde906fda61caa56450ca6f4c7d6822d7b06 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -139,7 +139,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { case a: Alias => // As the root of the expression, Alias will always take an arbitrary exprId, we need // to erase that for equality testing. - val cleanedExprId = Alias(a.child, a.name)(ExprId(-1), a.qualifiers) + val cleanedExprId = + Alias(a.child, a.name)(ExprId(-1), a.qualifiers, isGenerated = a.isGenerated) BindReferences.bindReference(cleanedExprId, input, allowFailures = true) case other => BindReferences.bindReference(other, input, allowFailures = true) } @@ -222,7 +223,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { nameParts: Seq[String], resolver: Resolver, attribute: Attribute): Option[(Attribute, List[String])] = { - if (resolver(attribute.name, nameParts.head)) { + if (!attribute.isGenerated && resolver(attribute.name, nameParts.head)) { Option((attribute.withName(nameParts.head), nameParts.tail.toList)) } else { None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 2df0683f9fa161e9308a515a768354a27289c655..30df2a84f62c4726d47c202c77f6c4c9b09c59f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -656,6 +656,8 @@ object TreeNode { case t if t <:< definitions.DoubleTpe => value.asInstanceOf[JDouble].num: java.lang.Double + case t if t <:< localTypeOf[java.lang.Boolean] => + value.asInstanceOf[JBool].value: java.lang.Boolean case t if t <:< localTypeOf[BigInt] => value.asInstanceOf[JInt].num case t if t <:< localTypeOf[java.lang.String] => value.asInstanceOf[JString].s case t if t <:< localTypeOf[UUID] => UUID.fromString(value.asInstanceOf[JString].s) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index fc35959f205473e16a780fd802674ea4c47e6c5c..e0cec09742eba7954b2299fd982ccacfc210102a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -23,10 +23,10 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count, Sum} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.Inner -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} import org.apache.spark.sql.types._ @BeanInfo @@ -176,6 +176,13 @@ class AnalysisErrorSuite extends AnalysisTest { testRelation.select('abcd), "cannot resolve" :: "abcd" :: Nil) + errorTest( + "unresolved attributes with a generated name", + testRelation2.groupBy('a)(max('b)) + .where(sum('b) > 0) + .orderBy('havingCondition.asc), + "cannot resolve" :: "havingCondition" :: Nil) + errorTest( "bad casts", testRelation.select(Literal(1).cast(BinaryType).as('badCast)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index c02133ffc85402e6f59379d67d238c7a86250b83..3ea4adcaa6424ede95acb3159c069c7ada56aaef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -998,12 +998,20 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } - test("SPARK-10034: Sort on Aggregate with aggregation expression named 'aggOrdering'") { + test("Alias uses internally generated names 'aggOrder' and 'havingCondition'") { val df = Seq(1 -> 2).toDF("i", "j") - val query = df.groupBy('i) - .agg(max('j).as("aggOrdering")) + val query1 = df.groupBy('i) + .agg(max('j).as("aggOrder")) .orderBy(sum('j)) - checkAnswer(query, Row(1, 2)) + checkAnswer(query1, Row(1, 2)) + + // In the plan, there are two attributes having the same name 'havingCondition' + // One is a user-provided alias name; another is an internally generated one. + val query2 = df.groupBy('i) + .agg(max('j).as("havingCondition")) + .where(sum('j) > 0) + .orderBy('havingCondition.asc) + checkAnswer(query2, Row(1, 2)) } test("SPARK-10316: respect non-deterministic expressions in PhysicalOperation") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala index 1f731db26f3875a2dfb7a80cd8bbf5f35cf79772..129bfe0a7dfd8154f1824c0d136ced6b63db036f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -92,12 +92,11 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { checkHiveQl("SELECT COUNT(value) FROM t1 GROUP BY key ORDER BY MAX(key)") } - // TODO Fix name collision introduced by ResolveAggregateFunction analysis rule // When there are multiple aggregate functions in ORDER BY clause, all of them are extracted into // Aggregate operator and aliased to the same name "aggOrder". This is OK for normal query // execution since these aliases have different expression ID. But this introduces name collision // when converting resolved plans back to SQL query strings as expression IDs are stripped. - ignore("aggregate function in order by clause with multiple order keys") { + test("aggregate function in order by clause with multiple order keys") { checkHiveQl("SELECT COUNT(value) FROM t1 GROUP BY key ORDER BY key, MAX(key)") }