Skip to content
Snippets Groups Projects
Commit 4754e16f authored by Wenchen Fan's avatar Wenchen Fan Committed by Michael Armbrust
Browse files

[SPARK-6898][SQL] completely support special chars in column names

Even if we wrap column names in backticks like `` `a#$b.c` ``,  we still handle the "." inside column name specially. I think it's fragile to use a special char to split name parts, why not put name parts in `UnresolvedAttribute` directly?

Author: Wenchen Fan <cloud0fan@outlook.com>

This patch had conflicts when merged, resolved by
Committer: Michael Armbrust <michael@databricks.com>

Closes #5511 from cloud-fan/6898 and squashes the following commits:

48e3e57 [Wenchen Fan] more style fix
820dc45 [Wenchen Fan] do not ignore newName in UnresolvedAttribute
d81ad43 [Wenchen Fan] fix style
11699d6 [Wenchen Fan] completely support special chars in column names
parent 557a797a
No related branches found
No related tags found
No related merge requests found
Showing
with 52 additions and 33 deletions
......@@ -381,13 +381,13 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
| "(" ~> expression <~ ")"
| function
| dotExpressionHeader
| ident ^^ UnresolvedAttribute
| ident ^^ {case i => UnresolvedAttribute.quoted(i)}
| signedPrimary
| "~" ~> expression ^^ BitwiseNot
)
protected lazy val dotExpressionHeader: Parser[Expression] =
(ident <~ ".") ~ ident ~ rep("." ~> ident) ^^ {
case i1 ~ i2 ~ rest => UnresolvedAttribute((Seq(i1, i2) ++ rest).mkString("."))
case i1 ~ i2 ~ rest => UnresolvedAttribute(Seq(i1, i2) ++ rest)
}
}
......@@ -297,14 +297,15 @@ class Analyzer(
case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString}")
q transformExpressionsUp {
case u @ UnresolvedAttribute(name) if resolver(name, VirtualColumn.groupingIdName) &&
case u @ UnresolvedAttribute(nameParts) if nameParts.length == 1 &&
resolver(nameParts(0), VirtualColumn.groupingIdName) &&
q.isInstanceOf[GroupingAnalytics] =>
// Resolve the virtual column GROUPING__ID for the operator GroupingAnalytics
q.asInstanceOf[GroupingAnalytics].gid
case u @ UnresolvedAttribute(name) =>
case u @ UnresolvedAttribute(nameParts) =>
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
val result =
withPosition(u) { q.resolveChildren(name, resolver).getOrElse(u) }
withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) }
logDebug(s"Resolving $u to $result")
result
case UnresolvedGetField(child, fieldName) if child.resolved =>
......@@ -383,12 +384,12 @@ class Analyzer(
child: LogicalPlan,
grandchild: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = {
// Find any attributes that remain unresolved in the sort.
val unresolved: Seq[String] =
ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name })
val unresolved: Seq[Seq[String]] =
ordering.flatMap(_.collect { case UnresolvedAttribute(nameParts) => nameParts })
// Create a map from name, to resolved attributes, when the desired name can be found
// prior to the projection.
val resolved: Map[String, NamedExpression] =
val resolved: Map[Seq[String], NamedExpression] =
unresolved.flatMap(u => grandchild.resolve(u, resolver).map(a => u -> a)).toMap
// Construct a set that contains all of the attributes that we need to evaluate the
......
......@@ -46,8 +46,12 @@ trait CheckAnalysis {
operator transformExpressionsUp {
case a: Attribute if !a.resolved =>
if (operator.childrenResolved) {
val nameParts = a match {
case UnresolvedAttribute(nameParts) => nameParts
case _ => Seq(a.name)
}
// Throw errors for specific problems with get field.
operator.resolveChildren(a.name, resolver, throwErrors = true)
operator.resolveChildren(nameParts, resolver, throwErrors = true)
}
val from = operator.inputSet.map(_.name).mkString(", ")
......
......@@ -49,7 +49,12 @@ case class UnresolvedRelation(
/**
* Holds the name of an attribute that has yet to be resolved.
*/
case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNode[Expression] {
case class UnresolvedAttribute(nameParts: Seq[String])
extends Attribute with trees.LeafNode[Expression] {
def name: String =
nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".")
override def exprId: ExprId = throw new UnresolvedException(this, "exprId")
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
......@@ -59,7 +64,7 @@ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNo
override def newInstance(): UnresolvedAttribute = this
override def withNullability(newNullability: Boolean): UnresolvedAttribute = this
override def withQualifiers(newQualifiers: Seq[String]): UnresolvedAttribute = this
override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute(name)
override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName)
// Unresolved attributes are transient at compile time and don't get evaluated during execution.
override def eval(input: Row = null): EvaluatedType =
......@@ -68,6 +73,11 @@ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNo
override def toString: String = s"'$name"
}
object UnresolvedAttribute {
def apply(name: String): UnresolvedAttribute = new UnresolvedAttribute(name.split("\\."))
def quoted(name: String): UnresolvedAttribute = new UnresolvedAttribute(Seq(name))
}
case class UnresolvedFunction(name: String, children: Seq[Expression]) extends Expression {
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
override def foldable: Boolean = throw new UnresolvedException(this, "foldable")
......
......@@ -19,12 +19,11 @@ package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedGetField, Resolver}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, EliminateSubQueries, Resolver}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.types.{ArrayType, StructType, StructField}
abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
......@@ -111,10 +110,10 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
* as string in the following form: `[scope].AttributeName.[nested].[fields]...`.
*/
def resolveChildren(
name: String,
nameParts: Seq[String],
resolver: Resolver,
throwErrors: Boolean = false): Option[NamedExpression] =
resolve(name, children.flatMap(_.output), resolver, throwErrors)
resolve(nameParts, children.flatMap(_.output), resolver, throwErrors)
/**
* Optionally resolves the given string to a [[NamedExpression]] based on the output of this
......@@ -122,10 +121,10 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
* `[scope].AttributeName.[nested].[fields]...`.
*/
def resolve(
name: String,
nameParts: Seq[String],
resolver: Resolver,
throwErrors: Boolean = false): Option[NamedExpression] =
resolve(name, output, resolver, throwErrors)
resolve(nameParts, output, resolver, throwErrors)
/**
* Resolve the given `name` string against the given attribute, returning either 0 or 1 match.
......@@ -135,7 +134,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
* See the comment above `candidates` variable in resolve() for semantics the returned data.
*/
private def resolveAsTableColumn(
nameParts: Array[String],
nameParts: Seq[String],
resolver: Resolver,
attribute: Attribute): Option[(Attribute, List[String])] = {
assert(nameParts.length > 1)
......@@ -155,7 +154,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
* See the comment above `candidates` variable in resolve() for semantics the returned data.
*/
private def resolveAsColumn(
nameParts: Array[String],
nameParts: Seq[String],
resolver: Resolver,
attribute: Attribute): Option[(Attribute, List[String])] = {
if (resolver(attribute.name, nameParts.head)) {
......@@ -167,13 +166,11 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
/** Performs attribute resolution given a name and a sequence of possible attributes. */
protected def resolve(
name: String,
nameParts: Seq[String],
input: Seq[Attribute],
resolver: Resolver,
throwErrors: Boolean): Option[NamedExpression] = {
val parts = name.split("\\.")
// A sequence of possible candidate matches.
// Each candidate is a tuple. The first element is a resolved attribute, followed by a list
// of parts that are to be resolved.
......@@ -182,9 +179,9 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
// and the second element will be List("c").
var candidates: Seq[(Attribute, List[String])] = {
// If the name has 2 or more parts, try to resolve it as `table.column` first.
if (parts.length > 1) {
if (nameParts.length > 1) {
input.flatMap { option =>
resolveAsTableColumn(parts, resolver, option)
resolveAsTableColumn(nameParts, resolver, option)
}
} else {
Seq.empty
......@@ -194,10 +191,12 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
// If none of attributes match `table.column` pattern, we try to resolve it as a column.
if (candidates.isEmpty) {
candidates = input.flatMap { candidate =>
resolveAsColumn(parts, resolver, candidate)
resolveAsColumn(nameParts, resolver, candidate)
}
}
def name = UnresolvedAttribute(nameParts).name
candidates.distinct match {
// One match, no nested fields, use it.
case Seq((a, Nil)) => Some(a)
......
......@@ -27,8 +27,6 @@ import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import scala.collection.immutable
class AnalysisSuite extends FunSuite with BeforeAndAfter {
val caseSensitiveCatalog = new SimpleCatalog(true)
val caseInsensitiveCatalog = new SimpleCatalog(false)
......
......@@ -158,7 +158,7 @@ class DataFrame private[sql](
}
protected[sql] def resolve(colName: String): NamedExpression = {
queryExecution.analyzed.resolve(colName, sqlContext.analyzer.resolver).getOrElse {
queryExecution.analyzed.resolve(colName.split("\\."), sqlContext.analyzer.resolver).getOrElse {
throw new AnalysisException(
s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""")
}
......@@ -166,7 +166,7 @@ class DataFrame private[sql](
protected[sql] def numericColumns: Seq[Expression] = {
schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n =>
queryExecution.analyzed.resolve(n.name, sqlContext.analyzer.resolver).get
queryExecution.analyzed.resolve(n.name.split("\\."), sqlContext.analyzer.resolver).get
}
}
......
......@@ -19,14 +19,13 @@ package org.apache.spark.sql
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.execution.GeneratedAggregate
import org.apache.spark.sql.functions._
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.{udf => _, _}
import org.apache.spark.sql.types._
class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
// Make sure the tables are loaded.
TestData
......@@ -1125,7 +1124,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
val data = sparkContext.parallelize(
Seq("""{"key?number1": "value1", "key.number2": "value2"}"""))
jsonRDD(data).registerTempTable("records")
sql("SELECT `key?number1` FROM records")
sql("SELECT `key?number1`, `key.number2` FROM records")
}
test("SPARK-3814 Support Bitwise & operator") {
......@@ -1225,4 +1224,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY c0.a"), Row(1))
checkAnswer(sql("SELECT b[0].a FROM t ORDER BY c0.a"), Row(1))
}
test("SPARK-6898: complete support for special chars in column names") {
jsonRDD(sparkContext.makeRDD(
"""{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil))
.registerTempTable("t")
checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1))
}
}
......@@ -1101,7 +1101,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
case Token(".", qualifier :: Token(attr, Nil) :: Nil) =>
nodeToExpr(qualifier) match {
case UnresolvedAttribute(qualifierName) =>
UnresolvedAttribute(qualifierName + "." + cleanIdentifier(attr))
UnresolvedAttribute(qualifierName :+ cleanIdentifier(attr))
case other => UnresolvedGetField(other, attr)
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment