diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SelectClauseParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SelectClauseParser.g index 2d2bafb1ee34fa38e193b938968b0489a4eb1b24..f18b6ec496f8fe17fae0de42c6d5a9ec1c7d00d7 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SelectClauseParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SelectClauseParser.g @@ -131,6 +131,13 @@ selectItem : (tableAllColumns) => tableAllColumns -> ^(TOK_SELEXPR tableAllColumns) | + namedExpression + ; + +namedExpression +@init { gParent.pushMsg("select named expression", state); } +@after { gParent.popMsg(state); } + : ( expression ((KW_AS? identifier) | (KW_AS LPAREN identifier (COMMA identifier)* RPAREN))? ) -> ^(TOK_SELEXPR expression identifier*) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala index 2e3cc0bfde7c76468c756893a9c74007f7445b24..c87b6c8e9543605149a4ca7e25b4c3d5a3199924 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala @@ -30,6 +30,12 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.random.RandomSampler +private[sql] object CatalystQl { + val parser = new CatalystQl + def parseExpression(sql: String): Expression = parser.parseExpression(sql) + def parseTableIdentifier(sql: String): TableIdentifier = parser.parseTableIdentifier(sql) +} + /** * This class translates a HQL String to a Catalyst [[LogicalPlan]] or [[Expression]]. */ @@ -41,16 +47,13 @@ private[sql] class CatalystQl(val conf: ParserConf = SimpleParserConf()) { } } - /** - * Returns the AST for the given SQL string. + * The safeParse method allows a user to focus on the parsing/AST transformation logic. This + * method will take care of possible errors during the parsing process. */ - protected def getAst(sql: String): ASTNode = ParseDriver.parse(sql, conf) - - /** Creates LogicalPlan for a given HiveQL string. */ - def createPlan(sql: String): LogicalPlan = { + protected def safeParse[T](sql: String, ast: ASTNode)(toResult: ASTNode => T): T = { try { - createPlan(sql, ParseDriver.parse(sql, conf)) + toResult(ast) } catch { case e: MatchError => throw e case e: AnalysisException => throw e @@ -58,26 +61,39 @@ private[sql] class CatalystQl(val conf: ParserConf = SimpleParserConf()) { throw new AnalysisException(e.getMessage) case e: NotImplementedError => throw new AnalysisException( - s""" - |Unsupported language features in query: $sql - |${getAst(sql).treeString} + s"""Unsupported language features in query + |== SQL == + |$sql + |== AST == + |${ast.treeString} + |== Error == |$e + |== Stacktrace == |${e.getStackTrace.head} """.stripMargin) } } - protected def createPlan(sql: String, tree: ASTNode): LogicalPlan = nodeToPlan(tree) - - def parseDdl(ddl: String): Seq[Attribute] = { - val tree = getAst(ddl) - assert(tree.text == "TOK_CREATETABLE", "Only CREATE TABLE supported.") - val tableOps = tree.children - val colList = tableOps - .find(_.text == "TOK_TABCOLLIST") - .getOrElse(sys.error("No columnList!")) - - colList.children.map(nodeToAttribute) + /** Creates LogicalPlan for a given SQL string. */ + def parsePlan(sql: String): LogicalPlan = + safeParse(sql, ParseDriver.parsePlan(sql, conf))(nodeToPlan) + + /** Creates Expression for a given SQL string. */ + def parseExpression(sql: String): Expression = + safeParse(sql, ParseDriver.parseExpression(sql, conf))(selExprNodeToExpr(_).get) + + /** Creates TableIdentifier for a given SQL string. */ + def parseTableIdentifier(sql: String): TableIdentifier = + safeParse(sql, ParseDriver.parseTableName(sql, conf))(extractTableIdent) + + def parseDdl(sql: String): Seq[Attribute] = { + safeParse(sql, ParseDriver.parseExpression(sql, conf)) { ast => + val Token("TOK_CREATETABLE", children) = ast + children + .find(_.text == "TOK_TABCOLLIST") + .getOrElse(sys.error("No columnList!")) + .flatMap(_.children.map(nodeToAttribute)) + } } protected def getClauses( @@ -187,7 +203,6 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val keyMap = keyASTs.zipWithIndex.toMap val bitmasks: Seq[Int] = setASTs.map { - case Token("TOK_GROUPING_SETS_EXPRESSION", null) => 0 case Token("TOK_GROUPING_SETS_EXPRESSION", columns) => columns.foldLeft(0)((bitmap, col) => { val keyIndex = keyMap.find(_._1.treeEquals(col)).map(_._2) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index 0e93af8b92cd2f2e6184acf616b19fbcf85dd00f..f8e4f21451192808f86a327e34457e5f139f8fbb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -28,7 +28,25 @@ import org.apache.spark.sql.AnalysisException * This is based on Hive's org.apache.hadoop.hive.ql.parse.ParseDriver */ object ParseDriver extends Logging { - def parse(command: String, conf: ParserConf): ASTNode = { + /** Create an LogicalPlan ASTNode from a SQL command. */ + def parsePlan(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser => + parser.statement().getTree + } + + /** Create an Expression ASTNode from a SQL command. */ + def parseExpression(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser => + parser.namedExpression().getTree + } + + /** Create an TableIdentifier ASTNode from a SQL command. */ + def parseTableName(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser => + parser.tableName().getTree + } + + private def parse( + command: String, + conf: ParserConf)( + toTree: SparkSqlParser => CommonTree): ASTNode = { logInfo(s"Parsing command: $command") // Setup error collection. @@ -44,7 +62,7 @@ object ParseDriver extends Logging { parser.configure(conf, reporter) try { - val result = parser.statement() + val result = toTree(parser) // Check errors. reporter.checkForErrors() @@ -57,7 +75,7 @@ object ParseDriver extends Logging { if (tree.token != null || tree.getChildCount == 0) tree else nonNullToken(tree.getChild(0).asInstanceOf[CommonTree]) } - val tree = nonNullToken(result.getTree) + val tree = nonNullToken(result) // Make sure all boundaries are set. tree.setUnknownTokenBoundaries() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala index d7204c3488313d7ad31adf3da46ff673abfab44d..ba9d2524a95513021d0f405b468928d97bdb0931 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala @@ -17,36 +17,157 @@ package org.apache.spark.sql.catalyst +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} +import org.apache.spark.unsafe.types.CalendarInterval class CatalystQlSuite extends PlanTest { val parser = new CatalystQl() + test("test case insensitive") { + val result = Project(UnresolvedAlias(Literal(1)):: Nil, OneRowRelation) + assert(result === parser.parsePlan("seLect 1")) + assert(result === parser.parsePlan("select 1")) + assert(result === parser.parsePlan("SELECT 1")) + } + + test("test NOT operator with comparison operations") { + val parsed = parser.parsePlan("SELECT NOT TRUE > TRUE") + val expected = Project( + UnresolvedAlias( + Not( + GreaterThan(Literal(true), Literal(true))) + ) :: Nil, + OneRowRelation) + comparePlans(parsed, expected) + } + + test("support hive interval literal") { + def checkInterval(sql: String, result: CalendarInterval): Unit = { + val parsed = parser.parsePlan(sql) + val expected = Project( + UnresolvedAlias( + Literal(result) + ) :: Nil, + OneRowRelation) + comparePlans(parsed, expected) + } + + def checkYearMonth(lit: String): Unit = { + checkInterval( + s"SELECT INTERVAL '$lit' YEAR TO MONTH", + CalendarInterval.fromYearMonthString(lit)) + } + + def checkDayTime(lit: String): Unit = { + checkInterval( + s"SELECT INTERVAL '$lit' DAY TO SECOND", + CalendarInterval.fromDayTimeString(lit)) + } + + def checkSingleUnit(lit: String, unit: String): Unit = { + checkInterval( + s"SELECT INTERVAL '$lit' $unit", + CalendarInterval.fromSingleUnitString(unit, lit)) + } + + checkYearMonth("123-10") + checkYearMonth("496-0") + checkYearMonth("-2-3") + checkYearMonth("-123-0") + + checkDayTime("99 11:22:33.123456789") + checkDayTime("-99 11:22:33.123456789") + checkDayTime("10 9:8:7.123456789") + checkDayTime("1 0:0:0") + checkDayTime("-1 0:0:0") + checkDayTime("1 0:0:1") + + for (unit <- Seq("year", "month", "day", "hour", "minute", "second")) { + checkSingleUnit("7", unit) + checkSingleUnit("-7", unit) + checkSingleUnit("0", unit) + } + + checkSingleUnit("13.123456789", "second") + checkSingleUnit("-13.123456789", "second") + } + + test("support scientific notation") { + def assertRight(input: String, output: Double): Unit = { + val parsed = parser.parsePlan("SELECT " + input) + val expected = Project( + UnresolvedAlias( + Literal(output) + ) :: Nil, + OneRowRelation) + comparePlans(parsed, expected) + } + + assertRight("9.0e1", 90) + assertRight("0.9e+2", 90) + assertRight("900e-1", 90) + assertRight("900.0E-1", 90) + assertRight("9.e+1", 90) + + intercept[AnalysisException](parser.parsePlan("SELECT .e3")) + } + + test("parse expressions") { + compareExpressions( + parser.parseExpression("prinln('hello', 'world')"), + UnresolvedFunction( + "prinln", Literal("hello") :: Literal("world") :: Nil, false)) + + compareExpressions( + parser.parseExpression("1 + r.r As q"), + Alias(Add(Literal(1), UnresolvedAttribute("r.r")), "q")()) + + compareExpressions( + parser.parseExpression("1 - f('o', o(bar))"), + Subtract(Literal(1), + UnresolvedFunction("f", + Literal("o") :: + UnresolvedFunction("o", UnresolvedAttribute("bar") :: Nil, false) :: + Nil, false))) + } + + test("table identifier") { + assert(TableIdentifier("q") === parser.parseTableIdentifier("q")) + assert(TableIdentifier("q", Some("d")) === parser.parseTableIdentifier("d.q")) + intercept[AnalysisException](parser.parseTableIdentifier("")) + // TODO parser swallows third identifier. + // intercept[AnalysisException](parser.parseTableIdentifier("d.q.g")) + } + test("parse union/except/intersect") { - parser.createPlan("select * from t1 union all select * from t2") - parser.createPlan("select * from t1 union distinct select * from t2") - parser.createPlan("select * from t1 union select * from t2") - parser.createPlan("select * from t1 except select * from t2") - parser.createPlan("select * from t1 intersect select * from t2") - parser.createPlan("(select * from t1) union all (select * from t2)") - parser.createPlan("(select * from t1) union distinct (select * from t2)") - parser.createPlan("(select * from t1) union (select * from t2)") - parser.createPlan("select * from ((select * from t1) union (select * from t2)) t") + parser.parsePlan("select * from t1 union all select * from t2") + parser.parsePlan("select * from t1 union distinct select * from t2") + parser.parsePlan("select * from t1 union select * from t2") + parser.parsePlan("select * from t1 except select * from t2") + parser.parsePlan("select * from t1 intersect select * from t2") + parser.parsePlan("(select * from t1) union all (select * from t2)") + parser.parsePlan("(select * from t1) union distinct (select * from t2)") + parser.parsePlan("(select * from t1) union (select * from t2)") + parser.parsePlan("select * from ((select * from t1) union (select * from t2)) t") } test("window function: better support of parentheses") { - parser.createPlan("select sum(product + 1) over (partition by ((1) + (product / 2)) " + + parser.parsePlan("select sum(product + 1) over (partition by ((1) + (product / 2)) " + "order by 2) from windowData") - parser.createPlan("select sum(product + 1) over (partition by (1 + (product / 2)) " + + parser.parsePlan("select sum(product + 1) over (partition by (1 + (product / 2)) " + "order by 2) from windowData") - parser.createPlan("select sum(product + 1) over (partition by ((product / 2) + 1) " + + parser.parsePlan("select sum(product + 1) over (partition by ((product / 2) + 1) " + "order by 2) from windowData") - parser.createPlan("select sum(product + 1) over (partition by ((product) + (1)) order by 2) " + + parser.parsePlan("select sum(product + 1) over (partition by ((product) + (1)) order by 2) " + "from windowData") - parser.createPlan("select sum(product + 1) over (partition by ((product) + 1) order by 2) " + + parser.parsePlan("select sum(product + 1) over (partition by ((product) + 1) order by 2) " + "from windowData") - parser.createPlan("select sum(product + 1) over (partition by (product + (1)) order by 2) " + + parser.parsePlan("select sum(product + 1) over (partition by (product + (1)) order by 2) " + "from windowData") } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala index 395c8bff53f471cd07c3f65bd86435e5a07a6188..b22f4249813254ce37b5e8fc6ccce71df685814a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala @@ -38,7 +38,7 @@ private[hive] class ExtendedHiveQlParser extends AbstractSparkSQLParser { protected lazy val hiveQl: Parser[LogicalPlan] = restInput ^^ { - case statement => HiveQl.createPlan(statement.trim) + case statement => HiveQl.parsePlan(statement.trim) } protected lazy val dfs: Parser[LogicalPlan] = diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 43d84d507b20ea9294bb14a29a0877559be043c9..67228f3f3c9c91c52b4da0bb8b6090e10d387b63 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -414,8 +414,8 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive alias match { // because hive use things like `_c0` to build the expanded text // currently we cannot support view from "create view v1(c1) as ..." - case None => Subquery(table.name, HiveQl.createPlan(viewText)) - case Some(aliasText) => Subquery(aliasText, HiveQl.createPlan(viewText)) + case None => Subquery(table.name, HiveQl.parsePlan(viewText)) + case Some(aliasText) => Subquery(aliasText, HiveQl.parsePlan(viewText)) } } else { MetastoreRelation(qualifiedTableName.database, qualifiedTableName.name, alias)(table)(hive) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index d1b1c0d8d8bc223665eb55f3ecf6d23867ac1070..ca9ddf94c11a70be6dbefbe19d297f2b85bcfe1e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -230,15 +230,16 @@ private[hive] object HiveQl extends SparkQl with Logging { CreateViewAsSelect(tableDesc, nodeToPlan(query), allowExist, replace, sql) } - protected override def createPlan( - sql: String, - node: ASTNode): LogicalPlan = { - if (nativeCommands.contains(node.text)) { - HiveNativeCommand(sql) - } else { - nodeToPlan(node) match { - case NativePlaceholder => HiveNativeCommand(sql) - case plan => plan + /** Creates LogicalPlan for a given SQL string. */ + override def parsePlan(sql: String): LogicalPlan = { + safeParse(sql, ParseDriver.parsePlan(sql, conf)) { ast => + if (nativeCommands.contains(ast.text)) { + HiveNativeCommand(sql) + } else { + nodeToPlan(ast) match { + case NativePlaceholder => HiveNativeCommand(sql) + case plan => plan + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala index e72a18a716b5ced16644fd1a1f5e15668d85f8a8..14a466cfe94864cd54c37d43ca0e24d348f92a55 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala @@ -117,9 +117,8 @@ class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAnd * @param token a unique token in the string that should be indicated by the exception */ def positionTest(name: String, query: String, token: String): Unit = { - def ast = ParseDriver.parse(query, hiveContext.conf) - def parseTree = - Try(quietly(ast.treeString)).getOrElse("<failed to parse>") + def ast = ParseDriver.parsePlan(query, hiveContext.conf) + def parseTree = Try(quietly(ast.treeString)).getOrElse("<failed to parse>") test(name) { val error = intercept[AnalysisException] { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala index f4a1a174224832e4b74de25f85af74aa3a5d2f15..53d15c14cb3d5ba562fcce9a3c5418bffcaa663b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.hive.client.{ExternalTable, HiveColumn, HiveTable, M class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { private def extractTableDesc(sql: String): (HiveTable, Boolean) = { - HiveQl.createPlan(sql).collect { + HiveQl.parsePlan(sql).collect { case CreateTableAsSelect(desc, child, allowExisting) => (desc, allowExisting) }.head }