Skip to content
Snippets Groups Projects
Commit da7bbb94 authored by Wenchen Fan's avatar Wenchen Fan Committed by Michael Armbrust
Browse files

[SPARK-8104] [SQL] auto alias expressions in analyzer

Currently we auto alias expression in parser. However, during parser phase we don't have enough information to do the right alias. For example, Generator that has more than 1 kind of element need MultiAlias, ExtractValue don't need Alias if it's in middle of a ExtractValue chain.

Author: Wenchen Fan <cloud0fan@outlook.com>

Closes #6647 from cloud-fan/alias and squashes the following commits:

552eba4 [Wenchen Fan] fix python
5b5786d [Wenchen Fan] fix agg
73a90cb [Wenchen Fan] fix case-preserve of ExtractValue
4cfd23c [Wenchen Fan] fix order by
d18f401 [Wenchen Fan] refine
9f07359 [Wenchen Fan] address comments
39c1aef [Wenchen Fan] small fix
33640ec [Wenchen Fan] auto alias expressions in analyzer
parent 5d89d9f0
No related branches found
No related tags found
No related merge requests found
Showing
with 150 additions and 117 deletions
......@@ -86,7 +86,8 @@ class SQLContext(object):
>>> df.registerTempTable("allTypes")
>>> sqlContext.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '
... 'from allTypes where b and i > 0').collect()
[Row(c0=2, c1=2.0, c2=False, c3=2, c4=0, time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)]
[Row(_c0=2, _c1=2.0, _c2=False, _c3=2, _c4=0, \
time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)]
>>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).collect()
[(1, u'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])]
"""
......@@ -176,17 +177,17 @@ class SQLContext(object):
>>> sqlContext.registerFunction("stringLengthString", lambda x: len(x))
>>> sqlContext.sql("SELECT stringLengthString('test')").collect()
[Row(c0=u'4')]
[Row(_c0=u'4')]
>>> from pyspark.sql.types import IntegerType
>>> sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
[Row(c0=4)]
[Row(_c0=4)]
>>> from pyspark.sql.types import IntegerType
>>> sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
[Row(c0=4)]
[Row(_c0=4)]
"""
func = lambda _, it: map(lambda x: f(*x), it)
ser = AutoBatchedSerializer(PickleSerializer())
......
......@@ -99,13 +99,6 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
protected val WHERE = Keyword("WHERE")
protected val WITH = Keyword("WITH")
protected def assignAliases(exprs: Seq[Expression]): Seq[NamedExpression] = {
exprs.zipWithIndex.map {
case (ne: NamedExpression, _) => ne
case (e, i) => Alias(e, s"c$i")()
}
}
protected lazy val start: Parser[LogicalPlan] =
start1 | insert | cte
......@@ -130,8 +123,8 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
val base = r.getOrElse(OneRowRelation)
val withFilter = f.map(Filter(_, base)).getOrElse(base)
val withProjection = g
.map(Aggregate(_, assignAliases(p), withFilter))
.getOrElse(Project(assignAliases(p), withFilter))
.map(Aggregate(_, p.map(UnresolvedAlias(_)), withFilter))
.getOrElse(Project(p.map(UnresolvedAlias(_)), withFilter))
val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection)
val withHaving = h.map(Filter(_, withDistinct)).getOrElse(withDistinct)
val withOrder = o.map(_(withHaving)).getOrElse(withHaving)
......
......@@ -17,8 +17,6 @@
package org.apache.spark.sql.catalyst.analysis
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf}
import org.apache.spark.sql.catalyst.expressions._
......@@ -74,10 +72,10 @@ class Analyzer(
ResolveSortReferences ::
ResolveGenerate ::
ResolveFunctions ::
ResolveAliases ::
ExtractWindowExpressions ::
GlobalAggregates ::
UnresolvedHavingClauseAttributes ::
TrimGroupingAliases ::
typeCoercionRules ++
extendedResolutionRules : _*)
)
......@@ -132,12 +130,38 @@ class Analyzer(
}
/**
* Removes no-op Alias expressions from the plan.
* Replaces [[UnresolvedAlias]]s with concrete aliases.
*/
object TrimGroupingAliases extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Aggregate(groups, aggs, child) =>
Aggregate(groups.map(_.transform { case Alias(c, _) => c }), aggs, child)
object ResolveAliases extends Rule[LogicalPlan] {
private def assignAliases(exprs: Seq[NamedExpression]) = {
// The `UnresolvedAlias`s will appear only at root of a expression tree, we don't need
// to transform down the whole tree.
exprs.zipWithIndex.map {
case (u @ UnresolvedAlias(child), i) =>
child match {
case _: UnresolvedAttribute => u
case ne: NamedExpression => ne
case ev: ExtractValueWithStruct => Alias(ev, ev.field.name)()
case g: Generator if g.resolved && g.elementTypes.size > 1 => MultiAlias(g, Nil)
case e if !e.resolved => u
case other => Alias(other, s"_c$i")()
}
case (other, _) => other
}
}
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case Aggregate(groups, aggs, child)
if child.resolved && aggs.exists(_.isInstanceOf[UnresolvedAlias]) =>
Aggregate(groups, assignAliases(aggs), child)
case g: GroupingAnalytics
if g.child.resolved && g.aggregations.exists(_.isInstanceOf[UnresolvedAlias]) =>
g.withNewAggs(assignAliases(g.aggregations))
case Project(projectList, child)
if child.resolved && projectList.exists(_.isInstanceOf[UnresolvedAlias]) =>
Project(assignAliases(projectList), child)
}
}
......@@ -228,7 +252,7 @@ class Analyzer(
}
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case i@InsertIntoTable(u: UnresolvedRelation, _, _, _, _) =>
case i @ InsertIntoTable(u: UnresolvedRelation, _, _, _, _) =>
i.copy(table = EliminateSubQueries(getTable(u)))
case u: UnresolvedRelation =>
getTable(u)
......@@ -248,24 +272,24 @@ class Analyzer(
Project(
projectList.flatMap {
case s: Star => s.expand(child.output, resolver)
case Alias(f @ UnresolvedFunction(_, args), name) if containsStar(args) =>
case UnresolvedAlias(f @ UnresolvedFunction(_, args)) if containsStar(args) =>
val expandedArgs = args.flatMap {
case s: Star => s.expand(child.output, resolver)
case o => o :: Nil
}
Alias(child = f.copy(children = expandedArgs), name)() :: Nil
case Alias(c @ CreateArray(args), name) if containsStar(args) =>
UnresolvedAlias(child = f.copy(children = expandedArgs)) :: Nil
case UnresolvedAlias(c @ CreateArray(args)) if containsStar(args) =>
val expandedArgs = args.flatMap {
case s: Star => s.expand(child.output, resolver)
case o => o :: Nil
}
Alias(c.copy(children = expandedArgs), name)() :: Nil
case Alias(c @ CreateStruct(args), name) if containsStar(args) =>
UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil
case UnresolvedAlias(c @ CreateStruct(args)) if containsStar(args) =>
val expandedArgs = args.flatMap {
case s: Star => s.expand(child.output, resolver)
case o => o :: Nil
}
Alias(c.copy(children = expandedArgs), name)() :: Nil
UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil
case o => o :: Nil
},
child)
......@@ -353,7 +377,9 @@ class Analyzer(
case u @ UnresolvedAttribute(nameParts) =>
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
val result =
withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) }
withPosition(u) {
q.resolveChildren(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u)
}
logDebug(s"Resolving $u to $result")
result
case UnresolvedExtractValue(child, fieldExpr) if child.resolved =>
......@@ -379,6 +405,11 @@ class Analyzer(
exprs.exists(_.collect { case _: Star => true }.nonEmpty)
}
private def trimUnresolvedAlias(ne: NamedExpression) = ne match {
case UnresolvedAlias(child) => child
case other => other
}
private def resolveSortOrders(ordering: Seq[SortOrder], plan: LogicalPlan, throws: Boolean) = {
ordering.map { order =>
// Resolve SortOrder in one round.
......@@ -388,7 +419,7 @@ class Analyzer(
try {
val newOrder = order transformUp {
case u @ UnresolvedAttribute(nameParts) =>
plan.resolve(nameParts, resolver).getOrElse(u)
plan.resolve(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u)
case UnresolvedExtractValue(child, fieldName) if child.resolved =>
ExtractValue(child, fieldName, resolver)
}
......@@ -586,18 +617,6 @@ class Analyzer(
/** Extracts a [[Generator]] expression and any names assigned by aliases to their output. */
private object AliasedGenerator {
def unapply(e: Expression): Option[(Generator, Seq[String])] = e match {
case Alias(g: Generator, name)
if g.resolved &&
g.elementTypes.size > 1 &&
java.util.regex.Pattern.matches("_c[0-9]+", name) => {
// Assume the default name given by parser is "_c[0-9]+",
// TODO in long term, move the naming logic from Parser to Analyzer.
// In projection, Parser gave default name for TGF as does for normal UDF,
// but the TGF probably have multiple output columns/names.
// e.g. SELECT explode(map(key, value)) FROM src;
// Let's simply ignore the default given name for this case.
Some((g, Nil))
}
case Alias(g: Generator, name) if g.resolved && g.elementTypes.size > 1 =>
// If not given the default names, and the TGF with multiple output columns
failAnalysis(
......
......@@ -95,14 +95,7 @@ trait CheckAnalysis {
case e => e.children.foreach(checkValidAggregateExpression)
}
val cleaned = aggregateExprs.map(_.transform {
// Should trim aliases around `GetField`s. These aliases are introduced while
// resolving struct field accesses, because `GetField` is not a `NamedExpression`.
// (Should we just turn `GetField` into a `NamedExpression`?)
case Alias(g, _) => g
})
cleaned.foreach(checkValidAggregateExpression)
aggregateExprs.foreach(checkValidAggregateExpression)
case _ => // Fallbacks to the following checks
}
......
......@@ -17,7 +17,6 @@
package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.{errors, trees}
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
......@@ -206,3 +205,22 @@ case class UnresolvedExtractValue(child: Expression, extraction: Expression)
override def toString: String = s"$child[$extraction]"
}
/**
* Holds the expression that has yet to be aliased.
*/
case class UnresolvedAlias(child: Expression) extends NamedExpression
with trees.UnaryNode[Expression] {
override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute")
override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers")
override def exprId: ExprId = throw new UnresolvedException(this, "exprId")
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
override def name: String = throw new UnresolvedException(this, "name")
override lazy val resolved = false
override def eval(input: InternalRow = null): Any =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
}
......@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import scala.collection.Map
import org.apache.spark.sql.{catalyst, AnalysisException}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.types._
......@@ -41,16 +41,22 @@ object ExtractValue {
resolver: Resolver): ExtractValue = {
(child.dataType, extraction) match {
case (StructType(fields), Literal(fieldName, StringType)) =>
val ordinal = findField(fields, fieldName.toString, resolver)
GetStructField(child, fields(ordinal), ordinal)
case (ArrayType(StructType(fields), containsNull), Literal(fieldName, StringType)) =>
val ordinal = findField(fields, fieldName.toString, resolver)
GetArrayStructFields(child, fields(ordinal), ordinal, containsNull)
case (StructType(fields), NonNullLiteral(v, StringType)) =>
val fieldName = v.toString
val ordinal = findField(fields, fieldName, resolver)
GetStructField(child, fields(ordinal).copy(name = fieldName), ordinal)
case (ArrayType(StructType(fields), containsNull), NonNullLiteral(v, StringType)) =>
val fieldName = v.toString
val ordinal = findField(fields, fieldName, resolver)
GetArrayStructFields(child, fields(ordinal).copy(name = fieldName), ordinal, containsNull)
case (_: ArrayType, _) if extraction.dataType.isInstanceOf[IntegralType] =>
GetArrayItem(child, extraction)
case (_: MapType, _) =>
GetMapValue(child, extraction)
case (otherType, _) =>
val errorMsg = otherType match {
case StructType(_) | ArrayType(StructType(_), _) =>
......@@ -94,16 +100,21 @@ trait ExtractValue extends UnaryExpression {
self: Product =>
}
abstract class ExtractValueWithStruct extends ExtractValue {
self: Product =>
def field: StructField
override def toString: String = s"$child.${field.name}"
}
/**
* Returns the value of fields in the Struct `child`.
*/
case class GetStructField(child: Expression, field: StructField, ordinal: Int)
extends ExtractValue {
extends ExtractValueWithStruct {
override def dataType: DataType = field.dataType
override def nullable: Boolean = child.nullable || field.nullable
override def foldable: Boolean = child.foldable
override def toString: String = s"$child.${field.name}"
override def eval(input: InternalRow): Any = {
val baseValue = child.eval(input).asInstanceOf[InternalRow]
......@@ -118,12 +129,9 @@ case class GetArrayStructFields(
child: Expression,
field: StructField,
ordinal: Int,
containsNull: Boolean) extends ExtractValue {
containsNull: Boolean) extends ExtractValueWithStruct {
override def dataType: DataType = ArrayType(field.dataType, containsNull)
override def nullable: Boolean = child.nullable
override def foldable: Boolean = child.foldable
override def toString: String = s"$child.${field.name}"
override def eval(input: InternalRow): Any = {
val baseValue = child.eval(input).asInstanceOf[Seq[InternalRow]]
......
......@@ -156,12 +156,8 @@ object PartialAggregation {
partialEvaluations(new TreeNodeRef(e)).finalEvaluation
case e: Expression =>
// Should trim aliases around `GetField`s. These aliases are introduced while
// resolving struct field accesses, because `GetField` is not a `NamedExpression`.
// (Should we just turn `GetField` into a `NamedExpression`?)
val trimmed = e.transform { case Alias(g: ExtractValue, _) => g }
namedGroupingExpressions.collectFirst {
case (expr, ne) if expr semanticEquals trimmed => ne.toAttribute
case (expr, ne) if expr semanticEquals e => ne.toAttribute
}.getOrElse(e)
}).asInstanceOf[Seq[NamedExpression]]
......
......@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, EliminateSubQueries, Resolver}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.trees.TreeNode
......@@ -252,14 +252,13 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
// One match, but we also need to extract the requested nested field.
case Seq((a, nestedFields)) =>
// The foldLeft adds ExtractValues for every remaining parts of the identifier,
// and aliases it with the last part of the identifier.
// and wrap it with UnresolvedAlias which will be removed later.
// For example, consider "a.b.c", where "a" is resolved to an existing attribute.
// Then this will add ExtractValue("c", ExtractValue("b", a)), and alias
// the final expression as "c".
// Then this will add ExtractValue("c", ExtractValue("b", a)), and wrap it as
// UnresolvedAlias(ExtractValue("c", ExtractValue("b", a))).
val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) =>
ExtractValue(expr, Literal(fieldName), resolver))
val aliasName = nestedFields.last
Some(Alias(fieldExprs, aliasName)())
Some(UnresolvedAlias(fieldExprs))
// No matches.
case Seq() =>
......
......@@ -242,6 +242,8 @@ trait GroupingAnalytics extends UnaryNode {
def aggregations: Seq[NamedExpression]
override def output: Seq[Attribute] = aggregations.map(_.toAttribute)
def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics
}
/**
......@@ -266,7 +268,11 @@ case class GroupingSets(
groupByExprs: Seq[Expression],
child: LogicalPlan,
aggregations: Seq[NamedExpression],
gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics
gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics {
def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics =
this.copy(aggregations = aggs)
}
/**
* Cube is a syntactic sugar for GROUPING SETS, and will be transformed to GroupingSets,
......@@ -284,7 +290,11 @@ case class Cube(
groupByExprs: Seq[Expression],
child: LogicalPlan,
aggregations: Seq[NamedExpression],
gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics
gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics {
def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics =
this.copy(aggregations = aggs)
}
/**
* Rollup is a syntactic sugar for GROUPING SETS, and will be transformed to GroupingSets,
......@@ -303,7 +313,11 @@ case class Rollup(
groupByExprs: Seq[Expression],
child: LogicalPlan,
aggregations: Seq[NamedExpression],
gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics
gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics {
def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics =
this.copy(aggregations = aggs)
}
case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
......
......@@ -21,7 +21,6 @@ import scala.language.implicitConversions
import org.apache.spark.annotation.Experimental
import org.apache.spark.Logging
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.analysis._
......
......@@ -32,7 +32,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.{MultiAlias, ResolvedStar, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{Filter, _}
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
......@@ -629,6 +629,10 @@ class DataFrame private[sql](
@scala.annotation.varargs
def select(cols: Column*): DataFrame = {
val namedExpressions = cols.map {
// Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we
// will remove intermediate Alias for ExtractValue chain, and we need to alias it again to
// make it a NamedExpression.
case Column(u: UnresolvedAttribute) => UnresolvedAlias(u)
case Column(expr: NamedExpression) => expr
// Leave an unaliased explode with an empty list of names since the analzyer will generate the
// correct defaults after the nested expression's type has been resolved.
......
......@@ -21,7 +21,7 @@ import scala.collection.JavaConversions._
import scala.language.implicitConversions
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.analysis.Star
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, Star}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate}
import org.apache.spark.sql.types.NumericType
......@@ -70,27 +70,31 @@ class GroupedData protected[sql](
groupingExprs: Seq[Expression],
private val groupType: GroupedData.GroupType) {
private[this] def toDF(aggExprs: Seq[NamedExpression]): DataFrame = {
private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = {
val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) {
val retainedExprs = groupingExprs.map {
case expr: NamedExpression => expr
case expr: Expression => Alias(expr, expr.prettyString)()
}
retainedExprs ++ aggExprs
} else {
aggExprs
}
groupingExprs ++ aggExprs
} else {
aggExprs
}
val aliasedAgg = aggregates.map {
// Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we
// will remove intermediate Alias for ExtractValue chain, and we need to alias it again to
// make it a NamedExpression.
case u: UnresolvedAttribute => UnresolvedAlias(u)
case expr: NamedExpression => expr
case expr: Expression => Alias(expr, expr.prettyString)()
}
groupType match {
case GroupedData.GroupByType =>
DataFrame(
df.sqlContext, Aggregate(groupingExprs, aggregates, df.logicalPlan))
df.sqlContext, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan))
case GroupedData.RollupType =>
DataFrame(
df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aggregates))
df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aliasedAgg))
case GroupedData.CubeType =>
DataFrame(
df.sqlContext, Cube(groupingExprs, df.logicalPlan, aggregates))
df.sqlContext, Cube(groupingExprs, df.logicalPlan, aliasedAgg))
}
}
......@@ -112,10 +116,7 @@ class GroupedData protected[sql](
namedExpr
}
}
toDF(columnExprs.map { c =>
val a = f(c)
Alias(a, a.prettyString)()
})
toDF(columnExprs.map(f))
}
private[this] def strToExpr(expr: String): (Expression => Expression) = {
......@@ -169,8 +170,7 @@ class GroupedData protected[sql](
*/
def agg(exprs: Map[String, String]): DataFrame = {
toDF(exprs.map { case (colName, expr) =>
val a = strToExpr(expr)(df(colName).expr)
Alias(a, a.prettyString)()
strToExpr(expr)(df(colName).expr)
}.toSeq)
}
......@@ -224,10 +224,7 @@ class GroupedData protected[sql](
*/
@scala.annotation.varargs
def agg(expr: Column, exprs: Column*): DataFrame = {
toDF((expr +: exprs).map(_.expr).map {
case expr: NamedExpression => expr
case expr: Expression => Alias(expr, expr.prettyString)()
})
toDF((expr +: exprs).map(_.expr))
}
/**
......
......@@ -74,7 +74,7 @@ private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] {
// Skip EvaluatePython nodes.
case plan: EvaluatePython => plan
case plan: LogicalPlan =>
case plan: LogicalPlan if plan.resolved =>
// Extract any PythonUDFs from the current operator.
val udfs = plan.expressions.flatMap(_.collect { case udf: PythonUDF => udf })
if (udfs.isEmpty) {
......
......@@ -1367,9 +1367,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
test("SPARK-6145: special cases") {
sqlContext.read.json(sqlContext.sparkContext.makeRDD(
"""{"a": {"b": [1]}, "b": [{"a": 1}], "c0": {"a": 1}}""" :: Nil)).registerTempTable("t")
checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY c0.a"), Row(1))
checkAnswer(sql("SELECT b[0].a FROM t ORDER BY c0.a"), Row(1))
"""{"a": {"b": [1]}, "b": [{"a": 1}], "_c0": {"a": 1}}""" :: Nil)).registerTempTable("t")
checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY _c0.a"), Row(1))
checkAnswer(sql("SELECT b[0].a FROM t ORDER BY _c0.a"), Row(1))
}
test("SPARK-6898: complete support for special chars in column names") {
......
......@@ -19,7 +19,6 @@ package org.apache.spark.sql
import java.sql.Timestamp
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.test._
......
......@@ -415,13 +415,6 @@ private[hive] object HiveQl {
throw new NotImplementedError(s"No parse rules for StructField:\n ${dumpTree(a).toString} ")
}
protected def nameExpressions(exprs: Seq[Expression]): Seq[NamedExpression] = {
exprs.zipWithIndex.map {
case (ne: NamedExpression, _) => ne
case (e, i) => Alias(e, s"_c$i")()
}
}
protected def extractDbNameTableName(tableNameParts: Node): (Option[String], String) = {
val (db, tableName) =
tableNameParts.getChildren.map { case Token(part, Nil) => cleanIdentifier(part) } match {
......@@ -942,7 +935,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
// (if there is a group by) or a script transformation.
val withProject: LogicalPlan = transformation.getOrElse {
val selectExpressions =
nameExpressions(select.getChildren.flatMap(selExprNodeToExpr).toSeq)
select.getChildren.flatMap(selExprNodeToExpr).map(UnresolvedAlias(_)).toSeq
Seq(
groupByClause.map(e => e match {
case Token("TOK_GROUPBY", children) =>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment