diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 index 44727f9876deb8eaa4a15252c60e782c96e38afd..e4373f79f792238cb0d388cd2a9bb51404653ae8 100644 --- a/dev/deps/spark-deps-hadoop-2.2 +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -5,8 +5,7 @@ activation-1.1.jar akka-actor_2.10-2.3.11.jar akka-remote_2.10-2.3.11.jar akka-slf4j_2.10-2.3.11.jar -antlr-2.7.7.jar -antlr-runtime-3.4.jar +antlr-runtime-3.5.2.jar aopalliance-1.0.jar apache-log4j-extras-1.2.17.jar arpack_combined_all-0.1.jar @@ -179,7 +178,6 @@ spire_2.10-0.7.4.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar -stringtemplate-3.2.1.jar super-csv-2.2.0.jar tachyon-client-0.8.2.jar tachyon-underfs-hdfs-0.8.2.jar diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 index 6014d50c6b6fd30a5720e343abedfc19b801d4fb..7478181406d07d249acbd7109b725b5cfa8b08ad 100644 --- a/dev/deps/spark-deps-hadoop-2.3 +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -5,8 +5,7 @@ activation-1.1.1.jar akka-actor_2.10-2.3.11.jar akka-remote_2.10-2.3.11.jar akka-slf4j_2.10-2.3.11.jar -antlr-2.7.7.jar -antlr-runtime-3.4.jar +antlr-runtime-3.5.2.jar aopalliance-1.0.jar apache-log4j-extras-1.2.17.jar arpack_combined_all-0.1.jar @@ -170,7 +169,6 @@ spire_2.10-0.7.4.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar -stringtemplate-3.2.1.jar super-csv-2.2.0.jar tachyon-client-0.8.2.jar tachyon-underfs-hdfs-0.8.2.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 index f56e6f4393e787b080642575dbfa613cb1495611..faffb8bf398a5c2a1a2f802403537cfadaef304c 100644 --- a/dev/deps/spark-deps-hadoop-2.4 +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -5,8 +5,7 @@ activation-1.1.1.jar akka-actor_2.10-2.3.11.jar akka-remote_2.10-2.3.11.jar akka-slf4j_2.10-2.3.11.jar -antlr-2.7.7.jar -antlr-runtime-3.4.jar +antlr-runtime-3.5.2.jar aopalliance-1.0.jar apache-log4j-extras-1.2.17.jar arpack_combined_all-0.1.jar @@ -171,7 +170,6 @@ spire_2.10-0.7.4.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar -stringtemplate-3.2.1.jar super-csv-2.2.0.jar tachyon-client-0.8.2.jar tachyon-underfs-hdfs-0.8.2.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index e37484473db2e892fe808b5e4b5f12feb508314e..e703c7acd38765333be57792ae8cf7d41ce5bd82 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -5,8 +5,7 @@ activation-1.1.1.jar akka-actor_2.10-2.3.11.jar akka-remote_2.10-2.3.11.jar akka-slf4j_2.10-2.3.11.jar -antlr-2.7.7.jar -antlr-runtime-3.4.jar +antlr-runtime-3.5.2.jar aopalliance-1.0.jar apache-log4j-extras-1.2.17.jar apacheds-i18n-2.0.0-M15.jar @@ -177,7 +176,6 @@ spire_2.10-0.7.4.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar -stringtemplate-3.2.1.jar super-csv-2.2.0.jar tachyon-client-0.8.2.jar tachyon-underfs-hdfs-0.8.2.jar diff --git a/pom.xml b/pom.xml index d0ac1eb39aabefa573fe2c6f544795d19c65b83e..e414a8bfe6ce52a6475a4cb3c42e839b61869ec6 100644 --- a/pom.xml +++ b/pom.xml @@ -183,6 +183,7 @@ <jodd.version>3.5.2</jodd.version> <jsr305.version>1.3.9</jsr305.version> <libthrift.version>0.9.2</libthrift.version> + <antlr.version>3.5.2</antlr.version> <test.java.home>${java.home}</test.java.home> <test.exclude.tags></test.exclude.tags> @@ -1843,6 +1844,11 @@ </exclusion> </exclusions> </dependency> + <dependency> + <groupId>org.antlr</groupId> + <artifactId>antlr-runtime</artifactId> + <version>${antlr.version}</version> + </dependency> </dependencies> </dependencyManagement> diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index af1d36c6ea57bc0450431ae6849e33a033ae1a14..5d4f19ab14a2988339fbb5353242e8da08dcb55a 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -247,6 +247,9 @@ object SparkBuild extends PomBuild { /* Enable unidoc only for the root spark project */ enable(Unidoc.settings)(spark) + /* Catalyst ANTLR generation settings */ + enable(Catalyst.settings)(catalyst) + /* Spark SQL Core console settings */ enable(SQL.settings)(sql) @@ -357,6 +360,58 @@ object OldDeps { ) } +object Catalyst { + lazy val settings = Seq( + // ANTLR code-generation step. + // + // This has been heavily inspired by com.github.stefri.sbt-antlr (0.5.3). It fixes a number of + // build errors in the current plugin. + // Create Parser from ANTLR grammar files. + sourceGenerators in Compile += Def.task { + val log = streams.value.log + + val grammarFileNames = Seq( + "SparkSqlLexer.g", + "SparkSqlParser.g") + val sourceDir = (sourceDirectory in Compile).value / "antlr3" + val targetDir = (sourceManaged in Compile).value + + // Create default ANTLR Tool. + val antlr = new org.antlr.Tool + + // Setup input and output directories. + antlr.setInputDirectory(sourceDir.getPath) + antlr.setOutputDirectory(targetDir.getPath) + antlr.setForceRelativeOutput(true) + antlr.setMake(true) + + // Add grammar files. + grammarFileNames.flatMap(gFileName => (sourceDir ** gFileName).get).foreach { gFilePath => + val relGFilePath = (gFilePath relativeTo sourceDir).get.getPath + log.info("ANTLR: Grammar file '%s' detected.".format(relGFilePath)) + antlr.addGrammarFile(relGFilePath) + // We will set library directory multiple times here. However, only the + // last one has effect. Because the grammar files are located under the same directory, + // We assume there is only one library directory. + antlr.setLibDirectory(gFilePath.getParent) + } + + // Generate the parser. + antlr.process + if (antlr.getNumErrors > 0) { + log.error("ANTLR: Caught %d build errors.".format(antlr.getNumErrors)) + } + + // Return all generated java files. + (targetDir ** "*.java").get.toSeq + }.taskValue, + // Include ANTLR tokens files. + resourceGenerators in Compile += Def.task { + ((sourceManaged in Compile).value ** "*.tokens").get.toSeq + }.taskValue + ) +} + object SQL { lazy val settings = Seq( initialCommands in console := @@ -414,54 +469,7 @@ object Hive { // Some of our log4j jars make it impossible to submit jobs from this JVM to Hive Map/Reduce // in order to generate golden files. This is only required for developers who are adding new // new query tests. - fullClasspath in Test := (fullClasspath in Test).value.filterNot { f => f.toString.contains("jcl-over") }, - // ANTLR code-generation step. - // - // This has been heavily inspired by com.github.stefri.sbt-antlr (0.5.3). It fixes a number of - // build errors in the current plugin. - // Create Parser from ANTLR grammar files. - sourceGenerators in Compile += Def.task { - val log = streams.value.log - - val grammarFileNames = Seq( - "SparkSqlLexer.g", - "SparkSqlParser.g") - val sourceDir = (sourceDirectory in Compile).value / "antlr3" - val targetDir = (sourceManaged in Compile).value - - // Create default ANTLR Tool. - val antlr = new org.antlr.Tool - - // Setup input and output directories. - antlr.setInputDirectory(sourceDir.getPath) - antlr.setOutputDirectory(targetDir.getPath) - antlr.setForceRelativeOutput(true) - antlr.setMake(true) - - // Add grammar files. - grammarFileNames.flatMap(gFileName => (sourceDir ** gFileName).get).foreach { gFilePath => - val relGFilePath = (gFilePath relativeTo sourceDir).get.getPath - log.info("ANTLR: Grammar file '%s' detected.".format(relGFilePath)) - antlr.addGrammarFile(relGFilePath) - // We will set library directory multiple times here. However, only the - // last one has effect. Because the grammar files are located under the same directory, - // We assume there is only one library directory. - antlr.setLibDirectory(gFilePath.getParent) - } - - // Generate the parser. - antlr.process - if (antlr.getNumErrors > 0) { - log.error("ANTLR: Caught %d build errors.".format(antlr.getNumErrors)) - } - - // Return all generated java files. - (targetDir ** "*.java").get.toSeq - }.taskValue, - // Include ANTLR tokens files. - resourceGenerators in Compile += Def.task { - ((sourceManaged in Compile).value ** "*.tokens").get.toSeq - }.taskValue + fullClasspath in Test := (fullClasspath in Test).value.filterNot { f => f.toString.contains("jcl-over") } ) } diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index cfa520b7b9db2ceb7bb99292729fe150076f46dc..76ca3f3bb1bfa701994e1596bfe6202a7bfc601a 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -71,6 +71,10 @@ <groupId>org.codehaus.janino</groupId> <artifactId>janino</artifactId> </dependency> + <dependency> + <groupId>org.antlr</groupId> + <artifactId>antlr-runtime</artifactId> + </dependency> </dependencies> <build> <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory> @@ -103,6 +107,24 @@ </execution> </executions> </plugin> + <plugin> + <groupId>org.antlr</groupId> + <artifactId>antlr3-maven-plugin</artifactId> + <executions> + <execution> + <goals> + <goal>antlr</goal> + </goals> + </execution> + </executions> + <configuration> + <sourceDirectory>../catalyst/src/main/antlr3</sourceDirectory> + <includes> + <include>**/SparkSqlLexer.g</include> + <include>**/SparkSqlParser.g</include> + </includes> + </configuration> + </plugin> </plugins> </build> <profiles> diff --git a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/FromClauseParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g similarity index 98% rename from sql/hive/src/main/antlr3/org/apache/spark/sql/parser/FromClauseParser.g rename to sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g index e4a80f0ce8ebf64caac9455ac6460c9c936d249e..ba6cfc60f045fddbd7b68828f222b0f11efdcfb3 100644 --- a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/FromClauseParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g @@ -13,6 +13,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. + + This file is an adaptation of Hive's org/apache/hadoop/hive/ql/FromClauseParser.g grammar. */ parser grammar FromClauseParser; @@ -33,7 +35,7 @@ k=3; @Override public void displayRecognitionError(String[] tokenNames, RecognitionException e) { - gParent.errors.add(new ParseError(gParent, e, tokenNames)); + gParent.displayRecognitionError(tokenNames, e); } protected boolean useSQL11ReservedKeywordsForIdentifier() { return gParent.useSQL11ReservedKeywordsForIdentifier(); diff --git a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/IdentifiersParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/IdentifiersParser.g similarity index 99% rename from sql/hive/src/main/antlr3/org/apache/spark/sql/parser/IdentifiersParser.g rename to sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/IdentifiersParser.g index 9f1e168374f010b6b1f6df81070785e64b80e2a4..86c6bd610f9120af1b63f62266e0f9a20085e817 100644 --- a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/IdentifiersParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/IdentifiersParser.g @@ -13,6 +13,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. + + This file is an adaptation of Hive's org/apache/hadoop/hive/ql/IdentifiersParser.g grammar. */ parser grammar IdentifiersParser; @@ -33,7 +35,7 @@ k=3; @Override public void displayRecognitionError(String[] tokenNames, RecognitionException e) { - gParent.errors.add(new ParseError(gParent, e, tokenNames)); + gParent.displayRecognitionError(tokenNames, e); } protected boolean useSQL11ReservedKeywordsForIdentifier() { return gParent.useSQL11ReservedKeywordsForIdentifier(); diff --git a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SelectClauseParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SelectClauseParser.g similarity index 97% rename from sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SelectClauseParser.g rename to sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SelectClauseParser.g index 48bc8b0a300af9a756401176b31885055aedb2d9..2d2bafb1ee34fa38e193b938968b0489a4eb1b24 100644 --- a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SelectClauseParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SelectClauseParser.g @@ -13,6 +13,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. + + This file is an adaptation of Hive's org/apache/hadoop/hive/ql/SelectClauseParser.g grammar. */ parser grammar SelectClauseParser; @@ -33,7 +35,7 @@ k=3; @Override public void displayRecognitionError(String[] tokenNames, RecognitionException e) { - gParent.errors.add(new ParseError(gParent, e, tokenNames)); + gParent.displayRecognitionError(tokenNames, e); } protected boolean useSQL11ReservedKeywordsForIdentifier() { return gParent.useSQL11ReservedKeywordsForIdentifier(); diff --git a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlLexer.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g similarity index 93% rename from sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlLexer.g rename to sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g index ee1b8989b5affeac6d947a0f4020e77c97be1126..e01e7101d0b7e6eda9fedb9a93e394059121349a 100644 --- a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlLexer.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g @@ -13,26 +13,37 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. + + This file is an adaptation of Hive's org/apache/hadoop/hive/ql/HiveLexer.g grammar. */ lexer grammar SparkSqlLexer; @lexer::header { -package org.apache.spark.sql.parser; +package org.apache.spark.sql.catalyst.parser; -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.hive.conf.HiveConf; } @lexer::members { - private Configuration hiveConf; + private ParserConf parserConf; + private ParseErrorReporter reporter; - public void setHiveConf(Configuration hiveConf) { - this.hiveConf = hiveConf; + public void configure(ParserConf parserConf, ParseErrorReporter reporter) { + this.parserConf = parserConf; + this.reporter = reporter; } protected boolean allowQuotedId() { - String supportedQIds = HiveConf.getVar(hiveConf, HiveConf.ConfVars.HIVE_QUOTEDID_SUPPORT); - return !"none".equals(supportedQIds); + if (parserConf == null) { + return true; + } + return parserConf.supportQuotedId(); + } + + @Override + public void displayRecognitionError(String[] tokenNames, RecognitionException e) { + if (reporter != null) { + reporter.report(this, e, tokenNames); + } } } diff --git a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g similarity index 99% rename from sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlParser.g rename to sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g index 69574d713d0be02c5988f149fc2eb4877ae7fe18..98b46794a630c50fcd8d750637a9ae0ae55dc23c 100644 --- a/sql/hive/src/main/antlr3/org/apache/spark/sql/parser/SparkSqlParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g @@ -13,6 +13,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. + + This file is an adaptation of Hive's org/apache/hadoop/hive/ql/HiveParser.g grammar. */ parser grammar SparkSqlParser; @@ -369,18 +371,15 @@ TOK_SET_AUTOCOMMIT; // Package headers @header { -package org.apache.spark.sql.parser; +package org.apache.spark.sql.catalyst.parser; import java.util.Arrays; import java.util.Collection; import java.util.HashMap; -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.hive.conf.HiveConf; } @members { - ArrayList<ParseError> errors = new ArrayList<ParseError>(); Stack msgs = new Stack<String>(); private static HashMap<String, String> xlateMap; @@ -563,9 +562,10 @@ import org.apache.hadoop.hive.conf.HiveConf; } @Override - public void displayRecognitionError(String[] tokenNames, - RecognitionException e) { - errors.add(new ParseError(this, e, tokenNames)); + public void displayRecognitionError(String[] tokenNames, RecognitionException e) { + if (reporter != null) { + reporter.report(this, e, tokenNames); + } } @Override @@ -654,15 +654,20 @@ import org.apache.hadoop.hive.conf.HiveConf; private CommonTree throwColumnNameException() throws RecognitionException { throw new FailedPredicateException(input, Arrays.toString(excludedCharForColumnName) + " can not be used in column name in create table statement.", ""); } - private Configuration hiveConf; - public void setHiveConf(Configuration hiveConf) { - this.hiveConf = hiveConf; + + private ParserConf parserConf; + private ParseErrorReporter reporter; + + public void configure(ParserConf parserConf, ParseErrorReporter reporter) { + this.parserConf = parserConf; + this.reporter = reporter; } + protected boolean useSQL11ReservedKeywordsForIdentifier() { - if(hiveConf==null){ - return false; + if (parserConf == null) { + return true; } - return !HiveConf.getBoolVar(hiveConf, HiveConf.ConfVars.HIVE_SUPPORT_SQL11_RESERVED_KEYWORDS); + return !parserConf.supportSQL11ReservedKeywords(); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java new file mode 100644 index 0000000000000000000000000000000000000000..5bc87b680f9ad82851fe9f136c9bb4330e214fec --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java @@ -0,0 +1,162 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.parser; + +import java.io.UnsupportedEncodingException; + +/** + * A couple of utility methods that help with parsing ASTs. + * + * Both methods in this class were take from the SemanticAnalyzer in Hive: + * ql/src/java/org/apache/hadoop/hive/ql/parse/BaseSemanticAnalyzer.java + */ +public final class ParseUtils { + private ParseUtils() { + super(); + } + + public static String charSetString(String charSetName, String charSetString) + throws UnsupportedEncodingException { + // The character set name starts with a _, so strip that + charSetName = charSetName.substring(1); + if (charSetString.charAt(0) == '\'') { + return new String(unescapeSQLString(charSetString).getBytes(), charSetName); + } else // hex input is also supported + { + assert charSetString.charAt(0) == '0'; + assert charSetString.charAt(1) == 'x'; + charSetString = charSetString.substring(2); + + byte[] bArray = new byte[charSetString.length() / 2]; + int j = 0; + for (int i = 0; i < charSetString.length(); i += 2) { + int val = Character.digit(charSetString.charAt(i), 16) * 16 + + Character.digit(charSetString.charAt(i + 1), 16); + if (val > 127) { + val = val - 256; + } + bArray[j++] = (byte)val; + } + + return new String(bArray, charSetName); + } + } + + private static final int[] multiplier = new int[] {1000, 100, 10, 1}; + + @SuppressWarnings("nls") + public static String unescapeSQLString(String b) { + Character enclosure = null; + + // Some of the strings can be passed in as unicode. For example, the + // delimiter can be passed in as \002 - So, we first check if the + // string is a unicode number, else go back to the old behavior + StringBuilder sb = new StringBuilder(b.length()); + for (int i = 0; i < b.length(); i++) { + + char currentChar = b.charAt(i); + if (enclosure == null) { + if (currentChar == '\'' || b.charAt(i) == '\"') { + enclosure = currentChar; + } + // ignore all other chars outside the enclosure + continue; + } + + if (enclosure.equals(currentChar)) { + enclosure = null; + continue; + } + + if (currentChar == '\\' && (i + 6 < b.length()) && b.charAt(i + 1) == 'u') { + int code = 0; + int base = i + 2; + for (int j = 0; j < 4; j++) { + int digit = Character.digit(b.charAt(j + base), 16); + code += digit * multiplier[j]; + } + sb.append((char)code); + i += 5; + continue; + } + + if (currentChar == '\\' && (i + 4 < b.length())) { + char i1 = b.charAt(i + 1); + char i2 = b.charAt(i + 2); + char i3 = b.charAt(i + 3); + if ((i1 >= '0' && i1 <= '1') && (i2 >= '0' && i2 <= '7') + && (i3 >= '0' && i3 <= '7')) { + byte bVal = (byte) ((i3 - '0') + ((i2 - '0') * 8) + ((i1 - '0') * 8 * 8)); + byte[] bValArr = new byte[1]; + bValArr[0] = bVal; + String tmp = new String(bValArr); + sb.append(tmp); + i += 3; + continue; + } + } + + if (currentChar == '\\' && (i + 2 < b.length())) { + char n = b.charAt(i + 1); + switch (n) { + case '0': + sb.append("\0"); + break; + case '\'': + sb.append("'"); + break; + case '"': + sb.append("\""); + break; + case 'b': + sb.append("\b"); + break; + case 'n': + sb.append("\n"); + break; + case 'r': + sb.append("\r"); + break; + case 't': + sb.append("\t"); + break; + case 'Z': + sb.append("\u001A"); + break; + case '\\': + sb.append("\\"); + break; + // The following 2 lines are exactly what MySQL does TODO: why do we do this? + case '%': + sb.append("\\%"); + break; + case '_': + sb.append("\\_"); + break; + default: + sb.append(n); + } + i++; + } else { + sb.append(currentChar); + } + } + return sb.toString(); + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala new file mode 100644 index 0000000000000000000000000000000000000000..42bdf25b61ea5a8305f27613e7507d338142acfe --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala @@ -0,0 +1,961 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst + +import java.sql.Date + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.Count +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.trees.CurrentOrigin +import org.apache.spark.sql.catalyst.parser._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval +import org.apache.spark.util.random.RandomSampler + +/** + * This class translates a HQL String to a Catalyst [[LogicalPlan]] or [[Expression]]. + */ +private[sql] class CatalystQl(val conf: ParserConf = SimpleParserConf()) { + object Token { + def unapply(node: ASTNode): Some[(String, List[ASTNode])] = { + CurrentOrigin.setPosition(node.line, node.positionInLine) + node.pattern + } + } + + + /** + * Returns the AST for the given SQL string. + */ + protected def getAst(sql: String): ASTNode = ParseDriver.parse(sql, conf) + + /** Creates LogicalPlan for a given HiveQL string. */ + def createPlan(sql: String): LogicalPlan = { + try { + createPlan(sql, ParseDriver.parse(sql, conf)) + } catch { + case e: MatchError => throw e + case e: AnalysisException => throw e + case e: Exception => + throw new AnalysisException(e.getMessage) + case e: NotImplementedError => + throw new AnalysisException( + s""" + |Unsupported language features in query: $sql + |${getAst(sql).treeString} + |$e + |${e.getStackTrace.head} + """.stripMargin) + } + } + + protected def createPlan(sql: String, tree: ASTNode): LogicalPlan = nodeToPlan(tree) + + def parseDdl(ddl: String): Seq[Attribute] = { + val tree = getAst(ddl) + assert(tree.text == "TOK_CREATETABLE", "Only CREATE TABLE supported.") + val tableOps = tree.children + val colList = tableOps + .find(_.text == "TOK_TABCOLLIST") + .getOrElse(sys.error("No columnList!")) + + colList.children.map(nodeToAttribute) + } + + protected def getClauses( + clauseNames: Seq[String], + nodeList: Seq[ASTNode]): Seq[Option[ASTNode]] = { + var remainingNodes = nodeList + val clauses = clauseNames.map { clauseName => + val (matches, nonMatches) = remainingNodes.partition(_.text.toUpperCase == clauseName) + remainingNodes = nonMatches ++ (if (matches.nonEmpty) matches.tail else Nil) + matches.headOption + } + + if (remainingNodes.nonEmpty) { + sys.error( + s"""Unhandled clauses: ${remainingNodes.map(_.treeString).mkString("\n")}. + |You are likely trying to use an unsupported Hive feature."""".stripMargin) + } + clauses + } + + protected def getClause(clauseName: String, nodeList: Seq[ASTNode]): ASTNode = + getClauseOption(clauseName, nodeList).getOrElse(sys.error( + s"Expected clause $clauseName missing from ${nodeList.map(_.treeString).mkString("\n")}")) + + protected def getClauseOption(clauseName: String, nodeList: Seq[ASTNode]): Option[ASTNode] = { + nodeList.filter { case ast: ASTNode => ast.text == clauseName } match { + case Seq(oneMatch) => Some(oneMatch) + case Seq() => None + case _ => sys.error(s"Found multiple instances of clause $clauseName") + } + } + + protected def nodeToAttribute(node: ASTNode): Attribute = node match { + case Token("TOK_TABCOL", Token(colName, Nil) :: dataType :: Nil) => + AttributeReference(colName, nodeToDataType(dataType), nullable = true)() + case _ => + noParseRule("Attribute", node) + } + + protected def nodeToDataType(node: ASTNode): DataType = node match { + case Token("TOK_DECIMAL", precision :: scale :: Nil) => + DecimalType(precision.text.toInt, scale.text.toInt) + case Token("TOK_DECIMAL", precision :: Nil) => + DecimalType(precision.text.toInt, 0) + case Token("TOK_DECIMAL", Nil) => DecimalType.USER_DEFAULT + case Token("TOK_BIGINT", Nil) => LongType + case Token("TOK_INT", Nil) => IntegerType + case Token("TOK_TINYINT", Nil) => ByteType + case Token("TOK_SMALLINT", Nil) => ShortType + case Token("TOK_BOOLEAN", Nil) => BooleanType + case Token("TOK_STRING", Nil) => StringType + case Token("TOK_VARCHAR", Token(_, Nil) :: Nil) => StringType + case Token("TOK_FLOAT", Nil) => FloatType + case Token("TOK_DOUBLE", Nil) => DoubleType + case Token("TOK_DATE", Nil) => DateType + case Token("TOK_TIMESTAMP", Nil) => TimestampType + case Token("TOK_BINARY", Nil) => BinaryType + case Token("TOK_LIST", elementType :: Nil) => ArrayType(nodeToDataType(elementType)) + case Token("TOK_STRUCT", Token("TOK_TABCOLLIST", fields) :: Nil) => + StructType(fields.map(nodeToStructField)) + case Token("TOK_MAP", keyType :: valueType :: Nil) => + MapType(nodeToDataType(keyType), nodeToDataType(valueType)) + case _ => + noParseRule("DataType", node) + } + + protected def nodeToStructField(node: ASTNode): StructField = node match { + case Token("TOK_TABCOL", Token(fieldName, Nil) :: dataType :: Nil) => + StructField(fieldName, nodeToDataType(dataType), nullable = true) + case Token("TOK_TABCOL", Token(fieldName, Nil) :: dataType :: _ /* comment */:: Nil) => + StructField(fieldName, nodeToDataType(dataType), nullable = true) + case _ => + noParseRule("StructField", node) + } + + protected def extractTableIdent(tableNameParts: ASTNode): TableIdentifier = { + tableNameParts.children.map { + case Token(part, Nil) => cleanIdentifier(part) + } match { + case Seq(tableOnly) => TableIdentifier(tableOnly) + case Seq(databaseName, table) => TableIdentifier(table, Some(databaseName)) + case other => sys.error("Hive only supports tables names like 'tableName' " + + s"or 'databaseName.tableName', found '$other'") + } + } + + /** + * SELECT MAX(value) FROM src GROUP BY k1, k2, k3 GROUPING SETS((k1, k2), (k2)) + * is equivalent to + * SELECT MAX(value) FROM src GROUP BY k1, k2 UNION SELECT MAX(value) FROM src GROUP BY k2 + * Check the following link for details. + * +https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C+Grouping+and+Rollup + * + * The bitmask denotes the grouping expressions validity for a grouping set, + * the bitmask also be called as grouping id (`GROUPING__ID`, the virtual column in Hive) + * e.g. In superset (k1, k2, k3), (bit 0: k1, bit 1: k2, and bit 2: k3), the grouping id of + * GROUPING SETS (k1, k2) and (k2) should be 3 and 2 respectively. + */ + protected def extractGroupingSet(children: Seq[ASTNode]): (Seq[Expression], Seq[Int]) = { + val (keyASTs, setASTs) = children.partition { + case Token("TOK_GROUPING_SETS_EXPRESSION", _) => false // grouping sets + case _ => true // grouping keys + } + + val keys = keyASTs.map(nodeToExpr) + val keyMap = keyASTs.zipWithIndex.toMap + + val bitmasks: Seq[Int] = setASTs.map { + case Token("TOK_GROUPING_SETS_EXPRESSION", null) => 0 + case Token("TOK_GROUPING_SETS_EXPRESSION", columns) => + columns.foldLeft(0)((bitmap, col) => { + val keyIndex = keyMap.find(_._1.treeEquals(col)).map(_._2) + bitmap | 1 << keyIndex.getOrElse( + throw new AnalysisException(s"${col.treeString} doesn't show up in the GROUP BY list")) + }) + case _ => sys.error("Expect GROUPING SETS clause") + } + + (keys, bitmasks) + } + + protected def nodeToPlan(node: ASTNode): LogicalPlan = node match { + case Token("TOK_QUERY", queryArgs @ Token("TOK_CTE" | "TOK_FROM" | "TOK_INSERT", _) :: _) => + val (fromClause: Option[ASTNode], insertClauses, cteRelations) = + queryArgs match { + case Token("TOK_CTE", ctes) :: Token("TOK_FROM", from) :: inserts => + val cteRelations = ctes.map { node => + val relation = nodeToRelation(node).asInstanceOf[Subquery] + relation.alias -> relation + } + (Some(from.head), inserts, Some(cteRelations.toMap)) + case Token("TOK_FROM", from) :: inserts => + (Some(from.head), inserts, None) + case Token("TOK_INSERT", _) :: Nil => + (None, queryArgs, None) + } + + // Return one query for each insert clause. + val queries = insertClauses.map { + case Token("TOK_INSERT", singleInsert) => + val ( + intoClause :: + destClause :: + selectClause :: + selectDistinctClause :: + whereClause :: + groupByClause :: + rollupGroupByClause :: + cubeGroupByClause :: + groupingSetsClause :: + orderByClause :: + havingClause :: + sortByClause :: + clusterByClause :: + distributeByClause :: + limitClause :: + lateralViewClause :: + windowClause :: Nil) = { + getClauses( + Seq( + "TOK_INSERT_INTO", + "TOK_DESTINATION", + "TOK_SELECT", + "TOK_SELECTDI", + "TOK_WHERE", + "TOK_GROUPBY", + "TOK_ROLLUP_GROUPBY", + "TOK_CUBE_GROUPBY", + "TOK_GROUPING_SETS", + "TOK_ORDERBY", + "TOK_HAVING", + "TOK_SORTBY", + "TOK_CLUSTERBY", + "TOK_DISTRIBUTEBY", + "TOK_LIMIT", + "TOK_LATERAL_VIEW", + "WINDOW"), + singleInsert) + } + + val relations = fromClause match { + case Some(f) => nodeToRelation(f) + case None => OneRowRelation + } + + val withWhere = whereClause.map { whereNode => + val Seq(whereExpr) = whereNode.children + Filter(nodeToExpr(whereExpr), relations) + }.getOrElse(relations) + + val select = (selectClause orElse selectDistinctClause) + .getOrElse(sys.error("No select clause.")) + + val transformation = nodeToTransformation(select.children.head, withWhere) + + val withLateralView = lateralViewClause.map { lv => + nodeToGenerate(lv.children.head, outer = false, withWhere) + }.getOrElse(withWhere) + + // The projection of the query can either be a normal projection, an aggregation + // (if there is a group by) or a script transformation. + val withProject: LogicalPlan = transformation.getOrElse { + val selectExpressions = + select.children.flatMap(selExprNodeToExpr).map(UnresolvedAlias(_)) + Seq( + groupByClause.map(e => e match { + case Token("TOK_GROUPBY", children) => + // Not a transformation so must be either project or aggregation. + Aggregate(children.map(nodeToExpr), selectExpressions, withLateralView) + case _ => sys.error("Expect GROUP BY") + }), + groupingSetsClause.map(e => e match { + case Token("TOK_GROUPING_SETS", children) => + val(groupByExprs, masks) = extractGroupingSet(children) + GroupingSets(masks, groupByExprs, withLateralView, selectExpressions) + case _ => sys.error("Expect GROUPING SETS") + }), + rollupGroupByClause.map(e => e match { + case Token("TOK_ROLLUP_GROUPBY", children) => + Aggregate( + Seq(Rollup(children.map(nodeToExpr))), + selectExpressions, + withLateralView) + case _ => sys.error("Expect WITH ROLLUP") + }), + cubeGroupByClause.map(e => e match { + case Token("TOK_CUBE_GROUPBY", children) => + Aggregate( + Seq(Cube(children.map(nodeToExpr))), + selectExpressions, + withLateralView) + case _ => sys.error("Expect WITH CUBE") + }), + Some(Project(selectExpressions, withLateralView))).flatten.head + } + + // Handle HAVING clause. + val withHaving = havingClause.map { h => + val havingExpr = h.children match { case Seq(hexpr) => nodeToExpr(hexpr) } + // Note that we added a cast to boolean. If the expression itself is already boolean, + // the optimizer will get rid of the unnecessary cast. + Filter(Cast(havingExpr, BooleanType), withProject) + }.getOrElse(withProject) + + // Handle SELECT DISTINCT + val withDistinct = + if (selectDistinctClause.isDefined) Distinct(withHaving) else withHaving + + // Handle ORDER BY, SORT BY, DISTRIBUTE BY, and CLUSTER BY clause. + val withSort = + (orderByClause, sortByClause, distributeByClause, clusterByClause) match { + case (Some(totalOrdering), None, None, None) => + Sort(totalOrdering.children.map(nodeToSortOrder), global = true, withDistinct) + case (None, Some(perPartitionOrdering), None, None) => + Sort( + perPartitionOrdering.children.map(nodeToSortOrder), + global = false, withDistinct) + case (None, None, Some(partitionExprs), None) => + RepartitionByExpression( + partitionExprs.children.map(nodeToExpr), withDistinct) + case (None, Some(perPartitionOrdering), Some(partitionExprs), None) => + Sort( + perPartitionOrdering.children.map(nodeToSortOrder), global = false, + RepartitionByExpression( + partitionExprs.children.map(nodeToExpr), + withDistinct)) + case (None, None, None, Some(clusterExprs)) => + Sort( + clusterExprs.children.map(nodeToExpr).map(SortOrder(_, Ascending)), + global = false, + RepartitionByExpression( + clusterExprs.children.map(nodeToExpr), + withDistinct)) + case (None, None, None, None) => withDistinct + case _ => sys.error("Unsupported set of ordering / distribution clauses.") + } + + val withLimit = + limitClause.map(l => nodeToExpr(l.children.head)) + .map(Limit(_, withSort)) + .getOrElse(withSort) + + // Collect all window specifications defined in the WINDOW clause. + val windowDefinitions = windowClause.map(_.children.collect { + case Token("TOK_WINDOWDEF", + Token(windowName, Nil) :: Token("TOK_WINDOWSPEC", spec) :: Nil) => + windowName -> nodesToWindowSpecification(spec) + }.toMap) + // Handle cases like + // window w1 as (partition by p_mfgr order by p_name + // range between 2 preceding and 2 following), + // w2 as w1 + val resolvedCrossReference = windowDefinitions.map { + windowDefMap => windowDefMap.map { + case (windowName, WindowSpecReference(other)) => + (windowName, windowDefMap(other).asInstanceOf[WindowSpecDefinition]) + case o => o.asInstanceOf[(String, WindowSpecDefinition)] + } + } + + val withWindowDefinitions = + resolvedCrossReference.map(WithWindowDefinition(_, withLimit)).getOrElse(withLimit) + + // TOK_INSERT_INTO means to add files to the table. + // TOK_DESTINATION means to overwrite the table. + val resultDestination = + (intoClause orElse destClause).getOrElse(sys.error("No destination found.")) + val overwrite = intoClause.isEmpty + nodeToDest( + resultDestination, + withWindowDefinitions, + overwrite) + } + + // If there are multiple INSERTS just UNION them together into on query. + val query = queries.reduceLeft(Union) + + // return With plan if there is CTE + cteRelations.map(With(query, _)).getOrElse(query) + + // HIVE-9039 renamed TOK_UNION => TOK_UNIONALL while adding TOK_UNIONDISTINCT + case Token("TOK_UNIONALL", left :: right :: Nil) => + Union(nodeToPlan(left), nodeToPlan(right)) + + case _ => + noParseRule("Plan", node) + } + + val allJoinTokens = "(TOK_.*JOIN)".r + val laterViewToken = "TOK_LATERAL_VIEW(.*)".r + protected def nodeToRelation(node: ASTNode): LogicalPlan = { + node match { + case Token("TOK_SUBQUERY", query :: Token(alias, Nil) :: Nil) => + Subquery(cleanIdentifier(alias), nodeToPlan(query)) + + case Token(laterViewToken(isOuter), selectClause :: relationClause :: Nil) => + nodeToGenerate( + selectClause, + outer = isOuter.nonEmpty, + nodeToRelation(relationClause)) + + /* All relations, possibly with aliases or sampling clauses. */ + case Token("TOK_TABREF", clauses) => + // If the last clause is not a token then it's the alias of the table. + val (nonAliasClauses, aliasClause) = + if (clauses.last.text.startsWith("TOK")) { + (clauses, None) + } else { + (clauses.dropRight(1), Some(clauses.last)) + } + + val (Some(tableNameParts) :: + splitSampleClause :: + bucketSampleClause :: Nil) = { + getClauses(Seq("TOK_TABNAME", "TOK_TABLESPLITSAMPLE", "TOK_TABLEBUCKETSAMPLE"), + nonAliasClauses) + } + + val tableIdent = extractTableIdent(tableNameParts) + val alias = aliasClause.map { case Token(a, Nil) => cleanIdentifier(a) } + val relation = UnresolvedRelation(tableIdent, alias) + + // Apply sampling if requested. + (bucketSampleClause orElse splitSampleClause).map { + case Token("TOK_TABLESPLITSAMPLE", + Token("TOK_ROWCOUNT", Nil) :: Token(count, Nil) :: Nil) => + Limit(Literal(count.toInt), relation) + case Token("TOK_TABLESPLITSAMPLE", + Token("TOK_PERCENT", Nil) :: Token(fraction, Nil) :: Nil) => + // The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling + // function takes X PERCENT as the input and the range of X is [0, 100], we need to + // adjust the fraction. + require( + fraction.toDouble >= (0.0 - RandomSampler.roundingEpsilon) + && fraction.toDouble <= (100.0 + RandomSampler.roundingEpsilon), + s"Sampling fraction ($fraction) must be on interval [0, 100]") + Sample(0.0, fraction.toDouble / 100, withReplacement = false, + (math.random * 1000).toInt, + relation) + case Token("TOK_TABLEBUCKETSAMPLE", + Token(numerator, Nil) :: + Token(denominator, Nil) :: Nil) => + val fraction = numerator.toDouble / denominator.toDouble + Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, relation) + case a => + noParseRule("Sampling", a) + }.getOrElse(relation) + + case Token(allJoinTokens(joinToken), relation1 :: relation2 :: other) => + if (!(other.size <= 1)) { + sys.error(s"Unsupported join operation: $other") + } + + val joinType = joinToken match { + case "TOK_JOIN" => Inner + case "TOK_CROSSJOIN" => Inner + case "TOK_RIGHTOUTERJOIN" => RightOuter + case "TOK_LEFTOUTERJOIN" => LeftOuter + case "TOK_FULLOUTERJOIN" => FullOuter + case "TOK_LEFTSEMIJOIN" => LeftSemi + case "TOK_UNIQUEJOIN" => noParseRule("Unique Join", node) + case "TOK_ANTIJOIN" => noParseRule("Anti Join", node) + } + Join(nodeToRelation(relation1), + nodeToRelation(relation2), + joinType, + other.headOption.map(nodeToExpr)) + + case _ => + noParseRule("Relation", node) + } + } + + protected def nodeToSortOrder(node: ASTNode): SortOrder = node match { + case Token("TOK_TABSORTCOLNAMEASC", sortExpr :: Nil) => + SortOrder(nodeToExpr(sortExpr), Ascending) + case Token("TOK_TABSORTCOLNAMEDESC", sortExpr :: Nil) => + SortOrder(nodeToExpr(sortExpr), Descending) + case _ => + noParseRule("SortOrder", node) + } + + val destinationToken = "TOK_DESTINATION|TOK_INSERT_INTO".r + protected def nodeToDest( + node: ASTNode, + query: LogicalPlan, + overwrite: Boolean): LogicalPlan = node match { + case Token(destinationToken(), + Token("TOK_DIR", + Token("TOK_TMP_FILE", Nil) :: Nil) :: Nil) => + query + + case Token(destinationToken(), + Token("TOK_TAB", + tableArgs) :: Nil) => + val Some(tableNameParts) :: partitionClause :: Nil = + getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs) + + val tableIdent = extractTableIdent(tableNameParts) + + val partitionKeys = partitionClause.map(_.children.map { + // Parse partitions. We also make keys case insensitive. + case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) => + cleanIdentifier(key.toLowerCase) -> Some(unquoteString(value)) + case Token("TOK_PARTVAL", Token(key, Nil) :: Nil) => + cleanIdentifier(key.toLowerCase) -> None + }.toMap).getOrElse(Map.empty) + + InsertIntoTable( + UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite, ifNotExists = false) + + case Token(destinationToken(), + Token("TOK_TAB", + tableArgs) :: + Token("TOK_IFNOTEXISTS", + ifNotExists) :: Nil) => + val Some(tableNameParts) :: partitionClause :: Nil = + getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs) + + val tableIdent = extractTableIdent(tableNameParts) + + val partitionKeys = partitionClause.map(_.children.map { + // Parse partitions. We also make keys case insensitive. + case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) => + cleanIdentifier(key.toLowerCase) -> Some(unquoteString(value)) + case Token("TOK_PARTVAL", Token(key, Nil) :: Nil) => + cleanIdentifier(key.toLowerCase) -> None + }.toMap).getOrElse(Map.empty) + + InsertIntoTable( + UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite, ifNotExists = true) + + case _ => + noParseRule("Destination", node) + } + + protected def selExprNodeToExpr(node: ASTNode): Option[Expression] = node match { + case Token("TOK_SELEXPR", e :: Nil) => + Some(nodeToExpr(e)) + + case Token("TOK_SELEXPR", e :: Token(alias, Nil) :: Nil) => + Some(Alias(nodeToExpr(e), cleanIdentifier(alias))()) + + case Token("TOK_SELEXPR", e :: aliasChildren) => + val aliasNames = aliasChildren.collect { + case Token(name, Nil) => cleanIdentifier(name) + } + Some(MultiAlias(nodeToExpr(e), aliasNames)) + + /* Hints are ignored */ + case Token("TOK_HINTLIST", _) => None + + case _ => + noParseRule("Select", node) + } + + protected val escapedIdentifier = "`([^`]+)`".r + protected val doubleQuotedString = "\"([^\"]+)\"".r + protected val singleQuotedString = "'([^']+)'".r + + protected def unquoteString(str: String) = str match { + case singleQuotedString(s) => s + case doubleQuotedString(s) => s + case other => other + } + + /** Strips backticks from ident if present */ + protected def cleanIdentifier(ident: String): String = ident match { + case escapedIdentifier(i) => i + case plainIdent => plainIdent + } + + val numericAstTypes = Seq( + SparkSqlParser.Number, + SparkSqlParser.TinyintLiteral, + SparkSqlParser.SmallintLiteral, + SparkSqlParser.BigintLiteral, + SparkSqlParser.DecimalLiteral) + + /* Case insensitive matches */ + val COUNT = "(?i)COUNT".r + val SUM = "(?i)SUM".r + val AND = "(?i)AND".r + val OR = "(?i)OR".r + val NOT = "(?i)NOT".r + val TRUE = "(?i)TRUE".r + val FALSE = "(?i)FALSE".r + val LIKE = "(?i)LIKE".r + val RLIKE = "(?i)RLIKE".r + val REGEXP = "(?i)REGEXP".r + val IN = "(?i)IN".r + val DIV = "(?i)DIV".r + val BETWEEN = "(?i)BETWEEN".r + val WHEN = "(?i)WHEN".r + val CASE = "(?i)CASE".r + + protected def nodeToExpr(node: ASTNode): Expression = node match { + /* Attribute References */ + case Token("TOK_TABLE_OR_COL", Token(name, Nil) :: Nil) => + UnresolvedAttribute.quoted(cleanIdentifier(name)) + case Token(".", qualifier :: Token(attr, Nil) :: Nil) => + nodeToExpr(qualifier) match { + case UnresolvedAttribute(nameParts) => + UnresolvedAttribute(nameParts :+ cleanIdentifier(attr)) + case other => UnresolvedExtractValue(other, Literal(attr)) + } + + /* Stars (*) */ + case Token("TOK_ALLCOLREF", Nil) => UnresolvedStar(None) + // The format of dbName.tableName.* cannot be parsed by HiveParser. TOK_TABNAME will only + // has a single child which is tableName. + case Token("TOK_ALLCOLREF", Token("TOK_TABNAME", Token(name, Nil) :: Nil) :: Nil) => + UnresolvedStar(Some(UnresolvedAttribute.parseAttributeName(name))) + + /* Aggregate Functions */ + case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) => + Count(args.map(nodeToExpr)).toAggregateExpression(isDistinct = true) + case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) => + Count(Literal(1)).toAggregateExpression() + + /* Casts */ + case Token("TOK_FUNCTION", Token("TOK_STRING", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), StringType) + case Token("TOK_FUNCTION", Token("TOK_VARCHAR", _) :: arg :: Nil) => + Cast(nodeToExpr(arg), StringType) + case Token("TOK_FUNCTION", Token("TOK_CHAR", _) :: arg :: Nil) => + Cast(nodeToExpr(arg), StringType) + case Token("TOK_FUNCTION", Token("TOK_INT", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), IntegerType) + case Token("TOK_FUNCTION", Token("TOK_BIGINT", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), LongType) + case Token("TOK_FUNCTION", Token("TOK_FLOAT", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), FloatType) + case Token("TOK_FUNCTION", Token("TOK_DOUBLE", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), DoubleType) + case Token("TOK_FUNCTION", Token("TOK_SMALLINT", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), ShortType) + case Token("TOK_FUNCTION", Token("TOK_TINYINT", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), ByteType) + case Token("TOK_FUNCTION", Token("TOK_BINARY", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), BinaryType) + case Token("TOK_FUNCTION", Token("TOK_BOOLEAN", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), BooleanType) + case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: scale :: nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), DecimalType(precision.text.toInt, scale.text.toInt)) + case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), DecimalType(precision.text.toInt, 0)) + case Token("TOK_FUNCTION", Token("TOK_DECIMAL", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), DecimalType.USER_DEFAULT) + case Token("TOK_FUNCTION", Token("TOK_TIMESTAMP", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), TimestampType) + case Token("TOK_FUNCTION", Token("TOK_DATE", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), DateType) + + /* Arithmetic */ + case Token("+", child :: Nil) => nodeToExpr(child) + case Token("-", child :: Nil) => UnaryMinus(nodeToExpr(child)) + case Token("~", child :: Nil) => BitwiseNot(nodeToExpr(child)) + case Token("+", left :: right:: Nil) => Add(nodeToExpr(left), nodeToExpr(right)) + case Token("-", left :: right:: Nil) => Subtract(nodeToExpr(left), nodeToExpr(right)) + case Token("*", left :: right:: Nil) => Multiply(nodeToExpr(left), nodeToExpr(right)) + case Token("/", left :: right:: Nil) => Divide(nodeToExpr(left), nodeToExpr(right)) + case Token(DIV(), left :: right:: Nil) => + Cast(Divide(nodeToExpr(left), nodeToExpr(right)), LongType) + case Token("%", left :: right:: Nil) => Remainder(nodeToExpr(left), nodeToExpr(right)) + case Token("&", left :: right:: Nil) => BitwiseAnd(nodeToExpr(left), nodeToExpr(right)) + case Token("|", left :: right:: Nil) => BitwiseOr(nodeToExpr(left), nodeToExpr(right)) + case Token("^", left :: right:: Nil) => BitwiseXor(nodeToExpr(left), nodeToExpr(right)) + + /* Comparisons */ + case Token("=", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right)) + case Token("==", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right)) + case Token("<=>", left :: right:: Nil) => EqualNullSafe(nodeToExpr(left), nodeToExpr(right)) + case Token("!=", left :: right:: Nil) => Not(EqualTo(nodeToExpr(left), nodeToExpr(right))) + case Token("<>", left :: right:: Nil) => Not(EqualTo(nodeToExpr(left), nodeToExpr(right))) + case Token(">", left :: right:: Nil) => GreaterThan(nodeToExpr(left), nodeToExpr(right)) + case Token(">=", left :: right:: Nil) => GreaterThanOrEqual(nodeToExpr(left), nodeToExpr(right)) + case Token("<", left :: right:: Nil) => LessThan(nodeToExpr(left), nodeToExpr(right)) + case Token("<=", left :: right:: Nil) => LessThanOrEqual(nodeToExpr(left), nodeToExpr(right)) + case Token(LIKE(), left :: right:: Nil) => Like(nodeToExpr(left), nodeToExpr(right)) + case Token(RLIKE(), left :: right:: Nil) => RLike(nodeToExpr(left), nodeToExpr(right)) + case Token(REGEXP(), left :: right:: Nil) => RLike(nodeToExpr(left), nodeToExpr(right)) + case Token("TOK_FUNCTION", Token("TOK_ISNOTNULL", Nil) :: child :: Nil) => + IsNotNull(nodeToExpr(child)) + case Token("TOK_FUNCTION", Token("TOK_ISNULL", Nil) :: child :: Nil) => + IsNull(nodeToExpr(child)) + case Token("TOK_FUNCTION", Token(IN(), Nil) :: value :: list) => + In(nodeToExpr(value), list.map(nodeToExpr)) + case Token("TOK_FUNCTION", + Token(BETWEEN(), Nil) :: + kw :: + target :: + minValue :: + maxValue :: Nil) => + + val targetExpression = nodeToExpr(target) + val betweenExpr = + And( + GreaterThanOrEqual(targetExpression, nodeToExpr(minValue)), + LessThanOrEqual(targetExpression, nodeToExpr(maxValue))) + kw match { + case Token("KW_FALSE", Nil) => betweenExpr + case Token("KW_TRUE", Nil) => Not(betweenExpr) + } + + /* Boolean Logic */ + case Token(AND(), left :: right:: Nil) => And(nodeToExpr(left), nodeToExpr(right)) + case Token(OR(), left :: right:: Nil) => Or(nodeToExpr(left), nodeToExpr(right)) + case Token(NOT(), child :: Nil) => Not(nodeToExpr(child)) + case Token("!", child :: Nil) => Not(nodeToExpr(child)) + + /* Case statements */ + case Token("TOK_FUNCTION", Token(WHEN(), Nil) :: branches) => + CaseWhen(branches.map(nodeToExpr)) + case Token("TOK_FUNCTION", Token(CASE(), Nil) :: branches) => + val keyExpr = nodeToExpr(branches.head) + CaseKeyWhen(keyExpr, branches.drop(1).map(nodeToExpr)) + + /* Complex datatype manipulation */ + case Token("[", child :: ordinal :: Nil) => + UnresolvedExtractValue(nodeToExpr(child), nodeToExpr(ordinal)) + + /* Window Functions */ + case Token(text, args :+ Token("TOK_WINDOWSPEC", spec)) => + val function = nodeToExpr(node.copy(children = node.children.init)) + nodesToWindowSpecification(spec) match { + case reference: WindowSpecReference => + UnresolvedWindowExpression(function, reference) + case definition: WindowSpecDefinition => + WindowExpression(function, definition) + } + + /* UDFs - Must be last otherwise will preempt built in functions */ + case Token("TOK_FUNCTION", Token(name, Nil) :: args) => + UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = false) + // Aggregate function with DISTINCT keyword. + case Token("TOK_FUNCTIONDI", Token(name, Nil) :: args) => + UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = true) + case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: args) => + UnresolvedFunction(name, UnresolvedStar(None) :: Nil, isDistinct = false) + + /* Literals */ + case Token("TOK_NULL", Nil) => Literal.create(null, NullType) + case Token(TRUE(), Nil) => Literal.create(true, BooleanType) + case Token(FALSE(), Nil) => Literal.create(false, BooleanType) + case Token("TOK_STRINGLITERALSEQUENCE", strings) => + Literal(strings.map(s => ParseUtils.unescapeSQLString(s.text)).mkString) + + // This code is adapted from + // /ql/src/java/org/apache/hadoop/hive/ql/parse/TypeCheckProcFactory.java#L223 + case ast: ASTNode if numericAstTypes contains ast.tokenType => + var v: Literal = null + try { + if (ast.text.endsWith("L")) { + // Literal bigint. + v = Literal.create(ast.text.substring(0, ast.text.length() - 1).toLong, LongType) + } else if (ast.text.endsWith("S")) { + // Literal smallint. + v = Literal.create(ast.text.substring(0, ast.text.length() - 1).toShort, ShortType) + } else if (ast.text.endsWith("Y")) { + // Literal tinyint. + v = Literal.create(ast.text.substring(0, ast.text.length() - 1).toByte, ByteType) + } else if (ast.text.endsWith("BD") || ast.text.endsWith("D")) { + // Literal decimal + val strVal = ast.text.stripSuffix("D").stripSuffix("B") + v = Literal(Decimal(strVal)) + } else { + v = Literal.create(ast.text.toDouble, DoubleType) + v = Literal.create(ast.text.toLong, LongType) + v = Literal.create(ast.text.toInt, IntegerType) + } + } catch { + case nfe: NumberFormatException => // Do nothing + } + + if (v == null) { + sys.error(s"Failed to parse number '${ast.text}'.") + } else { + v + } + + case ast: ASTNode if ast.tokenType == SparkSqlParser.StringLiteral => + Literal(ParseUtils.unescapeSQLString(ast.text)) + + case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_DATELITERAL => + Literal(Date.valueOf(ast.text.substring(1, ast.text.length - 1))) + + case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_CHARSETLITERAL => + Literal(ParseUtils.charSetString(ast.children.head.text, ast.children(1).text)) + + case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_YEAR_MONTH_LITERAL => + Literal(CalendarInterval.fromYearMonthString(ast.text)) + + case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_DAY_TIME_LITERAL => + Literal(CalendarInterval.fromDayTimeString(ast.text)) + + case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_YEAR_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("year", ast.text)) + + case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_MONTH_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("month", ast.text)) + + case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_DAY_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("day", ast.text)) + + case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_HOUR_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("hour", ast.text)) + + case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_MINUTE_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("minute", ast.text)) + + case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_SECOND_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("second", ast.text)) + + case _ => + noParseRule("Expression", node) + } + + /* Case insensitive matches for Window Specification */ + val PRECEDING = "(?i)preceding".r + val FOLLOWING = "(?i)following".r + val CURRENT = "(?i)current".r + protected def nodesToWindowSpecification(nodes: Seq[ASTNode]): WindowSpec = nodes match { + case Token(windowName, Nil) :: Nil => + // Refer to a window spec defined in the window clause. + WindowSpecReference(windowName) + case Nil => + // OVER() + WindowSpecDefinition( + partitionSpec = Nil, + orderSpec = Nil, + frameSpecification = UnspecifiedFrame) + case spec => + val (partitionClause :: rowFrame :: rangeFrame :: Nil) = + getClauses( + Seq( + "TOK_PARTITIONINGSPEC", + "TOK_WINDOWRANGE", + "TOK_WINDOWVALUES"), + spec) + + // Handle Partition By and Order By. + val (partitionSpec, orderSpec) = partitionClause.map { partitionAndOrdering => + val (partitionByClause :: orderByClause :: sortByClause :: clusterByClause :: Nil) = + getClauses( + Seq("TOK_DISTRIBUTEBY", "TOK_ORDERBY", "TOK_SORTBY", "TOK_CLUSTERBY"), + partitionAndOrdering.children) + + (partitionByClause, orderByClause.orElse(sortByClause), clusterByClause) match { + case (Some(partitionByExpr), Some(orderByExpr), None) => + (partitionByExpr.children.map(nodeToExpr), + orderByExpr.children.map(nodeToSortOrder)) + case (Some(partitionByExpr), None, None) => + (partitionByExpr.children.map(nodeToExpr), Nil) + case (None, Some(orderByExpr), None) => + (Nil, orderByExpr.children.map(nodeToSortOrder)) + case (None, None, Some(clusterByExpr)) => + val expressions = clusterByExpr.children.map(nodeToExpr) + (expressions, expressions.map(SortOrder(_, Ascending))) + case _ => + noParseRule("Partition & Ordering", partitionAndOrdering) + } + }.getOrElse { + (Nil, Nil) + } + + // Handle Window Frame + val windowFrame = + if (rowFrame.isEmpty && rangeFrame.isEmpty) { + UnspecifiedFrame + } else { + val frameType = rowFrame.map(_ => RowFrame).getOrElse(RangeFrame) + def nodeToBoundary(node: ASTNode): FrameBoundary = node match { + case Token(PRECEDING(), Token(count, Nil) :: Nil) => + if (count.toLowerCase() == "unbounded") { + UnboundedPreceding + } else { + ValuePreceding(count.toInt) + } + case Token(FOLLOWING(), Token(count, Nil) :: Nil) => + if (count.toLowerCase() == "unbounded") { + UnboundedFollowing + } else { + ValueFollowing(count.toInt) + } + case Token(CURRENT(), Nil) => CurrentRow + case _ => + noParseRule("Window Frame Boundary", node) + } + + rowFrame.orElse(rangeFrame).map { frame => + frame.children match { + case precedingNode :: followingNode :: Nil => + SpecifiedWindowFrame( + frameType, + nodeToBoundary(precedingNode), + nodeToBoundary(followingNode)) + case precedingNode :: Nil => + SpecifiedWindowFrame(frameType, nodeToBoundary(precedingNode), CurrentRow) + case _ => + noParseRule("Window Frame", frame) + } + }.getOrElse(sys.error(s"If you see this, please file a bug report with your query.")) + } + + WindowSpecDefinition(partitionSpec, orderSpec, windowFrame) + } + + protected def nodeToTransformation( + node: ASTNode, + child: LogicalPlan): Option[ScriptTransformation] = None + + val explode = "(?i)explode".r + val jsonTuple = "(?i)json_tuple".r + protected def nodeToGenerate(node: ASTNode, outer: Boolean, child: LogicalPlan): Generate = { + val Token("TOK_SELECT", Token("TOK_SELEXPR", clauses) :: Nil) = node + + val alias = getClause("TOK_TABALIAS", clauses).children.head.text + + val generator = clauses.head match { + case Token("TOK_FUNCTION", Token(explode(), Nil) :: childNode :: Nil) => + Explode(nodeToExpr(childNode)) + case Token("TOK_FUNCTION", Token(jsonTuple(), Nil) :: children) => + JsonTuple(children.map(nodeToExpr)) + case other => + nodeToGenerator(other) + } + + val attributes = clauses.collect { + case Token(a, Nil) => UnresolvedAttribute(a.toLowerCase) + } + + Generate(generator, join = true, outer = outer, Some(alias.toLowerCase), attributes, child) + } + + protected def nodeToGenerator(node: ASTNode): Generator = noParseRule("Generator", node) + + protected def noParseRule(msg: String, node: ASTNode): Nothing = throw new NotImplementedError( + s"[$msg]: No parse rules for ASTNode type: ${node.tokenType}, tree:\n${node.treeString}") +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala new file mode 100644 index 0000000000000000000000000000000000000000..ec5e71042d4beb1603399df0609991e7a124eb98 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser + +import org.antlr.runtime.{Token, TokenRewriteStream} + +import org.apache.spark.sql.catalyst.trees.{Origin, TreeNode} + +case class ASTNode( + token: Token, + startIndex: Int, + stopIndex: Int, + children: List[ASTNode], + stream: TokenRewriteStream) extends TreeNode[ASTNode] { + /** Cache the number of children. */ + val numChildren = children.size + + /** tuple used in pattern matching. */ + val pattern = Some((token.getText, children)) + + /** Line in which the ASTNode starts. */ + lazy val line: Int = { + val line = token.getLine + if (line == 0) { + if (children.nonEmpty) children.head.line + else 0 + } else { + line + } + } + + /** Position of the Character at which ASTNode starts. */ + lazy val positionInLine: Int = { + val line = token.getCharPositionInLine + if (line == -1) { + if (children.nonEmpty) children.head.positionInLine + else 0 + } else { + line + } + } + + /** Origin of the ASTNode. */ + override val origin = Origin(Some(line), Some(positionInLine)) + + /** Source text. */ + lazy val source = stream.toString(startIndex, stopIndex) + + def text: String = token.getText + + def tokenType: Int = token.getType + + /** + * Checks if this node is equal to another node. + * + * Right now this function only checks the name, type, text and children of the node + * for equality. + */ + def treeEquals(other: ASTNode): Boolean = { + def check(f: ASTNode => Any): Boolean = { + val l = f(this) + val r = f(other) + (l == null && r == null) || l.equals(r) + } + if (other == null) { + false + } else if (!check(_.token.getType) + || !check(_.token.getText) + || !check(_.numChildren)) { + false + } else { + children.zip(other.children).forall { + case (l, r) => l treeEquals r + } + } + } + + override def simpleString: String = s"$text $line, $startIndex, $stopIndex, $positionInLine " +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala new file mode 100644 index 0000000000000000000000000000000000000000..0e93af8b92cd2f2e6184acf616b19fbcf85dd00f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser + +import org.antlr.runtime._ +import org.antlr.runtime.tree.CommonTree + +import org.apache.spark.Logging +import org.apache.spark.sql.AnalysisException + +/** + * The ParseDriver takes a SQL command and turns this into an AST. + * + * This is based on Hive's org.apache.hadoop.hive.ql.parse.ParseDriver + */ +object ParseDriver extends Logging { + def parse(command: String, conf: ParserConf): ASTNode = { + logInfo(s"Parsing command: $command") + + // Setup error collection. + val reporter = new ParseErrorReporter() + + // Create lexer. + val lexer = new SparkSqlLexer(new ANTLRNoCaseStringStream(command)) + val tokens = new TokenRewriteStream(lexer) + lexer.configure(conf, reporter) + + // Create the parser. + val parser = new SparkSqlParser(tokens) + parser.configure(conf, reporter) + + try { + val result = parser.statement() + + // Check errors. + reporter.checkForErrors() + + // Return the AST node from the result. + logInfo(s"Parse completed.") + + // Find the non null token tree in the result. + def nonNullToken(tree: CommonTree): CommonTree = { + if (tree.token != null || tree.getChildCount == 0) tree + else nonNullToken(tree.getChild(0).asInstanceOf[CommonTree]) + } + val tree = nonNullToken(result.getTree) + + // Make sure all boundaries are set. + tree.setUnknownTokenBoundaries() + + // Construct the immutable AST. + def createASTNode(tree: CommonTree): ASTNode = { + val children = (0 until tree.getChildCount).map { i => + createASTNode(tree.getChild(i).asInstanceOf[CommonTree]) + }.toList + ASTNode(tree.token, tree.getTokenStartIndex, tree.getTokenStopIndex, children, tokens) + } + createASTNode(tree) + } + catch { + case e: RecognitionException => + logInfo(s"Parse failed.") + reporter.throwError(e) + } + } +} + +/** + * This string stream provides the lexer with upper case characters only. This greatly simplifies + * lexing the stream, while we can maintain the original command. + * + * This is based on Hive's org.apache.hadoop.hive.ql.parse.ParseDriver.ANTLRNoCaseStringStream + * + * The comment below (taken from the original class) describes the rationale for doing this: + * + * This class provides and implementation for a case insensitive token checker for the lexical + * analysis part of antlr. By converting the token stream into upper case at the time when lexical + * rules are checked, this class ensures that the lexical rules need to just match the token with + * upper case letters as opposed to combination of upper case and lower case characters. This is + * purely used for matching lexical rules. The actual token text is stored in the same way as the + * user input without actually converting it into an upper case. The token values are generated by + * the consume() function of the super class ANTLRStringStream. The LA() function is the lookahead + * function and is purely used for matching lexical rules. This also means that the grammar will + * only accept capitalized tokens in case it is run from other tools like antlrworks which do not + * have the ANTLRNoCaseStringStream implementation. + */ + +private[parser] class ANTLRNoCaseStringStream(input: String) extends ANTLRStringStream(input) { + override def LA(i: Int): Int = { + val la = super.LA(i) + if (la == 0 || la == CharStream.EOF) la + else Character.toUpperCase(la) + } +} + +/** + * Utility used by the Parser and the Lexer for error collection and reporting. + */ +private[parser] class ParseErrorReporter { + val errors = scala.collection.mutable.Buffer.empty[ParseError] + + def report(br: BaseRecognizer, re: RecognitionException, tokenNames: Array[String]): Unit = { + errors += ParseError(br, re, tokenNames) + } + + def checkForErrors(): Unit = { + if (errors.nonEmpty) { + val first = errors.head + val e = first.re + throwError(e.line, e.charPositionInLine, first.buildMessage().toString, errors.tail) + } + } + + def throwError(e: RecognitionException): Nothing = { + throwError(e.line, e.charPositionInLine, e.toString, errors) + } + + private def throwError( + line: Int, + startPosition: Int, + msg: String, + errors: Seq[ParseError]): Nothing = { + val b = new StringBuilder + b.append(msg).append("\n") + errors.foreach(error => error.buildMessage(b).append("\n")) + throw new AnalysisException(b.toString, Option(line), Option(startPosition)) + } +} + +/** + * Error collected during the parsing process. + * + * This is based on Hive's org.apache.hadoop.hive.ql.parse.ParseError + */ +private[parser] case class ParseError( + br: BaseRecognizer, + re: RecognitionException, + tokenNames: Array[String]) { + def buildMessage(s: StringBuilder = new StringBuilder): StringBuilder = { + s.append(br.getErrorHeader(re)).append(" ").append(br.getErrorMessage(re, tokenNames)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserConf.scala new file mode 100644 index 0000000000000000000000000000000000000000..ce449b11431a56e76eba70b2c6ecb0fb3b1c8268 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserConf.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser + +trait ParserConf { + def supportQuotedId: Boolean + def supportSQL11ReservedKeywords: Boolean +} + +case class SimpleParserConf( + supportQuotedId: Boolean = true, + supportSQL11ReservedKeywords: Boolean = false) extends ParserConf diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index b58a3739912bcbbadf6a115e34fc936e537dd493..26c00dc250b4b33cec09703aacb451dc9995ada7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -25,6 +25,7 @@ import scala.collection.JavaConverters._ import org.apache.parquet.hadoop.ParquetOutputCommitter import org.apache.spark.sql.catalyst.CatalystConf +import org.apache.spark.sql.catalyst.parser.ParserConf //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines the configuration options for Spark SQL. @@ -451,6 +452,19 @@ private[spark] object SQLConf { doc = "When true, we could use `datasource`.`path` as table in SQL query" ) + val PARSER_SUPPORT_QUOTEDID = booleanConf("spark.sql.parser.supportQuotedIdentifiers", + defaultValue = Some(true), + isPublic = false, + doc = "Whether to use quoted identifier.\n false: default(past) behavior. Implies only" + + "alphaNumeric and underscore are valid characters in identifiers.\n" + + " true: implies column names can contain any character.") + + val PARSER_SUPPORT_SQL11_RESERVED_KEYWORDS = booleanConf( + "spark.sql.parser.supportSQL11ReservedKeywords", + defaultValue = Some(false), + isPublic = false, + doc = "This flag should be set to true to enable support for SQL2011 reserved keywords.") + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" val EXTERNAL_SORT = "spark.sql.planner.externalSort" @@ -471,7 +485,7 @@ private[spark] object SQLConf { * * SQLConf is thread-safe (internally synchronized, so safe to be used in multiple threads). */ -private[sql] class SQLConf extends Serializable with CatalystConf { +private[sql] class SQLConf extends Serializable with CatalystConf with ParserConf { import SQLConf._ /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @@ -569,6 +583,10 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def runSQLOnFile: Boolean = getConf(RUN_SQL_ON_FILES) + def supportQuotedId: Boolean = getConf(PARSER_SUPPORT_QUOTEDID) + + def supportSQL11ReservedKeywords: Boolean = getConf(PARSER_SUPPORT_SQL11_RESERVED_KEYWORDS) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala new file mode 100644 index 0000000000000000000000000000000000000000..a322688a259e261f355577b8c4fe20dfa79a4c3c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.parser.{ASTNode, ParserConf, SimpleParserConf} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} +import org.apache.spark.sql.catalyst.{CatalystQl, TableIdentifier} + +private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends CatalystQl(conf) { + /** Check if a command should not be explained. */ + protected def isNoExplainCommand(command: String): Boolean = "TOK_DESCTABLE" == command + + protected override def nodeToPlan(node: ASTNode): LogicalPlan = { + node match { + // Just fake explain for any of the native commands. + case Token("TOK_EXPLAIN", explainArgs) if isNoExplainCommand(explainArgs.head.text) => + ExplainCommand(OneRowRelation) + + case Token("TOK_EXPLAIN", explainArgs) if "TOK_CREATETABLE" == explainArgs.head.text => + val Some(crtTbl) :: _ :: extended :: Nil = + getClauses(Seq("TOK_CREATETABLE", "FORMATTED", "EXTENDED"), explainArgs) + ExplainCommand(nodeToPlan(crtTbl), extended = extended.isDefined) + + case Token("TOK_EXPLAIN", explainArgs) => + // Ignore FORMATTED if present. + val Some(query) :: _ :: extended :: Nil = + getClauses(Seq("TOK_QUERY", "FORMATTED", "EXTENDED"), explainArgs) + ExplainCommand(nodeToPlan(query), extended = extended.isDefined) + + case Token("TOK_DESCTABLE", describeArgs) => + // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL + val Some(tableType) :: formatted :: extended :: pretty :: Nil = + getClauses(Seq("TOK_TABTYPE", "FORMATTED", "EXTENDED", "PRETTY"), describeArgs) + if (formatted.isDefined || pretty.isDefined) { + // FORMATTED and PRETTY are not supported and this statement will be treated as + // a Hive native command. + nodeToDescribeFallback(node) + } else { + tableType match { + case Token("TOK_TABTYPE", Token("TOK_TABNAME", nameParts :: Nil) :: Nil) => + nameParts match { + case Token(".", dbName :: tableName :: Nil) => + // It is describing a table with the format like "describe db.table". + // TODO: Actually, a user may mean tableName.columnName. Need to resolve this + // issue. + val tableIdent = extractTableIdent(nameParts) + datasources.DescribeCommand( + UnresolvedRelation(tableIdent, None), isExtended = extended.isDefined) + case Token(".", dbName :: tableName :: colName :: Nil) => + // It is describing a column with the format like "describe db.table column". + nodeToDescribeFallback(node) + case tableName => + // It is describing a table with the format like "describe table". + datasources.DescribeCommand( + UnresolvedRelation(TableIdentifier(tableName.text), None), + isExtended = extended.isDefined) + } + // All other cases. + case _ => nodeToDescribeFallback(node) + } + } + + case _ => + super.nodeToPlan(node) + } + } + + protected def nodeToDescribeFallback(node: ASTNode): LogicalPlan = noParseRule("Describe", node) +} diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index ffabb92179a18c6ea488050133f1eb7ac8d61908..cd0c2aeb93a9fb78ff919ef37acc75616b84a126 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -262,26 +262,6 @@ </executions> </plugin> - - <plugin> - <groupId>org.antlr</groupId> - <artifactId>antlr3-maven-plugin</artifactId> - <executions> - <execution> - <goals> - <goal>antlr</goal> - </goals> - </execution> - </executions> - <configuration> - <sourceDirectory>${basedir}/src/main/antlr3</sourceDirectory> - <includes> - <include>**/SparkSqlLexer.g</include> - <include>**/SparkSqlParser.g</include> - </includes> - </configuration> - </plugin> - </plugins> </build> </project> diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/ASTErrorNode.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/ASTErrorNode.java deleted file mode 100644 index 35ecdc5ad10a917858dae62ba7c04bd9f1775f90..0000000000000000000000000000000000000000 --- a/sql/hive/src/main/java/org/apache/spark/sql/parser/ASTErrorNode.java +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parser; - -import org.antlr.runtime.RecognitionException; -import org.antlr.runtime.Token; -import org.antlr.runtime.TokenStream; -import org.antlr.runtime.tree.CommonErrorNode; - -public class ASTErrorNode extends ASTNode { - - /** - * - */ - private static final long serialVersionUID = 1L; - CommonErrorNode delegate; - - public ASTErrorNode(TokenStream input, Token start, Token stop, - RecognitionException e){ - delegate = new CommonErrorNode(input,start,stop,e); - } - - @Override - public boolean isNil() { return delegate.isNil(); } - - @Override - public int getType() { return delegate.getType(); } - - @Override - public String getText() { return delegate.getText(); } - @Override - public String toString() { return delegate.toString(); } -} diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/ASTNode.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/ASTNode.java deleted file mode 100644 index 33d9322b628ecdf41c97a820dc7984d9b8529dbe..0000000000000000000000000000000000000000 --- a/sql/hive/src/main/java/org/apache/spark/sql/parser/ASTNode.java +++ /dev/null @@ -1,245 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parser; - -import java.io.Serializable; -import java.util.ArrayList; -import java.util.List; - -import org.antlr.runtime.Token; -import org.antlr.runtime.tree.CommonTree; -import org.antlr.runtime.tree.Tree; -import org.apache.hadoop.hive.ql.lib.Node; - -public class ASTNode extends CommonTree implements Node, Serializable { - private static final long serialVersionUID = 1L; - private transient StringBuffer astStr; - private transient int startIndx = -1; - private transient int endIndx = -1; - private transient ASTNode rootNode; - private transient boolean isValidASTStr; - - public ASTNode() { - } - - /** - * Constructor. - * - * @param t - * Token for the CommonTree Node - */ - public ASTNode(Token t) { - super(t); - } - - public ASTNode(ASTNode node) { - super(node); - } - - @Override - public Tree dupNode() { - return new ASTNode(this); - } - - /* - * (non-Javadoc) - * - * @see org.apache.hadoop.hive.ql.lib.Node#getChildren() - */ - @Override - public ArrayList<Node> getChildren() { - if (super.getChildCount() == 0) { - return null; - } - - ArrayList<Node> ret_vec = new ArrayList<Node>(); - for (int i = 0; i < super.getChildCount(); ++i) { - ret_vec.add((Node) super.getChild(i)); - } - - return ret_vec; - } - - /* - * (non-Javadoc) - * - * @see org.apache.hadoop.hive.ql.lib.Node#getName() - */ - @Override - public String getName() { - return (Integer.valueOf(super.getToken().getType())).toString(); - } - - public String dump() { - StringBuilder sb = new StringBuilder("\n"); - dump(sb, ""); - return sb.toString(); - } - - private StringBuilder dump(StringBuilder sb, String ws) { - sb.append(ws); - sb.append(toString()); - sb.append("\n"); - - ArrayList<Node> children = getChildren(); - if (children != null) { - for (Node node : getChildren()) { - if (node instanceof ASTNode) { - ((ASTNode) node).dump(sb, ws + " "); - } else { - sb.append(ws); - sb.append(" NON-ASTNODE!!"); - sb.append("\n"); - } - } - } - return sb; - } - - private ASTNode getRootNodeWithValidASTStr(boolean useMemoizedRoot) { - if (useMemoizedRoot && rootNode != null && rootNode.parent == null && - rootNode.hasValidMemoizedString()) { - return rootNode; - } - ASTNode retNode = this; - while (retNode.parent != null) { - retNode = (ASTNode) retNode.parent; - } - rootNode=retNode; - if (!rootNode.isValidASTStr) { - rootNode.astStr = new StringBuffer(); - rootNode.toStringTree(rootNode); - rootNode.isValidASTStr = true; - } - return retNode; - } - - private boolean hasValidMemoizedString() { - return isValidASTStr && astStr != null; - } - - private void resetRootInformation() { - // Reset the previously stored rootNode string - if (rootNode != null) { - rootNode.astStr = null; - rootNode.isValidASTStr = false; - } - } - - private int getMemoizedStringLen() { - return astStr == null ? 0 : astStr.length(); - } - - private String getMemoizedSubString(int start, int end) { - return (astStr == null || start < 0 || end > astStr.length() || start >= end) ? null : - astStr.subSequence(start, end).toString(); - } - - private void addtoMemoizedString(String string) { - if (astStr == null) { - astStr = new StringBuffer(); - } - astStr.append(string); - } - - @Override - public void setParent(Tree t) { - super.setParent(t); - resetRootInformation(); - } - - @Override - public void addChild(Tree t) { - super.addChild(t); - resetRootInformation(); - } - - @Override - public void addChildren(List kids) { - super.addChildren(kids); - resetRootInformation(); - } - - @Override - public void setChild(int i, Tree t) { - super.setChild(i, t); - resetRootInformation(); - } - - @Override - public void insertChild(int i, Object t) { - super.insertChild(i, t); - resetRootInformation(); - } - - @Override - public Object deleteChild(int i) { - Object ret = super.deleteChild(i); - resetRootInformation(); - return ret; - } - - @Override - public void replaceChildren(int startChildIndex, int stopChildIndex, Object t) { - super.replaceChildren(startChildIndex, stopChildIndex, t); - resetRootInformation(); - } - - @Override - public String toStringTree() { - - // The root might have changed because of tree modifications. - // Compute the new root for this tree and set the astStr. - getRootNodeWithValidASTStr(true); - - // If rootNotModified is false, then startIndx and endIndx will be stale. - if (startIndx >= 0 && endIndx <= rootNode.getMemoizedStringLen()) { - return rootNode.getMemoizedSubString(startIndx, endIndx); - } - return toStringTree(rootNode); - } - - private String toStringTree(ASTNode rootNode) { - this.rootNode = rootNode; - startIndx = rootNode.getMemoizedStringLen(); - // Leaf node - if ( children==null || children.size()==0 ) { - rootNode.addtoMemoizedString(this.toString()); - endIndx = rootNode.getMemoizedStringLen(); - return this.toString(); - } - if ( !isNil() ) { - rootNode.addtoMemoizedString("("); - rootNode.addtoMemoizedString(this.toString()); - rootNode.addtoMemoizedString(" "); - } - for (int i = 0; children!=null && i < children.size(); i++) { - ASTNode t = (ASTNode)children.get(i); - if ( i>0 ) { - rootNode.addtoMemoizedString(" "); - } - t.toStringTree(rootNode); - } - if ( !isNil() ) { - rootNode.addtoMemoizedString(")"); - } - endIndx = rootNode.getMemoizedStringLen(); - return rootNode.getMemoizedSubString(startIndx, endIndx); - } -} diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseDriver.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseDriver.java deleted file mode 100644 index c77198b087cbdbd89312bd3fff61b1a6224a1e64..0000000000000000000000000000000000000000 --- a/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseDriver.java +++ /dev/null @@ -1,213 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parser; - -import java.util.ArrayList; -import org.antlr.runtime.ANTLRStringStream; -import org.antlr.runtime.CharStream; -import org.antlr.runtime.NoViableAltException; -import org.antlr.runtime.RecognitionException; -import org.antlr.runtime.Token; -import org.antlr.runtime.TokenRewriteStream; -import org.antlr.runtime.TokenStream; -import org.antlr.runtime.tree.CommonTree; -import org.antlr.runtime.tree.CommonTreeAdaptor; -import org.antlr.runtime.tree.TreeAdaptor; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.apache.hadoop.hive.ql.Context; - -/** - * ParseDriver. - * - */ -public class ParseDriver { - - private static final Logger LOG = LoggerFactory.getLogger("hive.ql.parse.ParseDriver"); - - /** - * ANTLRNoCaseStringStream. - * - */ - //This class provides and implementation for a case insensitive token checker - //for the lexical analysis part of antlr. By converting the token stream into - //upper case at the time when lexical rules are checked, this class ensures that the - //lexical rules need to just match the token with upper case letters as opposed to - //combination of upper case and lower case characters. This is purely used for matching lexical - //rules. The actual token text is stored in the same way as the user input without - //actually converting it into an upper case. The token values are generated by the consume() - //function of the super class ANTLRStringStream. The LA() function is the lookahead function - //and is purely used for matching lexical rules. This also means that the grammar will only - //accept capitalized tokens in case it is run from other tools like antlrworks which - //do not have the ANTLRNoCaseStringStream implementation. - public class ANTLRNoCaseStringStream extends ANTLRStringStream { - - public ANTLRNoCaseStringStream(String input) { - super(input); - } - - @Override - public int LA(int i) { - - int returnChar = super.LA(i); - if (returnChar == CharStream.EOF) { - return returnChar; - } else if (returnChar == 0) { - return returnChar; - } - - return Character.toUpperCase((char) returnChar); - } - } - - /** - * HiveLexerX. - * - */ - public class HiveLexerX extends SparkSqlLexer { - - private final ArrayList<ParseError> errors; - - public HiveLexerX(CharStream input) { - super(input); - errors = new ArrayList<ParseError>(); - } - - @Override - public void displayRecognitionError(String[] tokenNames, RecognitionException e) { - errors.add(new ParseError(this, e, tokenNames)); - } - - @Override - public String getErrorMessage(RecognitionException e, String[] tokenNames) { - String msg = null; - - if (e instanceof NoViableAltException) { - // @SuppressWarnings("unused") - // NoViableAltException nvae = (NoViableAltException) e; - // for development, can add - // "decision=<<"+nvae.grammarDecisionDescription+">>" - // and "(decision="+nvae.decisionNumber+") and - // "state "+nvae.stateNumber - msg = "character " + getCharErrorDisplay(e.c) + " not supported here"; - } else { - msg = super.getErrorMessage(e, tokenNames); - } - - return msg; - } - - public ArrayList<ParseError> getErrors() { - return errors; - } - - } - - /** - * Tree adaptor for making antlr return ASTNodes instead of CommonTree nodes - * so that the graph walking algorithms and the rules framework defined in - * ql.lib can be used with the AST Nodes. - */ - public static final TreeAdaptor adaptor = new CommonTreeAdaptor() { - /** - * Creates an ASTNode for the given token. The ASTNode is a wrapper around - * antlr's CommonTree class that implements the Node interface. - * - * @param payload - * The token. - * @return Object (which is actually an ASTNode) for the token. - */ - @Override - public Object create(Token payload) { - return new ASTNode(payload); - } - - @Override - public Object dupNode(Object t) { - - return create(((CommonTree)t).token); - }; - - @Override - public Object errorNode(TokenStream input, Token start, Token stop, RecognitionException e) { - return new ASTErrorNode(input, start, stop, e); - }; - }; - - public ASTNode parse(String command) throws ParseException { - return parse(command, null); - } - - public ASTNode parse(String command, Context ctx) - throws ParseException { - return parse(command, ctx, true); - } - - /** - * Parses a command, optionally assigning the parser's token stream to the - * given context. - * - * @param command - * command to parse - * - * @param ctx - * context with which to associate this parser's token stream, or - * null if either no context is available or the context already has - * an existing stream - * - * @return parsed AST - */ - public ASTNode parse(String command, Context ctx, boolean setTokenRewriteStream) - throws ParseException { - LOG.info("Parsing command: " + command); - - HiveLexerX lexer = new HiveLexerX(new ANTLRNoCaseStringStream(command)); - TokenRewriteStream tokens = new TokenRewriteStream(lexer); - if (ctx != null) { - if ( setTokenRewriteStream) { - ctx.setTokenRewriteStream(tokens); - } - lexer.setHiveConf(ctx.getConf()); - } - SparkSqlParser parser = new SparkSqlParser(tokens); - if (ctx != null) { - parser.setHiveConf(ctx.getConf()); - } - parser.setTreeAdaptor(adaptor); - SparkSqlParser.statement_return r = null; - try { - r = parser.statement(); - } catch (RecognitionException e) { - e.printStackTrace(); - throw new ParseException(parser.errors); - } - - if (lexer.getErrors().size() == 0 && parser.errors.size() == 0) { - LOG.info("Parse Completed"); - } else if (lexer.getErrors().size() != 0) { - throw new ParseException(lexer.getErrors()); - } else { - throw new ParseException(parser.errors); - } - - ASTNode tree = (ASTNode) r.getTree(); - tree.setUnknownTokenBoundaries(); - return tree; - } -} diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseError.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseError.java deleted file mode 100644 index b47bcfb2914dfb89b349443276f074bd36697641..0000000000000000000000000000000000000000 --- a/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseError.java +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parser; - -import org.antlr.runtime.BaseRecognizer; -import org.antlr.runtime.RecognitionException; - -/** - * - */ -public class ParseError { - private final BaseRecognizer br; - private final RecognitionException re; - private final String[] tokenNames; - - ParseError(BaseRecognizer br, RecognitionException re, String[] tokenNames) { - this.br = br; - this.re = re; - this.tokenNames = tokenNames; - } - - BaseRecognizer getBaseRecognizer() { - return br; - } - - RecognitionException getRecognitionException() { - return re; - } - - String[] getTokenNames() { - return tokenNames; - } - - String getMessage() { - return br.getErrorHeader(re) + " " + br.getErrorMessage(re, tokenNames); - } - -} diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseException.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseException.java deleted file mode 100644 index fff891ced5550fa8ad894ae4cba358a359c906ce..0000000000000000000000000000000000000000 --- a/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseException.java +++ /dev/null @@ -1,51 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parser; - -import java.util.ArrayList; - -/** - * ParseException. - * - */ -public class ParseException extends Exception { - - private static final long serialVersionUID = 1L; - ArrayList<ParseError> errors; - - public ParseException(ArrayList<ParseError> errors) { - super(); - this.errors = errors; - } - - @Override - public String getMessage() { - - StringBuilder sb = new StringBuilder(); - for (ParseError err : errors) { - if (sb.length() > 0) { - sb.append('\n'); - } - sb.append(err.getMessage()); - } - - return sb.toString(); - } - -} diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseUtils.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseUtils.java deleted file mode 100644 index a5c2998f86cc1751330ce98e7a3d019370722955..0000000000000000000000000000000000000000 --- a/sql/hive/src/main/java/org/apache/spark/sql/parser/ParseUtils.java +++ /dev/null @@ -1,96 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parser; - -import org.apache.hadoop.hive.common.type.HiveDecimal; -import org.apache.hadoop.hive.ql.parse.SemanticException; -import org.apache.hadoop.hive.serde2.typeinfo.CharTypeInfo; -import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; -import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; -import org.apache.hadoop.hive.serde2.typeinfo.VarcharTypeInfo; - - -/** - * Library of utility functions used in the parse code. - * - */ -public final class ParseUtils { - /** - * Performs a descent of the leftmost branch of a tree, stopping when either a - * node with a non-null token is found or the leaf level is encountered. - * - * @param tree - * candidate node from which to start searching - * - * @return node at which descent stopped - */ - public static ASTNode findRootNonNullToken(ASTNode tree) { - while ((tree.getToken() == null) && (tree.getChildCount() > 0)) { - tree = (org.apache.spark.sql.parser.ASTNode) tree.getChild(0); - } - return tree; - } - - private ParseUtils() { - // prevent instantiation - } - - public static VarcharTypeInfo getVarcharTypeInfo(ASTNode node) - throws SemanticException { - if (node.getChildCount() != 1) { - throw new SemanticException("Bad params for type varchar"); - } - - String lengthStr = node.getChild(0).getText(); - return TypeInfoFactory.getVarcharTypeInfo(Integer.valueOf(lengthStr)); - } - - public static CharTypeInfo getCharTypeInfo(ASTNode node) - throws SemanticException { - if (node.getChildCount() != 1) { - throw new SemanticException("Bad params for type char"); - } - - String lengthStr = node.getChild(0).getText(); - return TypeInfoFactory.getCharTypeInfo(Integer.valueOf(lengthStr)); - } - - public static DecimalTypeInfo getDecimalTypeTypeInfo(ASTNode node) - throws SemanticException { - if (node.getChildCount() > 2) { - throw new SemanticException("Bad params for type decimal"); - } - - int precision = HiveDecimal.USER_DEFAULT_PRECISION; - int scale = HiveDecimal.USER_DEFAULT_SCALE; - - if (node.getChildCount() >= 1) { - String precStr = node.getChild(0).getText(); - precision = Integer.valueOf(precStr); - } - - if (node.getChildCount() == 2) { - String scaleStr = node.getChild(1).getText(); - scale = Integer.valueOf(scaleStr); - } - - return TypeInfoFactory.getDecimalTypeInfo(precision, scale); - } - -} diff --git a/sql/hive/src/main/java/org/apache/spark/sql/parser/SemanticAnalyzer.java b/sql/hive/src/main/java/org/apache/spark/sql/parser/SemanticAnalyzer.java deleted file mode 100644 index 4b2015e0df84eaa8a631d2d61c4b535a96106e60..0000000000000000000000000000000000000000 --- a/sql/hive/src/main/java/org/apache/spark/sql/parser/SemanticAnalyzer.java +++ /dev/null @@ -1,406 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parser; - -import java.io.UnsupportedEncodingException; -import java.net.URI; -import java.net.URISyntaxException; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import org.antlr.runtime.tree.Tree; -import org.apache.commons.lang.StringUtils; -import org.apache.hadoop.fs.Path; -import org.apache.hadoop.hive.conf.HiveConf; -import org.apache.hadoop.hive.metastore.api.FieldSchema; -import org.apache.hadoop.hive.ql.ErrorMsg; -import org.apache.hadoop.hive.ql.parse.SemanticException; -import org.apache.hadoop.hive.serde.serdeConstants; -import org.apache.hadoop.hive.serde2.typeinfo.CharTypeInfo; -import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; -import org.apache.hadoop.hive.serde2.typeinfo.VarcharTypeInfo; - -/** - * SemanticAnalyzer. - * - */ -public abstract class SemanticAnalyzer { - public static String charSetString(String charSetName, String charSetString) - throws SemanticException { - try { - // The character set name starts with a _, so strip that - charSetName = charSetName.substring(1); - if (charSetString.charAt(0) == '\'') { - return new String(unescapeSQLString(charSetString).getBytes(), - charSetName); - } else // hex input is also supported - { - assert charSetString.charAt(0) == '0'; - assert charSetString.charAt(1) == 'x'; - charSetString = charSetString.substring(2); - - byte[] bArray = new byte[charSetString.length() / 2]; - int j = 0; - for (int i = 0; i < charSetString.length(); i += 2) { - int val = Character.digit(charSetString.charAt(i), 16) * 16 - + Character.digit(charSetString.charAt(i + 1), 16); - if (val > 127) { - val = val - 256; - } - bArray[j++] = (byte)val; - } - - String res = new String(bArray, charSetName); - return res; - } - } catch (UnsupportedEncodingException e) { - throw new SemanticException(e); - } - } - - /** - * Remove the encapsulating "`" pair from the identifier. We allow users to - * use "`" to escape identifier for table names, column names and aliases, in - * case that coincide with Hive language keywords. - */ - public static String unescapeIdentifier(String val) { - if (val == null) { - return null; - } - if (val.charAt(0) == '`' && val.charAt(val.length() - 1) == '`') { - val = val.substring(1, val.length() - 1); - } - return val; - } - - /** - * Converts parsed key/value properties pairs into a map. - * - * @param prop ASTNode parent of the key/value pairs - * - * @param mapProp property map which receives the mappings - */ - public static void readProps( - ASTNode prop, Map<String, String> mapProp) { - - for (int propChild = 0; propChild < prop.getChildCount(); propChild++) { - String key = unescapeSQLString(prop.getChild(propChild).getChild(0) - .getText()); - String value = null; - if (prop.getChild(propChild).getChild(1) != null) { - value = unescapeSQLString(prop.getChild(propChild).getChild(1).getText()); - } - mapProp.put(key, value); - } - } - - private static final int[] multiplier = new int[] {1000, 100, 10, 1}; - - @SuppressWarnings("nls") - public static String unescapeSQLString(String b) { - Character enclosure = null; - - // Some of the strings can be passed in as unicode. For example, the - // delimiter can be passed in as \002 - So, we first check if the - // string is a unicode number, else go back to the old behavior - StringBuilder sb = new StringBuilder(b.length()); - for (int i = 0; i < b.length(); i++) { - - char currentChar = b.charAt(i); - if (enclosure == null) { - if (currentChar == '\'' || b.charAt(i) == '\"') { - enclosure = currentChar; - } - // ignore all other chars outside the enclosure - continue; - } - - if (enclosure.equals(currentChar)) { - enclosure = null; - continue; - } - - if (currentChar == '\\' && (i + 6 < b.length()) && b.charAt(i + 1) == 'u') { - int code = 0; - int base = i + 2; - for (int j = 0; j < 4; j++) { - int digit = Character.digit(b.charAt(j + base), 16); - code += digit * multiplier[j]; - } - sb.append((char)code); - i += 5; - continue; - } - - if (currentChar == '\\' && (i + 4 < b.length())) { - char i1 = b.charAt(i + 1); - char i2 = b.charAt(i + 2); - char i3 = b.charAt(i + 3); - if ((i1 >= '0' && i1 <= '1') && (i2 >= '0' && i2 <= '7') - && (i3 >= '0' && i3 <= '7')) { - byte bVal = (byte) ((i3 - '0') + ((i2 - '0') * 8) + ((i1 - '0') * 8 * 8)); - byte[] bValArr = new byte[1]; - bValArr[0] = bVal; - String tmp = new String(bValArr); - sb.append(tmp); - i += 3; - continue; - } - } - - if (currentChar == '\\' && (i + 2 < b.length())) { - char n = b.charAt(i + 1); - switch (n) { - case '0': - sb.append("\0"); - break; - case '\'': - sb.append("'"); - break; - case '"': - sb.append("\""); - break; - case 'b': - sb.append("\b"); - break; - case 'n': - sb.append("\n"); - break; - case 'r': - sb.append("\r"); - break; - case 't': - sb.append("\t"); - break; - case 'Z': - sb.append("\u001A"); - break; - case '\\': - sb.append("\\"); - break; - // The following 2 lines are exactly what MySQL does TODO: why do we do this? - case '%': - sb.append("\\%"); - break; - case '_': - sb.append("\\_"); - break; - default: - sb.append(n); - } - i++; - } else { - sb.append(currentChar); - } - } - return sb.toString(); - } - - /** - * Get the list of FieldSchema out of the ASTNode. - */ - public static List<FieldSchema> getColumns(ASTNode ast, boolean lowerCase) throws SemanticException { - List<FieldSchema> colList = new ArrayList<FieldSchema>(); - int numCh = ast.getChildCount(); - for (int i = 0; i < numCh; i++) { - FieldSchema col = new FieldSchema(); - ASTNode child = (ASTNode) ast.getChild(i); - Tree grandChild = child.getChild(0); - if(grandChild != null) { - String name = grandChild.getText(); - if(lowerCase) { - name = name.toLowerCase(); - } - // child 0 is the name of the column - col.setName(unescapeIdentifier(name)); - // child 1 is the type of the column - ASTNode typeChild = (ASTNode) (child.getChild(1)); - col.setType(getTypeStringFromAST(typeChild)); - - // child 2 is the optional comment of the column - if (child.getChildCount() == 3) { - col.setComment(unescapeSQLString(child.getChild(2).getText())); - } - } - colList.add(col); - } - return colList; - } - - protected static String getTypeStringFromAST(ASTNode typeNode) - throws SemanticException { - switch (typeNode.getType()) { - case SparkSqlParser.TOK_LIST: - return serdeConstants.LIST_TYPE_NAME + "<" - + getTypeStringFromAST((ASTNode) typeNode.getChild(0)) + ">"; - case SparkSqlParser.TOK_MAP: - return serdeConstants.MAP_TYPE_NAME + "<" - + getTypeStringFromAST((ASTNode) typeNode.getChild(0)) + "," - + getTypeStringFromAST((ASTNode) typeNode.getChild(1)) + ">"; - case SparkSqlParser.TOK_STRUCT: - return getStructTypeStringFromAST(typeNode); - case SparkSqlParser.TOK_UNIONTYPE: - return getUnionTypeStringFromAST(typeNode); - default: - return getTypeName(typeNode); - } - } - - private static String getStructTypeStringFromAST(ASTNode typeNode) - throws SemanticException { - String typeStr = serdeConstants.STRUCT_TYPE_NAME + "<"; - typeNode = (ASTNode) typeNode.getChild(0); - int children = typeNode.getChildCount(); - if (children <= 0) { - throw new SemanticException("empty struct not allowed."); - } - StringBuilder buffer = new StringBuilder(typeStr); - for (int i = 0; i < children; i++) { - ASTNode child = (ASTNode) typeNode.getChild(i); - buffer.append(unescapeIdentifier(child.getChild(0).getText())).append(":"); - buffer.append(getTypeStringFromAST((ASTNode) child.getChild(1))); - if (i < children - 1) { - buffer.append(","); - } - } - - buffer.append(">"); - return buffer.toString(); - } - - private static String getUnionTypeStringFromAST(ASTNode typeNode) - throws SemanticException { - String typeStr = serdeConstants.UNION_TYPE_NAME + "<"; - typeNode = (ASTNode) typeNode.getChild(0); - int children = typeNode.getChildCount(); - if (children <= 0) { - throw new SemanticException("empty union not allowed."); - } - StringBuilder buffer = new StringBuilder(typeStr); - for (int i = 0; i < children; i++) { - buffer.append(getTypeStringFromAST((ASTNode) typeNode.getChild(i))); - if (i < children - 1) { - buffer.append(","); - } - } - buffer.append(">"); - typeStr = buffer.toString(); - return typeStr; - } - - public static String getAstNodeText(ASTNode tree) { - return tree.getChildCount() == 0?tree.getText() : - getAstNodeText((ASTNode)tree.getChild(tree.getChildCount() - 1)); - } - - public static String generateErrorMessage(ASTNode ast, String message) { - StringBuilder sb = new StringBuilder(); - if (ast == null) { - sb.append(message).append(". Cannot tell the position of null AST."); - return sb.toString(); - } - sb.append(ast.getLine()); - sb.append(":"); - sb.append(ast.getCharPositionInLine()); - sb.append(" "); - sb.append(message); - sb.append(". Error encountered near token '"); - sb.append(getAstNodeText(ast)); - sb.append("'"); - return sb.toString(); - } - - private static final Map<Integer, String> TokenToTypeName = new HashMap<Integer, String>(); - - static { - TokenToTypeName.put(SparkSqlParser.TOK_BOOLEAN, serdeConstants.BOOLEAN_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_TINYINT, serdeConstants.TINYINT_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_SMALLINT, serdeConstants.SMALLINT_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_INT, serdeConstants.INT_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_BIGINT, serdeConstants.BIGINT_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_FLOAT, serdeConstants.FLOAT_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_DOUBLE, serdeConstants.DOUBLE_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_STRING, serdeConstants.STRING_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_CHAR, serdeConstants.CHAR_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_VARCHAR, serdeConstants.VARCHAR_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_BINARY, serdeConstants.BINARY_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_DATE, serdeConstants.DATE_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_DATETIME, serdeConstants.DATETIME_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_TIMESTAMP, serdeConstants.TIMESTAMP_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_INTERVAL_YEAR_MONTH, serdeConstants.INTERVAL_YEAR_MONTH_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_INTERVAL_DAY_TIME, serdeConstants.INTERVAL_DAY_TIME_TYPE_NAME); - TokenToTypeName.put(SparkSqlParser.TOK_DECIMAL, serdeConstants.DECIMAL_TYPE_NAME); - } - - public static String getTypeName(ASTNode node) throws SemanticException { - int token = node.getType(); - String typeName; - - // datetime type isn't currently supported - if (token == SparkSqlParser.TOK_DATETIME) { - throw new SemanticException(ErrorMsg.UNSUPPORTED_TYPE.getMsg()); - } - - switch (token) { - case SparkSqlParser.TOK_CHAR: - CharTypeInfo charTypeInfo = ParseUtils.getCharTypeInfo(node); - typeName = charTypeInfo.getQualifiedName(); - break; - case SparkSqlParser.TOK_VARCHAR: - VarcharTypeInfo varcharTypeInfo = ParseUtils.getVarcharTypeInfo(node); - typeName = varcharTypeInfo.getQualifiedName(); - break; - case SparkSqlParser.TOK_DECIMAL: - DecimalTypeInfo decTypeInfo = ParseUtils.getDecimalTypeTypeInfo(node); - typeName = decTypeInfo.getQualifiedName(); - break; - default: - typeName = TokenToTypeName.get(token); - } - return typeName; - } - - public static String relativeToAbsolutePath(HiveConf conf, String location) throws SemanticException { - boolean testMode = conf.getBoolVar(HiveConf.ConfVars.HIVETESTMODE); - if (testMode) { - URI uri = new Path(location).toUri(); - String scheme = uri.getScheme(); - String authority = uri.getAuthority(); - String path = uri.getPath(); - if (!path.startsWith("/")) { - path = (new Path(System.getProperty("test.tmp.dir"), - path)).toUri().getPath(); - } - if (StringUtils.isEmpty(scheme)) { - scheme = "pfile"; - } - try { - uri = new URI(scheme, authority, path, null, null); - } catch (URISyntaxException e) { - throw new SemanticException(ErrorMsg.INVALID_PATH.getMsg(), e); - } - return uri.toString(); - } else { - //no-op for non-test mode for now - return location; - } - } -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 31d82eb20f6e4a82ab07b8ec675c4e2793468422..bf3fe12d5c5d2980ce37c3cdb7b696f6ba807407 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -17,41 +17,30 @@ package org.apache.spark.sql.hive -import java.sql.Date import java.util.Locale import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer +import org.apache.hadoop.hive.common.`type`.HiveDecimal import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hadoop.hive.ql.{Context, ErrorMsg} -import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry} -import org.apache.hadoop.hive.ql.lib.Node -import org.apache.hadoop.hive.ql.parse.SemanticException -import org.apache.hadoop.hive.ql.plan.PlanUtils +import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, FunctionInfo} +import org.apache.hadoop.hive.ql.parse.EximUtil import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe - import org.apache.spark.Logging -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.parser._ +import org.apache.spark.sql.catalyst.parser.ParseUtils._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.trees.CurrentOrigin -import org.apache.spark.sql.execution.ExplainCommand -import org.apache.spark.sql.execution.datasources.DescribeCommand -import org.apache.spark.sql.hive.HiveShim._ +import org.apache.spark.sql.execution.SparkQl +import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper import org.apache.spark.sql.hive.client._ -import org.apache.spark.sql.hive.execution.{AnalyzeTable, DropTable, HiveNativeCommand, HiveScriptIOSchema} -import org.apache.spark.sql.parser._ +import org.apache.spark.sql.hive.execution.{HiveNativeCommand, AnalyzeTable, DropTable, HiveScriptIOSchema} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.CalendarInterval -import org.apache.spark.util.random.RandomSampler +import org.apache.spark.sql.AnalysisException /** * Used when we need to start parsing the AST before deciding that we are going to pass the command @@ -71,7 +60,7 @@ private[hive] case class CreateTableAsSelect( override def output: Seq[Attribute] = Seq.empty[Attribute] override lazy val resolved: Boolean = tableDesc.specifiedDatabase.isDefined && - tableDesc.schema.size > 0 && + tableDesc.schema.nonEmpty && tableDesc.serde.isDefined && tableDesc.inputFormat.isDefined && tableDesc.outputFormat.isDefined && @@ -89,7 +78,7 @@ private[hive] case class CreateViewAsSelect( } /** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */ -private[hive] object HiveQl extends Logging { +private[hive] object HiveQl extends SparkQl with Logging { protected val nativeCommands = Seq( "TOK_ALTERDATABASE_OWNER", "TOK_ALTERDATABASE_PROPERTIES", @@ -180,103 +169,6 @@ private[hive] object HiveQl extends Logging { protected val hqlParser = new ExtendedHiveQlParser - /** - * A set of implicit transformations that allow Hive ASTNodes to be rewritten by transformations - * similar to [[catalyst.trees.TreeNode]]. - * - * Note that this should be considered very experimental and is not indented as a replacement - * for TreeNode. Primarily it should be noted ASTNodes are not immutable and do not appear to - * have clean copy semantics. Therefore, users of this class should take care when - * copying/modifying trees that might be used elsewhere. - */ - implicit class TransformableNode(n: ASTNode) { - /** - * Returns a copy of this node where `rule` has been recursively applied to it and all of its - * children. When `rule` does not apply to a given node it is left unchanged. - * @param rule the function use to transform this nodes children - */ - def transform(rule: PartialFunction[ASTNode, ASTNode]): ASTNode = { - try { - val afterRule = rule.applyOrElse(n, identity[ASTNode]) - afterRule.withChildren( - nilIfEmpty(afterRule.getChildren) - .asInstanceOf[Seq[ASTNode]] - .map(ast => Option(ast).map(_.transform(rule)).orNull)) - } catch { - case e: Exception => - logError(dumpTree(n).toString) - throw e - } - } - - /** - * Returns a scala.Seq equivalent to [s] or Nil if [s] is null. - */ - private def nilIfEmpty[A](s: java.util.List[A]): Seq[A] = - Option(s).map(_.asScala).getOrElse(Nil) - - /** - * Returns this ASTNode with the text changed to `newText`. - */ - def withText(newText: String): ASTNode = { - n.token.asInstanceOf[org.antlr.runtime.CommonToken].setText(newText) - n - } - - /** - * Returns this ASTNode with the children changed to `newChildren`. - */ - def withChildren(newChildren: Seq[ASTNode]): ASTNode = { - (1 to n.getChildCount).foreach(_ => n.deleteChild(0)) - newChildren.foreach(n.addChild(_)) - n - } - - /** - * Throws an error if this is not equal to other. - * - * Right now this function only checks the name, type, text and children of the node - * for equality. - */ - def checkEquals(other: ASTNode): Unit = { - def check(field: String, f: ASTNode => Any): Unit = if (f(n) != f(other)) { - sys.error(s"$field does not match for trees. " + - s"'${f(n)}' != '${f(other)}' left: ${dumpTree(n)}, right: ${dumpTree(other)}") - } - check("name", _.getName) - check("type", _.getType) - check("text", _.getText) - check("numChildren", n => nilIfEmpty(n.getChildren).size) - - val leftChildren = nilIfEmpty(n.getChildren).asInstanceOf[Seq[ASTNode]] - val rightChildren = nilIfEmpty(other.getChildren).asInstanceOf[Seq[ASTNode]] - leftChildren zip rightChildren foreach { - case (l, r) => l checkEquals r - } - } - } - - /** - * Returns the AST for the given SQL string. - */ - def getAst(sql: String): ASTNode = { - /* - * Context has to be passed in hive0.13.1. - * Otherwise, there will be Null pointer exception, - * when retrieving properties form HiveConf. - */ - val hContext = createContext() - val node = getAst(sql, hContext) - hContext.clear() - node - } - - private def createContext(): Context = new Context(hiveConf) - - private def getAst(sql: String, context: Context) = - ParseUtils.findRootNonNullToken( - (new ParseDriver).parse(sql, context)) - /** * Returns the HiveConf */ @@ -296,226 +188,16 @@ private[hive] object HiveQl extends Logging { /** Returns a LogicalPlan for a given HiveQL string. */ def parseSql(sql: String): LogicalPlan = hqlParser.parse(sql) - val errorRegEx = "line (\\d+):(\\d+) (.*)".r - - /** Creates LogicalPlan for a given HiveQL string. */ - def createPlan(sql: String): LogicalPlan = { - try { - val context = createContext() - val tree = getAst(sql, context) - val plan = if (nativeCommands contains tree.getText) { - HiveNativeCommand(sql) - } else { - nodeToPlan(tree, context) match { - case NativePlaceholder => HiveNativeCommand(sql) - case other => other - } - } - context.clear() - plan - } catch { - case pe: ParseException => - pe.getMessage match { - case errorRegEx(line, start, message) => - throw new AnalysisException(message, Some(line.toInt), Some(start.toInt)) - case otherMessage => - throw new AnalysisException(otherMessage) - } - case e: MatchError => throw e - case e: Exception => - throw new AnalysisException(e.getMessage) - case e: NotImplementedError => - throw new AnalysisException( - s""" - |Unsupported language features in query: $sql - |${dumpTree(getAst(sql))} - |$e - |${e.getStackTrace.head} - """.stripMargin) - } - } - - def parseDdl(ddl: String): Seq[Attribute] = { - val tree = - try { - ParseUtils.findRootNonNullToken( - (new ParseDriver) - .parse(ddl, null /* no context required for parsing alone */)) - } catch { - case pe: org.apache.hadoop.hive.ql.parse.ParseException => - throw new RuntimeException(s"Failed to parse ddl: '$ddl'", pe) - } - assert(tree.asInstanceOf[ASTNode].getText == "TOK_CREATETABLE", "Only CREATE TABLE supported.") - val tableOps = tree.getChildren - val colList = - tableOps.asScala - .find(_.asInstanceOf[ASTNode].getText == "TOK_TABCOLLIST") - .getOrElse(sys.error("No columnList!")).getChildren - - colList.asScala.map(nodeToAttribute) - } - - /** Extractor for matching Hive's AST Tokens. */ - private[hive] case class Token(name: String, children: Seq[ASTNode]) extends Node { - def getName(): String = name - def getChildren(): java.util.List[Node] = { - val col = new java.util.ArrayList[Node](children.size) - children.foreach(col.add(_)) - col - } - } - object Token { - /** @return matches of the form (tokenName, children). */ - def unapply(t: Any): Option[(String, Seq[ASTNode])] = t match { - case t: ASTNode => - CurrentOrigin.setPosition(t.getLine, t.getCharPositionInLine) - Some((t.getText, - Option(t.getChildren).map(_.asScala.toList).getOrElse(Nil).asInstanceOf[Seq[ASTNode]])) - case t: Token => Some((t.name, t.children)) - case _ => None - } - } - - protected def getClauses( - clauseNames: Seq[String], - nodeList: Seq[ASTNode]): Seq[Option[ASTNode]] = { - var remainingNodes = nodeList - val clauses = clauseNames.map { clauseName => - val (matches, nonMatches) = remainingNodes.partition(_.getText.toUpperCase == clauseName) - remainingNodes = nonMatches ++ (if (matches.nonEmpty) matches.tail else Nil) - matches.headOption - } - - if (remainingNodes.nonEmpty) { - sys.error( - s"""Unhandled clauses: ${remainingNodes.map(dumpTree(_)).mkString("\n")}. - |You are likely trying to use an unsupported Hive feature."""".stripMargin) - } - clauses - } - - def getClause(clauseName: String, nodeList: Seq[Node]): Node = - getClauseOption(clauseName, nodeList).getOrElse(sys.error( - s"Expected clause $clauseName missing from ${nodeList.map(dumpTree(_)).mkString("\n")}")) - - def getClauseOption(clauseName: String, nodeList: Seq[Node]): Option[Node] = { - nodeList.filter { case ast: ASTNode => ast.getText == clauseName } match { - case Seq(oneMatch) => Some(oneMatch) - case Seq() => None - case _ => sys.error(s"Found multiple instances of clause $clauseName") - } - } - - protected def nodeToAttribute(node: Node): Attribute = node match { - case Token("TOK_TABCOL", Token(colName, Nil) :: dataType :: Nil) => - AttributeReference(colName, nodeToDataType(dataType), true)() - - case a: ASTNode => - throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") - } - - protected def nodeToDataType(node: Node): DataType = node match { - case Token("TOK_DECIMAL", precision :: scale :: Nil) => - DecimalType(precision.getText.toInt, scale.getText.toInt) - case Token("TOK_DECIMAL", precision :: Nil) => - DecimalType(precision.getText.toInt, 0) - case Token("TOK_DECIMAL", Nil) => DecimalType.USER_DEFAULT - case Token("TOK_BIGINT", Nil) => LongType - case Token("TOK_INT", Nil) => IntegerType - case Token("TOK_TINYINT", Nil) => ByteType - case Token("TOK_SMALLINT", Nil) => ShortType - case Token("TOK_BOOLEAN", Nil) => BooleanType - case Token("TOK_STRING", Nil) => StringType - case Token("TOK_VARCHAR", Token(_, Nil) :: Nil) => StringType - case Token("TOK_FLOAT", Nil) => FloatType - case Token("TOK_DOUBLE", Nil) => DoubleType - case Token("TOK_DATE", Nil) => DateType - case Token("TOK_TIMESTAMP", Nil) => TimestampType - case Token("TOK_BINARY", Nil) => BinaryType - case Token("TOK_LIST", elementType :: Nil) => ArrayType(nodeToDataType(elementType)) - case Token("TOK_STRUCT", - Token("TOK_TABCOLLIST", fields) :: Nil) => - StructType(fields.map(nodeToStructField)) - case Token("TOK_MAP", - keyType :: - valueType :: Nil) => - MapType(nodeToDataType(keyType), nodeToDataType(valueType)) - case a: ASTNode => - throw new NotImplementedError(s"No parse rules for DataType:\n ${dumpTree(a).toString} ") - } - - protected def nodeToStructField(node: Node): StructField = node match { - case Token("TOK_TABCOL", - Token(fieldName, Nil) :: - dataType :: Nil) => - StructField(fieldName, nodeToDataType(dataType), nullable = true) - case Token("TOK_TABCOL", - Token(fieldName, Nil) :: - dataType :: - _ /* comment */:: Nil) => - StructField(fieldName, nodeToDataType(dataType), nullable = true) - case a: ASTNode => - throw new NotImplementedError(s"No parse rules for StructField:\n ${dumpTree(a).toString} ") - } - - protected def extractTableIdent(tableNameParts: Node): TableIdentifier = { - tableNameParts.getChildren.asScala.map { - case Token(part, Nil) => cleanIdentifier(part) - } match { - case Seq(tableOnly) => TableIdentifier(tableOnly) - case Seq(databaseName, table) => TableIdentifier(table, Some(databaseName)) - case other => sys.error("Hive only supports tables names like 'tableName' " + - s"or 'databaseName.tableName', found '$other'") - } - } - - /** - * SELECT MAX(value) FROM src GROUP BY k1, k2, k3 GROUPING SETS((k1, k2), (k2)) - * is equivalent to - * SELECT MAX(value) FROM src GROUP BY k1, k2 UNION SELECT MAX(value) FROM src GROUP BY k2 - * Check the following link for details. - * -https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C+Grouping+and+Rollup - * - * The bitmask denotes the grouping expressions validity for a grouping set, - * the bitmask also be called as grouping id (`GROUPING__ID`, the virtual column in Hive) - * e.g. In superset (k1, k2, k3), (bit 0: k1, bit 1: k2, and bit 2: k3), the grouping id of - * GROUPING SETS (k1, k2) and (k2) should be 3 and 2 respectively. - */ - protected def extractGroupingSet(children: Seq[ASTNode]): (Seq[Expression], Seq[Int]) = { - val (keyASTs, setASTs) = children.partition( n => n match { - case Token("TOK_GROUPING_SETS_EXPRESSION", children) => false // grouping sets - case _ => true // grouping keys - }) - - val keys = keyASTs.map(nodeToExpr).toSeq - val keyMap = keyASTs.map(_.toStringTree).zipWithIndex.toMap - - val bitmasks: Seq[Int] = setASTs.map(set => set match { - case Token("TOK_GROUPING_SETS_EXPRESSION", null) => 0 - case Token("TOK_GROUPING_SETS_EXPRESSION", children) => - children.foldLeft(0)((bitmap, col) => { - val colString = col.asInstanceOf[ASTNode].toStringTree() - require(keyMap.contains(colString), s"$colString doens't show up in the GROUP BY list") - bitmap | 1 << keyMap(colString) - }) - case _ => sys.error("Expect GROUPING SETS clause") - }) - - (keys, bitmasks) - } - - protected def getProperties(node: Node): Seq[(String, String)] = node match { + protected def getProperties(node: ASTNode): Seq[(String, String)] = node match { case Token("TOK_TABLEPROPLIST", list) => list.map { case Token("TOK_TABLEPROPERTY", Token(key, Nil) :: Token(value, Nil) :: Nil) => - (unquoteString(key) -> unquoteString(value)) + unquoteString(key) -> unquoteString(value) } } private def createView( view: ASTNode, - context: Context, viewNameParts: ASTNode, query: ASTNode, schema: Seq[HiveColumn], @@ -524,8 +206,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C replace: Boolean): CreateViewAsSelect = { val TableIdentifier(viewName, dbName) = extractTableIdent(viewNameParts) - val originalText = context.getTokenRewriteStream - .toString(query.getTokenStartIndex, query.getTokenStopIndex) + val originalText = query.source val tableDesc = HiveTable( specifiedDatabase = dbName, @@ -544,104 +225,67 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // We need to keep the original SQL string so that if `spark.sql.nativeView` is // false, we can fall back to use hive native command later. // We can remove this when parser is configurable(can access SQLConf) in the future. - val sql = context.getTokenRewriteStream - .toString(view.getTokenStartIndex, view.getTokenStopIndex) - CreateViewAsSelect(tableDesc, nodeToPlan(query, context), allowExist, replace, sql) + val sql = view.source + CreateViewAsSelect(tableDesc, nodeToPlan(query), allowExist, replace, sql) } - protected def nodeToPlan(node: ASTNode, context: Context): LogicalPlan = node match { - // Special drop table that also uncaches. - case Token("TOK_DROPTABLE", - Token("TOK_TABNAME", tableNameParts) :: - ifExists) => - val tableName = tableNameParts.map { case Token(p, Nil) => p }.mkString(".") - DropTable(tableName, ifExists.nonEmpty) - // Support "ANALYZE TABLE tableNmae COMPUTE STATISTICS noscan" - case Token("TOK_ANALYZE", - Token("TOK_TAB", Token("TOK_TABNAME", tableNameParts) :: partitionSpec) :: - isNoscan) => - // Reference: - // https://cwiki.apache.org/confluence/display/Hive/StatsDev#StatsDev-ExistingTables - if (partitionSpec.nonEmpty) { - // Analyze partitions will be treated as a Hive native command. - NativePlaceholder - } else if (isNoscan.isEmpty) { - // If users do not specify "noscan", it will be treated as a Hive native command. - NativePlaceholder - } else { - val tableName = tableNameParts.map { case Token(p, Nil) => p }.mkString(".") - AnalyzeTable(tableName) + protected override def createPlan( + sql: String, + node: ASTNode): LogicalPlan = { + if (nativeCommands.contains(node.text)) { + HiveNativeCommand(sql) + } else { + nodeToPlan(node) match { + case NativePlaceholder => HiveNativeCommand(sql) + case plan => plan } - // Just fake explain for any of the native commands. - case Token("TOK_EXPLAIN", explainArgs) - if noExplainCommands.contains(explainArgs.head.getText) => - ExplainCommand(OneRowRelation) - case Token("TOK_EXPLAIN", explainArgs) - if "TOK_CREATETABLE" == explainArgs.head.getText => - val Some(crtTbl) :: _ :: extended :: Nil = - getClauses(Seq("TOK_CREATETABLE", "FORMATTED", "EXTENDED"), explainArgs) - ExplainCommand( - nodeToPlan(crtTbl, context), - extended = extended.isDefined) - case Token("TOK_EXPLAIN", explainArgs) => - // Ignore FORMATTED if present. - val Some(query) :: _ :: extended :: Nil = - getClauses(Seq("TOK_QUERY", "FORMATTED", "EXTENDED"), explainArgs) - ExplainCommand( - nodeToPlan(query, context), - extended = extended.isDefined) - - case Token("TOK_DESCTABLE", describeArgs) => - // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL - val Some(tableType) :: formatted :: extended :: pretty :: Nil = - getClauses(Seq("TOK_TABTYPE", "FORMATTED", "EXTENDED", "PRETTY"), describeArgs) - if (formatted.isDefined || pretty.isDefined) { - // FORMATTED and PRETTY are not supported and this statement will be treated as - // a Hive native command. - NativePlaceholder - } else { - tableType match { - case Token("TOK_TABTYPE", Token("TOK_TABNAME", nameParts :: Nil) :: Nil) => { - nameParts match { - case Token(".", dbName :: tableName :: Nil) => - // It is describing a table with the format like "describe db.table". - // TODO: Actually, a user may mean tableName.columnName. Need to resolve this issue. - val tableIdent = extractTableIdent(nameParts) - DescribeCommand( - UnresolvedRelation(tableIdent, None), isExtended = extended.isDefined) - case Token(".", dbName :: tableName :: colName :: Nil) => - // It is describing a column with the format like "describe db.table column". - NativePlaceholder - case tableName => - // It is describing a table with the format like "describe table". - DescribeCommand( - UnresolvedRelation(TableIdentifier(tableName.getText), None), - isExtended = extended.isDefined) - } - } - // All other cases. - case _ => NativePlaceholder + } + } + + protected override def isNoExplainCommand(command: String): Boolean = + noExplainCommands.contains(command) + + protected override def nodeToPlan(node: ASTNode): LogicalPlan = { + node match { + // Special drop table that also uncaches. + case Token("TOK_DROPTABLE", Token("TOK_TABNAME", tableNameParts) :: ifExists) => + val tableName = tableNameParts.map { case Token(p, Nil) => p }.mkString(".") + DropTable(tableName, ifExists.nonEmpty) + + // Support "ANALYZE TABLE tableNmae COMPUTE STATISTICS noscan" + case Token("TOK_ANALYZE", + Token("TOK_TAB", Token("TOK_TABNAME", tableNameParts) :: partitionSpec) :: isNoscan) => + // Reference: + // https://cwiki.apache.org/confluence/display/Hive/StatsDev#StatsDev-ExistingTables + if (partitionSpec.nonEmpty) { + // Analyze partitions will be treated as a Hive native command. + NativePlaceholder + } else if (isNoscan.isEmpty) { + // If users do not specify "noscan", it will be treated as a Hive native command. + NativePlaceholder + } else { + val tableName = tableNameParts.map { case Token(p, Nil) => p }.mkString(".") + AnalyzeTable(tableName) } - } - case view @ Token("TOK_ALTERVIEW", children) => - val Some(viewNameParts) :: maybeQuery :: ignores = - getClauses(Seq( - "TOK_TABNAME", - "TOK_QUERY", - "TOK_ALTERVIEW_ADDPARTS", - "TOK_ALTERVIEW_DROPPARTS", - "TOK_ALTERVIEW_PROPERTIES", - "TOK_ALTERVIEW_RENAME"), children) + case view @ Token("TOK_ALTERVIEW", children) => + val Some(nameParts) :: maybeQuery :: _ = + getClauses(Seq( + "TOK_TABNAME", + "TOK_QUERY", + "TOK_ALTERVIEW_ADDPARTS", + "TOK_ALTERVIEW_DROPPARTS", + "TOK_ALTERVIEW_PROPERTIES", + "TOK_ALTERVIEW_RENAME"), children) - // if ALTER VIEW doesn't have query part, let hive to handle it. - maybeQuery.map { query => - createView(view, context, viewNameParts, query, Nil, Map(), false, true) - }.getOrElse(NativePlaceholder) + // if ALTER VIEW doesn't have query part, let hive to handle it. + maybeQuery.map { query => + createView(view, nameParts, query, Nil, Map(), allowExist = false, replace = true) + }.getOrElse(NativePlaceholder) - case view @ Token("TOK_CREATEVIEW", children) + case view @ Token("TOK_CREATEVIEW", children) if children.collect { case t @ Token("TOK_QUERY", _) => t }.nonEmpty => - val Seq( + val Seq( Some(viewNameParts), Some(query), maybeComment, @@ -650,1224 +294,466 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C maybeProperties, maybeColumns, maybePartCols - ) = getClauses(Seq( - "TOK_TABNAME", - "TOK_QUERY", - "TOK_TABLECOMMENT", - "TOK_ORREPLACE", - "TOK_IFNOTEXISTS", - "TOK_TABLEPROPERTIES", - "TOK_TABCOLNAME", - "TOK_VIEWPARTCOLS"), children) - - // If the view is partitioned, we let hive handle it. - if (maybePartCols.isDefined) { - NativePlaceholder - } else { - val schema = maybeColumns.map { cols => - SemanticAnalyzer.getColumns(cols, true).asScala.map { field => + ) = getClauses(Seq( + "TOK_TABNAME", + "TOK_QUERY", + "TOK_TABLECOMMENT", + "TOK_ORREPLACE", + "TOK_IFNOTEXISTS", + "TOK_TABLEPROPERTIES", + "TOK_TABCOLNAME", + "TOK_VIEWPARTCOLS"), children) + + // If the view is partitioned, we let hive handle it. + if (maybePartCols.isDefined) { + NativePlaceholder + } else { + val schema = maybeColumns.map { cols => // We can't specify column types when create view, so fill it with null first, and // update it after the schema has been resolved later. - HiveColumn(field.getName, null, field.getComment) - } - }.getOrElse(Seq.empty[HiveColumn]) - - val properties = scala.collection.mutable.Map.empty[String, String] - - maybeProperties.foreach { - case Token("TOK_TABLEPROPERTIES", list :: Nil) => - properties ++= getProperties(list) - } - - maybeComment.foreach { - case Token("TOK_TABLECOMMENT", child :: Nil) => - val comment = SemanticAnalyzer.unescapeSQLString(child.getText) - if (comment ne null) { - properties += ("comment" -> comment) - } - } - - createView(view, context, viewNameParts, query, schema, properties.toMap, - allowExisting.isDefined, replace.isDefined) - } - - case Token("TOK_CREATETABLE", children) - if children.collect { case t @ Token("TOK_QUERY", _) => t }.nonEmpty => - // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL - val ( - Some(tableNameParts) :: - _ /* likeTable */ :: - externalTable :: - Some(query) :: - allowExisting +: - ignores) = - getClauses( - Seq( - "TOK_TABNAME", - "TOK_LIKETABLE", - "EXTERNAL", - "TOK_QUERY", - "TOK_IFNOTEXISTS", - "TOK_TABLECOMMENT", - "TOK_TABCOLLIST", - "TOK_TABLEPARTCOLS", // Partitioned by - "TOK_TABLEBUCKETS", // Clustered by - "TOK_TABLESKEWED", // Skewed by - "TOK_TABLEROWFORMAT", - "TOK_TABLESERIALIZER", - "TOK_FILEFORMAT_GENERIC", - "TOK_TABLEFILEFORMAT", // User-provided InputFormat and OutputFormat - "TOK_STORAGEHANDLER", // Storage handler - "TOK_TABLELOCATION", - "TOK_TABLEPROPERTIES"), - children) - val TableIdentifier(tblName, dbName) = extractTableIdent(tableNameParts) - - // TODO add bucket support - var tableDesc: HiveTable = HiveTable( - specifiedDatabase = dbName, - name = tblName, - schema = Seq.empty[HiveColumn], - partitionColumns = Seq.empty[HiveColumn], - properties = Map[String, String](), - serdeProperties = Map[String, String](), - tableType = if (externalTable.isDefined) ExternalTable else ManagedTable, - location = None, - inputFormat = None, - outputFormat = None, - serde = None, - viewText = None) - - // default storage type abbreviation (e.g. RCFile, ORC, PARQUET etc.) - val defaultStorageType = hiveConf.getVar(HiveConf.ConfVars.HIVEDEFAULTFILEFORMAT) - // handle the default format for the storage type abbreviation - val hiveSerDe = HiveSerDe.sourceToSerDe(defaultStorageType, hiveConf).getOrElse { - HiveSerDe( - inputFormat = Option("org.apache.hadoop.mapred.TextInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) - } + nodeToColumns(cols, lowerCase = true).map(_.copy(hiveType = null)) + }.getOrElse(Seq.empty[HiveColumn]) - hiveSerDe.inputFormat.foreach(f => tableDesc = tableDesc.copy(inputFormat = Some(f))) - hiveSerDe.outputFormat.foreach(f => tableDesc = tableDesc.copy(outputFormat = Some(f))) - hiveSerDe.serde.foreach(f => tableDesc = tableDesc.copy(serde = Some(f))) + val properties = scala.collection.mutable.Map.empty[String, String] - children.collect { - case list @ Token("TOK_TABCOLLIST", _) => - val cols = SemanticAnalyzer.getColumns(list, true) - if (cols != null) { - tableDesc = tableDesc.copy( - schema = cols.asScala.map { field => - HiveColumn(field.getName, field.getType, field.getComment) - }) + maybeProperties.foreach { + case Token("TOK_TABLEPROPERTIES", list :: Nil) => + properties ++= getProperties(list) } - case Token("TOK_TABLECOMMENT", child :: Nil) => - val comment = SemanticAnalyzer.unescapeSQLString(child.getText) - // TODO support the sql text - tableDesc = tableDesc.copy(viewText = Option(comment)) - case Token("TOK_TABLEPARTCOLS", list @ Token("TOK_TABCOLLIST", _) :: Nil) => - val cols = SemanticAnalyzer.getColumns(list(0), false) - if (cols != null) { - tableDesc = tableDesc.copy( - partitionColumns = cols.asScala.map { field => - HiveColumn(field.getName, field.getType, field.getComment) - }) - } - case Token("TOK_TABLEROWFORMAT", Token("TOK_SERDEPROPS", child :: Nil) :: Nil) => - val serdeParams = new java.util.HashMap[String, String]() - child match { - case Token("TOK_TABLEROWFORMATFIELD", rowChild1 :: rowChild2) => - val fieldDelim = SemanticAnalyzer.unescapeSQLString (rowChild1.getText()) - serdeParams.put(serdeConstants.FIELD_DELIM, fieldDelim) - serdeParams.put(serdeConstants.SERIALIZATION_FORMAT, fieldDelim) - if (rowChild2.length > 1) { - val fieldEscape = SemanticAnalyzer.unescapeSQLString (rowChild2(0).getText) - serdeParams.put(serdeConstants.ESCAPE_CHAR, fieldEscape) - } - case Token("TOK_TABLEROWFORMATCOLLITEMS", rowChild :: Nil) => - val collItemDelim = SemanticAnalyzer.unescapeSQLString(rowChild.getText) - serdeParams.put(serdeConstants.COLLECTION_DELIM, collItemDelim) - case Token("TOK_TABLEROWFORMATMAPKEYS", rowChild :: Nil) => - val mapKeyDelim = SemanticAnalyzer.unescapeSQLString(rowChild.getText) - serdeParams.put(serdeConstants.MAPKEY_DELIM, mapKeyDelim) - case Token("TOK_TABLEROWFORMATLINES", rowChild :: Nil) => - val lineDelim = SemanticAnalyzer.unescapeSQLString(rowChild.getText) - if (!(lineDelim == "\n") && !(lineDelim == "10")) { - throw new AnalysisException( - SemanticAnalyzer.generateErrorMessage( - rowChild, - ErrorMsg.LINES_TERMINATED_BY_NON_NEWLINE.getMsg)) - } - serdeParams.put(serdeConstants.LINE_DELIM, lineDelim) - case Token("TOK_TABLEROWFORMATNULL", rowChild :: Nil) => - val nullFormat = SemanticAnalyzer.unescapeSQLString(rowChild.getText) - // TODO support the nullFormat - case _ => assert(false) - } - tableDesc = tableDesc.copy( - serdeProperties = tableDesc.serdeProperties ++ serdeParams.asScala) - case Token("TOK_TABLELOCATION", child :: Nil) => - var location = SemanticAnalyzer.unescapeSQLString(child.getText) - location = SemanticAnalyzer.relativeToAbsolutePath(hiveConf, location) - tableDesc = tableDesc.copy(location = Option(location)) - case Token("TOK_TABLESERIALIZER", child :: Nil) => - tableDesc = tableDesc.copy( - serde = Option(SemanticAnalyzer.unescapeSQLString(child.getChild(0).getText))) - if (child.getChildCount == 2) { - val serdeParams = new java.util.HashMap[String, String]() - SemanticAnalyzer.readProps( - (child.getChild(1).getChild(0)).asInstanceOf[ASTNode], serdeParams) - tableDesc = tableDesc.copy( - serdeProperties = tableDesc.serdeProperties ++ serdeParams.asScala) - } - case Token("TOK_FILEFORMAT_GENERIC", child :: Nil) => - child.getText().toLowerCase(Locale.ENGLISH) match { - case "orc" => - tableDesc = tableDesc.copy( - inputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) - if (tableDesc.serde.isEmpty) { - tableDesc = tableDesc.copy( - serde = Option("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) - } - - case "parquet" => - tableDesc = tableDesc.copy( - inputFormat = - Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"), - outputFormat = - Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")) - if (tableDesc.serde.isEmpty) { - tableDesc = tableDesc.copy( - serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) - } - - case "rcfile" => - tableDesc = tableDesc.copy( - inputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) - if (tableDesc.serde.isEmpty) { - tableDesc = tableDesc.copy( - serde = Option("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) - } - case "textfile" => - tableDesc = tableDesc.copy( - inputFormat = - Option("org.apache.hadoop.mapred.TextInputFormat"), - outputFormat = - Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) - - case "sequencefile" => - tableDesc = tableDesc.copy( - inputFormat = Option("org.apache.hadoop.mapred.SequenceFileInputFormat"), - outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat")) - - case "avro" => - tableDesc = tableDesc.copy( - inputFormat = - Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat"), - outputFormat = - Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat")) - if (tableDesc.serde.isEmpty) { - tableDesc = tableDesc.copy( - serde = Option("org.apache.hadoop.hive.serde2.avro.AvroSerDe")) + maybeComment.foreach { + case Token("TOK_TABLECOMMENT", child :: Nil) => + val comment = unescapeSQLString(child.text) + if (comment ne null) { + properties += ("comment" -> comment) } - - case _ => - throw new SemanticException( - s"Unrecognized file format in STORED AS clause: ${child.getText}") - } - - case Token("TOK_TABLESERIALIZER", - Token("TOK_SERDENAME", Token(serdeName, Nil) :: otherProps) :: Nil) => - tableDesc = tableDesc.copy(serde = Option(unquoteString(serdeName))) - - otherProps match { - case Token("TOK_TABLEPROPERTIES", list :: Nil) :: Nil => - tableDesc = tableDesc.copy( - serdeProperties = tableDesc.serdeProperties ++ getProperties(list)) - case Nil => } - case Token("TOK_TABLEPROPERTIES", list :: Nil) => - tableDesc = tableDesc.copy(properties = tableDesc.properties ++ getProperties(list)) - case list @ Token("TOK_TABLEFILEFORMAT", children) => - tableDesc = tableDesc.copy( - inputFormat = - Option(SemanticAnalyzer.unescapeSQLString(list.getChild(0).getText)), - outputFormat = - Option(SemanticAnalyzer.unescapeSQLString(list.getChild(1).getText))) - case Token("TOK_STORAGEHANDLER", _) => - throw new AnalysisException(ErrorMsg.CREATE_NON_NATIVE_AS.getMsg()) - case _ => // Unsupport features - } - - CreateTableAsSelect(tableDesc, nodeToPlan(query, context), allowExisting != None) - - // If its not a "CTAS" like above then take it as a native command - case Token("TOK_CREATETABLE", _) => NativePlaceholder - - // Support "TRUNCATE TABLE table_name [PARTITION partition_spec]" - case Token("TOK_TRUNCATETABLE", - Token("TOK_TABLE_PARTITION", table) :: Nil) => NativePlaceholder - - case Token("TOK_QUERY", queryArgs) - if Seq("TOK_CTE", "TOK_FROM", "TOK_INSERT").contains(queryArgs.head.getText) => - - val (fromClause: Option[ASTNode], insertClauses, cteRelations) = - queryArgs match { - case Token("TOK_CTE", ctes) :: Token("TOK_FROM", from) :: inserts => - val cteRelations = ctes.map { node => - val relation = nodeToRelation(node, context).asInstanceOf[Subquery] - relation.alias -> relation - } - (Some(from.head), inserts, Some(cteRelations.toMap)) - case Token("TOK_FROM", from) :: inserts => - (Some(from.head), inserts, None) - case Token("TOK_INSERT", _) :: Nil => - (None, queryArgs, None) + createView(view, viewNameParts, query, schema, properties.toMap, + allowExisting.isDefined, replace.isDefined) } - // Return one query for each insert clause. - val queries = insertClauses.map { case Token("TOK_INSERT", singleInsert) => + case Token("TOK_CREATETABLE", children) + if children.collect { case t @ Token("TOK_QUERY", _) => t }.nonEmpty => + // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL val ( - intoClause :: - destClause :: - selectClause :: - selectDistinctClause :: - whereClause :: - groupByClause :: - rollupGroupByClause :: - cubeGroupByClause :: - groupingSetsClause :: - orderByClause :: - havingClause :: - sortByClause :: - clusterByClause :: - distributeByClause :: - limitClause :: - lateralViewClause :: - windowClause :: Nil) = { + Some(tableNameParts) :: + _ /* likeTable */ :: + externalTable :: + Some(query) :: + allowExisting +: + _) = getClauses( Seq( - "TOK_INSERT_INTO", - "TOK_DESTINATION", - "TOK_SELECT", - "TOK_SELECTDI", - "TOK_WHERE", - "TOK_GROUPBY", - "TOK_ROLLUP_GROUPBY", - "TOK_CUBE_GROUPBY", - "TOK_GROUPING_SETS", - "TOK_ORDERBY", - "TOK_HAVING", - "TOK_SORTBY", - "TOK_CLUSTERBY", - "TOK_DISTRIBUTEBY", - "TOK_LIMIT", - "TOK_LATERAL_VIEW", - "WINDOW"), - singleInsert) + "TOK_TABNAME", + "TOK_LIKETABLE", + "EXTERNAL", + "TOK_QUERY", + "TOK_IFNOTEXISTS", + "TOK_TABLECOMMENT", + "TOK_TABCOLLIST", + "TOK_TABLEPARTCOLS", // Partitioned by + "TOK_TABLEBUCKETS", // Clustered by + "TOK_TABLESKEWED", // Skewed by + "TOK_TABLEROWFORMAT", + "TOK_TABLESERIALIZER", + "TOK_FILEFORMAT_GENERIC", + "TOK_TABLEFILEFORMAT", // User-provided InputFormat and OutputFormat + "TOK_STORAGEHANDLER", // Storage handler + "TOK_TABLELOCATION", + "TOK_TABLEPROPERTIES"), + children) + val TableIdentifier(tblName, dbName) = extractTableIdent(tableNameParts) + + // TODO add bucket support + var tableDesc: HiveTable = HiveTable( + specifiedDatabase = dbName, + name = tblName, + schema = Seq.empty[HiveColumn], + partitionColumns = Seq.empty[HiveColumn], + properties = Map[String, String](), + serdeProperties = Map[String, String](), + tableType = if (externalTable.isDefined) ExternalTable else ManagedTable, + location = None, + inputFormat = None, + outputFormat = None, + serde = None, + viewText = None) + + // default storage type abbreviation (e.g. RCFile, ORC, PARQUET etc.) + val defaultStorageType = hiveConf.getVar(HiveConf.ConfVars.HIVEDEFAULTFILEFORMAT) + // handle the default format for the storage type abbreviation + val hiveSerDe = HiveSerDe.sourceToSerDe(defaultStorageType, hiveConf).getOrElse { + HiveSerDe( + inputFormat = Option("org.apache.hadoop.mapred.TextInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) } - val relations = fromClause match { - case Some(f) => nodeToRelation(f, context) - case None => OneRowRelation - } + hiveSerDe.inputFormat.foreach(f => tableDesc = tableDesc.copy(inputFormat = Some(f))) + hiveSerDe.outputFormat.foreach(f => tableDesc = tableDesc.copy(outputFormat = Some(f))) + hiveSerDe.serde.foreach(f => tableDesc = tableDesc.copy(serde = Some(f))) - val withWhere = whereClause.map { whereNode => - val Seq(whereExpr) = whereNode.getChildren.asScala - Filter(nodeToExpr(whereExpr), relations) - }.getOrElse(relations) - - val select = - (selectClause orElse selectDistinctClause).getOrElse(sys.error("No select clause.")) - - // Script transformations are expressed as a select clause with a single expression of type - // TOK_TRANSFORM - val transformation = select.getChildren.iterator().next() match { - case Token("TOK_SELEXPR", - Token("TOK_TRANSFORM", - Token("TOK_EXPLIST", inputExprs) :: - Token("TOK_SERDE", inputSerdeClause) :: - Token("TOK_RECORDWRITER", writerClause) :: - // TODO: Need to support other types of (in/out)put - Token(script, Nil) :: - Token("TOK_SERDE", outputSerdeClause) :: - Token("TOK_RECORDREADER", readerClause) :: - outputClause) :: Nil) => - - val (output, schemaLess) = outputClause match { - case Token("TOK_ALIASLIST", aliases) :: Nil => - (aliases.map { case Token(name, Nil) => AttributeReference(name, StringType)() }, - false) - case Token("TOK_TABCOLLIST", attributes) :: Nil => - (attributes.map { case Token("TOK_TABCOL", Token(name, Nil) :: dataType :: Nil) => - AttributeReference(name, nodeToDataType(dataType))() }, false) - case Nil => - (List(AttributeReference("key", StringType)(), - AttributeReference("value", StringType)()), true) + children.collect { + case list @ Token("TOK_TABCOLLIST", _) => + val cols = nodeToColumns(list, lowerCase = true) + if (cols != null) { + tableDesc = tableDesc.copy(schema = cols) } - - type SerDeInfo = ( - Seq[(String, String)], // Input row format information - Option[String], // Optional input SerDe class - Seq[(String, String)], // Input SerDe properties - Boolean // Whether to use default record reader/writer - ) - - def matchSerDe(clause: Seq[ASTNode]): SerDeInfo = clause match { - case Token("TOK_SERDEPROPS", propsClause) :: Nil => - val rowFormat = propsClause.map { - case Token(name, Token(value, Nil) :: Nil) => (name, value) + case Token("TOK_TABLECOMMENT", child :: Nil) => + val comment = unescapeSQLString(child.text) + // TODO support the sql text + tableDesc = tableDesc.copy(viewText = Option(comment)) + case Token("TOK_TABLEPARTCOLS", list @ Token("TOK_TABCOLLIST", _) :: Nil) => + val cols = nodeToColumns(list.head, lowerCase = false) + if (cols != null) { + tableDesc = tableDesc.copy(partitionColumns = cols) + } + case Token("TOK_TABLEROWFORMAT", Token("TOK_SERDEPROPS", child :: Nil) :: Nil) => + val serdeParams = new java.util.HashMap[String, String]() + child match { + case Token("TOK_TABLEROWFORMATFIELD", rowChild1 :: rowChild2) => + val fieldDelim = unescapeSQLString (rowChild1.text) + serdeParams.put(serdeConstants.FIELD_DELIM, fieldDelim) + serdeParams.put(serdeConstants.SERIALIZATION_FORMAT, fieldDelim) + if (rowChild2.length > 1) { + val fieldEscape = unescapeSQLString (rowChild2.head.text) + serdeParams.put(serdeConstants.ESCAPE_CHAR, fieldEscape) } - (rowFormat, None, Nil, false) - - case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Nil) :: Nil => - (Nil, Some(SemanticAnalyzer.unescapeSQLString(serdeClass)), Nil, false) - - case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: - Token("TOK_TABLEPROPERTIES", - Token("TOK_TABLEPROPLIST", propsClause) :: Nil) :: Nil) :: Nil => - val serdeProps = propsClause.map { - case Token("TOK_TABLEPROPERTY", Token(name, Nil) :: Token(value, Nil) :: Nil) => - (SemanticAnalyzer.unescapeSQLString(name), - SemanticAnalyzer.unescapeSQLString(value)) + case Token("TOK_TABLEROWFORMATCOLLITEMS", rowChild :: Nil) => + val collItemDelim = unescapeSQLString(rowChild.text) + serdeParams.put(serdeConstants.COLLECTION_DELIM, collItemDelim) + case Token("TOK_TABLEROWFORMATMAPKEYS", rowChild :: Nil) => + val mapKeyDelim = unescapeSQLString(rowChild.text) + serdeParams.put(serdeConstants.MAPKEY_DELIM, mapKeyDelim) + case Token("TOK_TABLEROWFORMATLINES", rowChild :: Nil) => + val lineDelim = unescapeSQLString(rowChild.text) + if (!(lineDelim == "\n") && !(lineDelim == "10")) { + throw new AnalysisException( + s"LINES TERMINATED BY only supports newline '\\n' right now: $rowChild") } - - // SPARK-10310: Special cases LazySimpleSerDe - // TODO Fully supports user-defined record reader/writer classes - val unescapedSerDeClass = SemanticAnalyzer.unescapeSQLString(serdeClass) - val useDefaultRecordReaderWriter = - unescapedSerDeClass == classOf[LazySimpleSerDe].getCanonicalName - (Nil, Some(unescapedSerDeClass), serdeProps, useDefaultRecordReaderWriter) - - case Nil => - // Uses default TextRecordReader/TextRecordWriter, sets field delimiter here - val serdeProps = Seq(serdeConstants.FIELD_DELIM -> "\t") - (Nil, Option(hiveConf.getVar(ConfVars.HIVESCRIPTSERDE)), serdeProps, true) - } - - val (inRowFormat, inSerdeClass, inSerdeProps, useDefaultRecordReader) = - matchSerDe(inputSerdeClause) - - val (outRowFormat, outSerdeClass, outSerdeProps, useDefaultRecordWriter) = - matchSerDe(outputSerdeClause) - - val unescapedScript = SemanticAnalyzer.unescapeSQLString(script) - - // TODO Adds support for user-defined record reader/writer classes - val recordReaderClass = if (useDefaultRecordReader) { - Option(hiveConf.getVar(ConfVars.HIVESCRIPTRECORDREADER)) - } else { - None + serdeParams.put(serdeConstants.LINE_DELIM, lineDelim) + case Token("TOK_TABLEROWFORMATNULL", rowChild :: Nil) => + val nullFormat = unescapeSQLString(rowChild.text) + // TODO support the nullFormat + case _ => assert(false) } - - val recordWriterClass = if (useDefaultRecordWriter) { - Option(hiveConf.getVar(ConfVars.HIVESCRIPTRECORDWRITER)) - } else { - None + tableDesc = tableDesc.copy( + serdeProperties = tableDesc.serdeProperties ++ serdeParams.asScala) + case Token("TOK_TABLELOCATION", child :: Nil) => + val location = EximUtil.relativeToAbsolutePath(hiveConf, unescapeSQLString(child.text)) + tableDesc = tableDesc.copy(location = Option(location)) + case Token("TOK_TABLESERIALIZER", child :: Nil) => + tableDesc = tableDesc.copy( + serde = Option(unescapeSQLString(child.children.head.text))) + if (child.numChildren == 2) { + // This is based on the readProps(..) method in + // ql/src/java/org/apache/hadoop/hive/ql/parse/BaseSemanticAnalyzer.java: + val serdeParams = child.children(1).children.head.children.map { + case Token(_, Token(prop, Nil) :: valueNode) => + val value = valueNode.headOption + .map(_.text) + .map(unescapeSQLString) + .orNull + (unescapeSQLString(prop), value) + }.toMap + tableDesc = tableDesc.copy(serdeProperties = tableDesc.serdeProperties ++ serdeParams) } + case Token("TOK_FILEFORMAT_GENERIC", child :: Nil) => + child.text.toLowerCase(Locale.ENGLISH) match { + case "orc" => + tableDesc = tableDesc.copy( + inputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) + if (tableDesc.serde.isEmpty) { + tableDesc = tableDesc.copy( + serde = Option("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) + } - val schema = HiveScriptIOSchema( - inRowFormat, outRowFormat, - inSerdeClass, outSerdeClass, - inSerdeProps, outSerdeProps, - recordReaderClass, recordWriterClass, - schemaLess) - - Some( - logical.ScriptTransformation( - inputExprs.map(nodeToExpr), - unescapedScript, - output, - withWhere, schema)) - case _ => None - } - - val withLateralView = lateralViewClause.map { lv => - val Token("TOK_SELECT", - Token("TOK_SELEXPR", clauses) :: Nil) = lv.getChildren.iterator().next() - - val alias = getClause("TOK_TABALIAS", clauses).getChildren.iterator().next() - .asInstanceOf[ASTNode].getText - - val (generator, attributes) = nodesToGenerator(clauses) - Generate( - generator, - join = true, - outer = false, - Some(alias.toLowerCase), - attributes.map(UnresolvedAttribute(_)), - withWhere) - }.getOrElse(withWhere) - - // The projection of the query can either be a normal projection, an aggregation - // (if there is a group by) or a script transformation. - val withProject: LogicalPlan = transformation.getOrElse { - val selectExpressions = - select.getChildren.asScala.flatMap(selExprNodeToExpr).map(UnresolvedAlias(_)) - Seq( - groupByClause.map(e => e match { - case Token("TOK_GROUPBY", children) => - // Not a transformation so must be either project or aggregation. - Aggregate(children.map(nodeToExpr), selectExpressions, withLateralView) - case _ => sys.error("Expect GROUP BY") - }), - groupingSetsClause.map(e => e match { - case Token("TOK_GROUPING_SETS", children) => - val(groupByExprs, masks) = extractGroupingSet(children) - GroupingSets(masks, groupByExprs, withLateralView, selectExpressions) - case _ => sys.error("Expect GROUPING SETS") - }), - rollupGroupByClause.map(e => e match { - case Token("TOK_ROLLUP_GROUPBY", children) => - Aggregate(Seq(Rollup(children.map(nodeToExpr))), selectExpressions, withLateralView) - case _ => sys.error("Expect WITH ROLLUP") - }), - cubeGroupByClause.map(e => e match { - case Token("TOK_CUBE_GROUPBY", children) => - Aggregate(Seq(Cube(children.map(nodeToExpr))), selectExpressions, withLateralView) - case _ => sys.error("Expect WITH CUBE") - }), - Some(Project(selectExpressions, withLateralView))).flatten.head - } + case "parquet" => + tableDesc = tableDesc.copy( + inputFormat = + Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"), + outputFormat = + Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")) + if (tableDesc.serde.isEmpty) { + tableDesc = tableDesc.copy( + serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) + } - // Handle HAVING clause. - val withHaving = havingClause.map { h => - val havingExpr = h.getChildren.asScala match { case Seq(hexpr) => nodeToExpr(hexpr) } - // Note that we added a cast to boolean. If the expression itself is already boolean, - // the optimizer will get rid of the unnecessary cast. - Filter(Cast(havingExpr, BooleanType), withProject) - }.getOrElse(withProject) - - // Handle SELECT DISTINCT - val withDistinct = - if (selectDistinctClause.isDefined) Distinct(withHaving) else withHaving - - // Handle ORDER BY, SORT BY, DISTRIBUTE BY, and CLUSTER BY clause. - val withSort = - (orderByClause, sortByClause, distributeByClause, clusterByClause) match { - case (Some(totalOrdering), None, None, None) => - Sort(totalOrdering.getChildren.asScala.map(nodeToSortOrder), true, withDistinct) - case (None, Some(perPartitionOrdering), None, None) => - Sort( - perPartitionOrdering.getChildren.asScala.map(nodeToSortOrder), - false, withDistinct) - case (None, None, Some(partitionExprs), None) => - RepartitionByExpression( - partitionExprs.getChildren.asScala.map(nodeToExpr), withDistinct) - case (None, Some(perPartitionOrdering), Some(partitionExprs), None) => - Sort( - perPartitionOrdering.getChildren.asScala.map(nodeToSortOrder), false, - RepartitionByExpression( - partitionExprs.getChildren.asScala.map(nodeToExpr), - withDistinct)) - case (None, None, None, Some(clusterExprs)) => - Sort( - clusterExprs.getChildren.asScala.map(nodeToExpr).map(SortOrder(_, Ascending)), - false, - RepartitionByExpression( - clusterExprs.getChildren.asScala.map(nodeToExpr), - withDistinct)) - case (None, None, None, None) => withDistinct - case _ => sys.error("Unsupported set of ordering / distribution clauses.") - } + case "rcfile" => + tableDesc = tableDesc.copy( + inputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) + if (tableDesc.serde.isEmpty) { + tableDesc = tableDesc.copy(serde = + Option("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) + } - val withLimit = - limitClause.map(l => nodeToExpr(l.getChildren.iterator().next())) - .map(Limit(_, withSort)) - .getOrElse(withSort) - - // Collect all window specifications defined in the WINDOW clause. - val windowDefinitions = windowClause.map(_.getChildren.asScala.collect { - case Token("TOK_WINDOWDEF", - Token(windowName, Nil) :: Token("TOK_WINDOWSPEC", spec) :: Nil) => - windowName -> nodesToWindowSpecification(spec) - }.toMap) - // Handle cases like - // window w1 as (partition by p_mfgr order by p_name - // range between 2 preceding and 2 following), - // w2 as w1 - val resolvedCrossReference = windowDefinitions.map { - windowDefMap => windowDefMap.map { - case (windowName, WindowSpecReference(other)) => - (windowName, windowDefMap(other).asInstanceOf[WindowSpecDefinition]) - case o => o.asInstanceOf[(String, WindowSpecDefinition)] - } - } + case "textfile" => + tableDesc = tableDesc.copy( + inputFormat = + Option("org.apache.hadoop.mapred.TextInputFormat"), + outputFormat = + Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) - val withWindowDefinitions = - resolvedCrossReference.map(WithWindowDefinition(_, withLimit)).getOrElse(withLimit) - - // TOK_INSERT_INTO means to add files to the table. - // TOK_DESTINATION means to overwrite the table. - val resultDestination = - (intoClause orElse destClause).getOrElse(sys.error("No destination found.")) - val overwrite = intoClause.isEmpty - nodeToDest( - resultDestination, - withWindowDefinitions, - overwrite) - } + case "sequencefile" => + tableDesc = tableDesc.copy( + inputFormat = Option("org.apache.hadoop.mapred.SequenceFileInputFormat"), + outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat")) - // If there are multiple INSERTS just UNION them together into on query. - val query = queries.reduceLeft(Union) + case "avro" => + tableDesc = tableDesc.copy( + inputFormat = + Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat"), + outputFormat = + Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat")) + if (tableDesc.serde.isEmpty) { + tableDesc = tableDesc.copy( + serde = Option("org.apache.hadoop.hive.serde2.avro.AvroSerDe")) + } - // return With plan if there is CTE - cteRelations.map(With(query, _)).getOrElse(query) + case _ => + throw new AnalysisException( + s"Unrecognized file format in STORED AS clause: ${child.text}") + } - // HIVE-9039 renamed TOK_UNION => TOK_UNIONALL while adding TOK_UNIONDISTINCT - case Token("TOK_UNIONALL", left :: right :: Nil) => - Union(nodeToPlan(left, context), nodeToPlan(right, context)) + case Token("TOK_TABLESERIALIZER", + Token("TOK_SERDENAME", Token(serdeName, Nil) :: otherProps) :: Nil) => + tableDesc = tableDesc.copy(serde = Option(unquoteString(serdeName))) - case a: ASTNode => - throw new NotImplementedError(s"No parse rules for $node:\n ${dumpTree(a).toString} ") - } + otherProps match { + case Token("TOK_TABLEPROPERTIES", list :: Nil) :: Nil => + tableDesc = tableDesc.copy( + serdeProperties = tableDesc.serdeProperties ++ getProperties(list)) + case _ => + } - val allJoinTokens = "(TOK_.*JOIN)".r - val laterViewToken = "TOK_LATERAL_VIEW(.*)".r - def nodeToRelation(node: Node, context: Context): LogicalPlan = node match { - case Token("TOK_SUBQUERY", - query :: Token(alias, Nil) :: Nil) => - Subquery(cleanIdentifier(alias), nodeToPlan(query, context)) - - case Token(laterViewToken(isOuter), selectClause :: relationClause :: Nil) => - val Token("TOK_SELECT", - Token("TOK_SELEXPR", clauses) :: Nil) = selectClause - - val alias = getClause("TOK_TABALIAS", clauses).getChildren.iterator().next() - .asInstanceOf[ASTNode].getText - - val (generator, attributes) = nodesToGenerator(clauses) - Generate( - generator, - join = true, - outer = isOuter.nonEmpty, - Some(alias.toLowerCase), - attributes.map(UnresolvedAttribute(_)), - nodeToRelation(relationClause, context)) - - /* All relations, possibly with aliases or sampling clauses. */ - case Token("TOK_TABREF", clauses) => - // If the last clause is not a token then it's the alias of the table. - val (nonAliasClauses, aliasClause) = - if (clauses.last.getText.startsWith("TOK")) { - (clauses, None) - } else { - (clauses.dropRight(1), Some(clauses.last)) + case Token("TOK_TABLEPROPERTIES", list :: Nil) => + tableDesc = tableDesc.copy(properties = tableDesc.properties ++ getProperties(list)) + case list @ Token("TOK_TABLEFILEFORMAT", _) => + tableDesc = tableDesc.copy( + inputFormat = + Option(unescapeSQLString(list.children.head.text)), + outputFormat = + Option(unescapeSQLString(list.children(1).text))) + case Token("TOK_STORAGEHANDLER", _) => + throw new AnalysisException( + "CREATE TABLE AS SELECT cannot be used for a non-native table") + case _ => // Unsupport features } - val (Some(tableNameParts) :: - splitSampleClause :: - bucketSampleClause :: Nil) = { - getClauses(Seq("TOK_TABNAME", "TOK_TABLESPLITSAMPLE", "TOK_TABLEBUCKETSAMPLE"), - nonAliasClauses) - } - - val tableIdent = extractTableIdent(tableNameParts) - val alias = aliasClause.map { case Token(a, Nil) => cleanIdentifier(a) } - val relation = UnresolvedRelation(tableIdent, alias) - - // Apply sampling if requested. - (bucketSampleClause orElse splitSampleClause).map { - case Token("TOK_TABLESPLITSAMPLE", - Token("TOK_ROWCOUNT", Nil) :: - Token(count, Nil) :: Nil) => - Limit(Literal(count.toInt), relation) - case Token("TOK_TABLESPLITSAMPLE", - Token("TOK_PERCENT", Nil) :: - Token(fraction, Nil) :: Nil) => - // The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling - // function takes X PERCENT as the input and the range of X is [0, 100], we need to - // adjust the fraction. - require( - fraction.toDouble >= (0.0 - RandomSampler.roundingEpsilon) - && fraction.toDouble <= (100.0 + RandomSampler.roundingEpsilon), - s"Sampling fraction ($fraction) must be on interval [0, 100]") - Sample(0.0, fraction.toDouble / 100, withReplacement = false, (math.random * 1000).toInt, - relation) - case Token("TOK_TABLEBUCKETSAMPLE", - Token(numerator, Nil) :: - Token(denominator, Nil) :: Nil) => - val fraction = numerator.toDouble / denominator.toDouble - Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, relation) - case a: ASTNode => - throw new NotImplementedError( - s"""No parse rules for sampling clause: ${a.getType}, text: ${a.getText} : - |${dumpTree(a).toString}" + - """.stripMargin) - }.getOrElse(relation) - - case Token("TOK_UNIQUEJOIN", joinArgs) => - val tableOrdinals = - joinArgs.zipWithIndex.filter { - case (arg, i) => arg.getText == "TOK_TABREF" - }.map(_._2) - - val isPreserved = tableOrdinals.map(i => (i - 1 < 0) || joinArgs(i - 1).getText == "PRESERVE") - val tables = tableOrdinals.map(i => nodeToRelation(joinArgs(i), context)) - val joinExpressions = - tableOrdinals.map(i => joinArgs(i + 1).getChildren.asScala.map(nodeToExpr)) - - val joinConditions = joinExpressions.sliding(2).map { - case Seq(c1, c2) => - val predicates = (c1, c2).zipped.map { case (e1, e2) => EqualTo(e1, e2): Expression } - predicates.reduceLeft(And) - }.toBuffer - - val joinType = isPreserved.sliding(2).map { - case Seq(true, true) => FullOuter - case Seq(true, false) => LeftOuter - case Seq(false, true) => RightOuter - case Seq(false, false) => Inner - }.toBuffer - - val joinedTables = tables.reduceLeft(Join(_, _, Inner, None)) - - // Must be transform down. - val joinedResult = joinedTables transform { - case j: Join => - j.copy( - condition = Some(joinConditions.remove(joinConditions.length - 1)), - joinType = joinType.remove(joinType.length - 1)) - } - - val groups = joinExpressions.head.indices.map(i => Coalesce(joinExpressions.map(_(i)))) - - // Unique join is not really the same as an outer join so we must group together results where - // the joinExpressions are the same, taking the First of each value is only okay because the - // user of a unique join is implicitly promising that there is only one result. - // TODO: This doesn't actually work since [[Star]] is not a valid aggregate expression. - // instead we should figure out how important supporting this feature is and whether it is - // worth the number of hacks that will be required to implement it. Namely, we need to add - // some sort of mapped star expansion that would expand all child output row to be similarly - // named output expressions where some aggregate expression has been applied (i.e. First). - // Aggregate(groups, Star(None, First(_)) :: Nil, joinedResult) - throw new UnsupportedOperationException - - case Token(allJoinTokens(joinToken), - relation1 :: - relation2 :: other) => - if (!(other.size <= 1)) { - sys.error(s"Unsupported join operation: $other") - } - - val joinType = joinToken match { - case "TOK_JOIN" => Inner - case "TOK_CROSSJOIN" => Inner - case "TOK_RIGHTOUTERJOIN" => RightOuter - case "TOK_LEFTOUTERJOIN" => LeftOuter - case "TOK_FULLOUTERJOIN" => FullOuter - case "TOK_LEFTSEMIJOIN" => LeftSemi - case "TOK_ANTIJOIN" => throw new NotImplementedError("Anti join not supported") - } - Join(nodeToRelation(relation1, context), - nodeToRelation(relation2, context), - joinType, - other.headOption.map(nodeToExpr)) - - case a: ASTNode => - throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") - } + CreateTableAsSelect(tableDesc, nodeToPlan(query), allowExisting.isDefined) - def nodeToSortOrder(node: Node): SortOrder = node match { - case Token("TOK_TABSORTCOLNAMEASC", sortExpr :: Nil) => - SortOrder(nodeToExpr(sortExpr), Ascending) - case Token("TOK_TABSORTCOLNAMEDESC", sortExpr :: Nil) => - SortOrder(nodeToExpr(sortExpr), Descending) + // If its not a "CTAS" like above then take it as a native command + case Token("TOK_CREATETABLE", _) => + NativePlaceholder - case a: ASTNode => - throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") - } + // Support "TRUNCATE TABLE table_name [PARTITION partition_spec]" + case Token("TOK_TRUNCATETABLE", Token("TOK_TABLE_PARTITION", table) :: Nil) => + NativePlaceholder - val destinationToken = "TOK_DESTINATION|TOK_INSERT_INTO".r - protected def nodeToDest( - node: Node, - query: LogicalPlan, - overwrite: Boolean): LogicalPlan = node match { - case Token(destinationToken(), - Token("TOK_DIR", - Token("TOK_TMP_FILE", Nil) :: Nil) :: Nil) => - query - - case Token(destinationToken(), - Token("TOK_TAB", - tableArgs) :: Nil) => - val Some(tableNameParts) :: partitionClause :: Nil = - getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs) - - val tableIdent = extractTableIdent(tableNameParts) - - val partitionKeys = partitionClause.map(_.getChildren.asScala.map { - // Parse partitions. We also make keys case insensitive. - case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) => - cleanIdentifier(key.toLowerCase) -> Some(PlanUtils.stripQuotes(value)) - case Token("TOK_PARTVAL", Token(key, Nil) :: Nil) => - cleanIdentifier(key.toLowerCase) -> None - }.toMap).getOrElse(Map.empty) - - InsertIntoTable(UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite, false) - - case Token(destinationToken(), - Token("TOK_TAB", - tableArgs) :: - Token("TOK_IFNOTEXISTS", - ifNotExists) :: Nil) => - val Some(tableNameParts) :: partitionClause :: Nil = - getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs) - - val tableIdent = extractTableIdent(tableNameParts) - - val partitionKeys = partitionClause.map(_.getChildren.asScala.map { - // Parse partitions. We also make keys case insensitive. - case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) => - cleanIdentifier(key.toLowerCase) -> Some(PlanUtils.stripQuotes(value)) - case Token("TOK_PARTVAL", Token(key, Nil) :: Nil) => - cleanIdentifier(key.toLowerCase) -> None - }.toMap).getOrElse(Map.empty) - - InsertIntoTable(UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite, true) - - case a: ASTNode => - throw new NotImplementedError(s"No parse rules for ${a.getName}:" + - s"\n ${dumpTree(a).toString} ") + case _ => + super.nodeToPlan(node) + } } - protected def selExprNodeToExpr(node: Node): Option[Expression] = node match { - case Token("TOK_SELEXPR", e :: Nil) => - Some(nodeToExpr(e)) - - case Token("TOK_SELEXPR", e :: Token(alias, Nil) :: Nil) => - Some(Alias(nodeToExpr(e), cleanIdentifier(alias))()) - - case Token("TOK_SELEXPR", e :: aliasChildren) => - var aliasNames = ArrayBuffer[String]() - aliasChildren.foreach { _ match { - case Token(name, Nil) => aliasNames += cleanIdentifier(name) + protected override def nodeToDescribeFallback(node: ASTNode): LogicalPlan = NativePlaceholder + + protected override def nodeToTransformation( + node: ASTNode, + child: LogicalPlan): Option[ScriptTransformation] = node match { + case Token("TOK_SELEXPR", + Token("TOK_TRANSFORM", + Token("TOK_EXPLIST", inputExprs) :: + Token("TOK_SERDE", inputSerdeClause) :: + Token("TOK_RECORDWRITER", writerClause) :: + // TODO: Need to support other types of (in/out)put + Token(script, Nil) :: + Token("TOK_SERDE", outputSerdeClause) :: + Token("TOK_RECORDREADER", readerClause) :: + outputClause) :: Nil) => + + val (output, schemaLess) = outputClause match { + case Token("TOK_ALIASLIST", aliases) :: Nil => + (aliases.map { case Token(name, Nil) => AttributeReference(name, StringType)() }, + false) + case Token("TOK_TABCOLLIST", attributes) :: Nil => + (attributes.map { case Token("TOK_TABCOL", Token(name, Nil) :: dataType :: Nil) => + AttributeReference(name, nodeToDataType(dataType))() }, false) + case Nil => + (List(AttributeReference("key", StringType)(), + AttributeReference("value", StringType)()), true) case _ => - } + noParseRule("Transform", node) } - Some(MultiAlias(nodeToExpr(e), aliasNames)) - - /* Hints are ignored */ - case Token("TOK_HINTLIST", _) => None - case a: ASTNode => - throw new NotImplementedError(s"No parse rules for ${a.getName }:" + - s"\n ${dumpTree(a).toString } ") - } - - protected val escapedIdentifier = "`([^`]+)`".r - protected val doubleQuotedString = "\"([^\"]+)\"".r - protected val singleQuotedString = "'([^']+)'".r + type SerDeInfo = ( + Seq[(String, String)], // Input row format information + Option[String], // Optional input SerDe class + Seq[(String, String)], // Input SerDe properties + Boolean // Whether to use default record reader/writer + ) + + def matchSerDe(clause: Seq[ASTNode]): SerDeInfo = clause match { + case Token("TOK_SERDEPROPS", propsClause) :: Nil => + val rowFormat = propsClause.map { + case Token(name, Token(value, Nil) :: Nil) => (name, value) + } + (rowFormat, None, Nil, false) - protected def unquoteString(str: String) = str match { - case singleQuotedString(s) => s - case doubleQuotedString(s) => s - case other => other - } + case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Nil) :: Nil => + (Nil, Some(unescapeSQLString(serdeClass)), Nil, false) - /** Strips backticks from ident if present */ - protected def cleanIdentifier(ident: String): String = ident match { - case escapedIdentifier(i) => i - case plainIdent => plainIdent - } + case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: + Token("TOK_TABLEPROPERTIES", + Token("TOK_TABLEPROPLIST", propsClause) :: Nil) :: Nil) :: Nil => + val serdeProps = propsClause.map { + case Token("TOK_TABLEPROPERTY", Token(name, Nil) :: Token(value, Nil) :: Nil) => + (unescapeSQLString(name), unescapeSQLString(value)) + } - val numericAstTypes = Seq( - SparkSqlParser.Number, - SparkSqlParser.TinyintLiteral, - SparkSqlParser.SmallintLiteral, - SparkSqlParser.BigintLiteral, - SparkSqlParser.DecimalLiteral) - - /* Case insensitive matches */ - val COUNT = "(?i)COUNT".r - val SUM = "(?i)SUM".r - val AND = "(?i)AND".r - val OR = "(?i)OR".r - val NOT = "(?i)NOT".r - val TRUE = "(?i)TRUE".r - val FALSE = "(?i)FALSE".r - val LIKE = "(?i)LIKE".r - val RLIKE = "(?i)RLIKE".r - val REGEXP = "(?i)REGEXP".r - val IN = "(?i)IN".r - val DIV = "(?i)DIV".r - val BETWEEN = "(?i)BETWEEN".r - val WHEN = "(?i)WHEN".r - val CASE = "(?i)CASE".r - - protected def nodeToExpr(node: Node): Expression = node match { - /* Attribute References */ - case Token("TOK_TABLE_OR_COL", - Token(name, Nil) :: Nil) => - UnresolvedAttribute.quoted(cleanIdentifier(name)) - case Token(".", qualifier :: Token(attr, Nil) :: Nil) => - nodeToExpr(qualifier) match { - case UnresolvedAttribute(nameParts) => - UnresolvedAttribute(nameParts :+ cleanIdentifier(attr)) - case other => UnresolvedExtractValue(other, Literal(attr)) + // SPARK-10310: Special cases LazySimpleSerDe + // TODO Fully supports user-defined record reader/writer classes + val unescapedSerDeClass = unescapeSQLString(serdeClass) + val useDefaultRecordReaderWriter = + unescapedSerDeClass == classOf[LazySimpleSerDe].getCanonicalName + (Nil, Some(unescapedSerDeClass), serdeProps, useDefaultRecordReaderWriter) + + case Nil => + // Uses default TextRecordReader/TextRecordWriter, sets field delimiter here + val serdeProps = Seq(serdeConstants.FIELD_DELIM -> "\t") + (Nil, Option(hiveConf.getVar(ConfVars.HIVESCRIPTSERDE)), serdeProps, true) } - /* Stars (*) */ - case Token("TOK_ALLCOLREF", Nil) => UnresolvedStar(None) - // The format of dbName.tableName.* cannot be parsed by HiveParser. TOK_TABNAME will only - // has a single child which is tableName. - case Token("TOK_ALLCOLREF", Token("TOK_TABNAME", Token(name, Nil) :: Nil) :: Nil) => - UnresolvedStar(Some(UnresolvedAttribute.parseAttributeName(name))) - - /* Aggregate Functions */ - case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) => - Count(args.map(nodeToExpr)).toAggregateExpression(isDistinct = true) - case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) => - Count(Literal(1)).toAggregateExpression() - - /* Casts */ - case Token("TOK_FUNCTION", Token("TOK_STRING", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), StringType) - case Token("TOK_FUNCTION", Token("TOK_VARCHAR", _) :: arg :: Nil) => - Cast(nodeToExpr(arg), StringType) - case Token("TOK_FUNCTION", Token("TOK_CHAR", _) :: arg :: Nil) => - Cast(nodeToExpr(arg), StringType) - case Token("TOK_FUNCTION", Token("TOK_INT", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), IntegerType) - case Token("TOK_FUNCTION", Token("TOK_BIGINT", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), LongType) - case Token("TOK_FUNCTION", Token("TOK_FLOAT", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), FloatType) - case Token("TOK_FUNCTION", Token("TOK_DOUBLE", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DoubleType) - case Token("TOK_FUNCTION", Token("TOK_SMALLINT", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), ShortType) - case Token("TOK_FUNCTION", Token("TOK_TINYINT", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), ByteType) - case Token("TOK_FUNCTION", Token("TOK_BINARY", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), BinaryType) - case Token("TOK_FUNCTION", Token("TOK_BOOLEAN", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), BooleanType) - case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: scale :: nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DecimalType(precision.getText.toInt, scale.getText.toInt)) - case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DecimalType(precision.getText.toInt, 0)) - case Token("TOK_FUNCTION", Token("TOK_DECIMAL", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DecimalType.USER_DEFAULT) - case Token("TOK_FUNCTION", Token("TOK_TIMESTAMP", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), TimestampType) - case Token("TOK_FUNCTION", Token("TOK_DATE", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DateType) - - /* Arithmetic */ - case Token("+", child :: Nil) => nodeToExpr(child) - case Token("-", child :: Nil) => UnaryMinus(nodeToExpr(child)) - case Token("~", child :: Nil) => BitwiseNot(nodeToExpr(child)) - case Token("+", left :: right:: Nil) => Add(nodeToExpr(left), nodeToExpr(right)) - case Token("-", left :: right:: Nil) => Subtract(nodeToExpr(left), nodeToExpr(right)) - case Token("*", left :: right:: Nil) => Multiply(nodeToExpr(left), nodeToExpr(right)) - case Token("/", left :: right:: Nil) => Divide(nodeToExpr(left), nodeToExpr(right)) - case Token(DIV(), left :: right:: Nil) => - Cast(Divide(nodeToExpr(left), nodeToExpr(right)), LongType) - case Token("%", left :: right:: Nil) => Remainder(nodeToExpr(left), nodeToExpr(right)) - case Token("&", left :: right:: Nil) => BitwiseAnd(nodeToExpr(left), nodeToExpr(right)) - case Token("|", left :: right:: Nil) => BitwiseOr(nodeToExpr(left), nodeToExpr(right)) - case Token("^", left :: right:: Nil) => BitwiseXor(nodeToExpr(left), nodeToExpr(right)) - - /* Comparisons */ - case Token("=", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right)) - case Token("==", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right)) - case Token("<=>", left :: right:: Nil) => EqualNullSafe(nodeToExpr(left), nodeToExpr(right)) - case Token("!=", left :: right:: Nil) => Not(EqualTo(nodeToExpr(left), nodeToExpr(right))) - case Token("<>", left :: right:: Nil) => Not(EqualTo(nodeToExpr(left), nodeToExpr(right))) - case Token(">", left :: right:: Nil) => GreaterThan(nodeToExpr(left), nodeToExpr(right)) - case Token(">=", left :: right:: Nil) => GreaterThanOrEqual(nodeToExpr(left), nodeToExpr(right)) - case Token("<", left :: right:: Nil) => LessThan(nodeToExpr(left), nodeToExpr(right)) - case Token("<=", left :: right:: Nil) => LessThanOrEqual(nodeToExpr(left), nodeToExpr(right)) - case Token(LIKE(), left :: right:: Nil) => Like(nodeToExpr(left), nodeToExpr(right)) - case Token(RLIKE(), left :: right:: Nil) => RLike(nodeToExpr(left), nodeToExpr(right)) - case Token(REGEXP(), left :: right:: Nil) => RLike(nodeToExpr(left), nodeToExpr(right)) - case Token("TOK_FUNCTION", Token("TOK_ISNOTNULL", Nil) :: child :: Nil) => - IsNotNull(nodeToExpr(child)) - case Token("TOK_FUNCTION", Token("TOK_ISNULL", Nil) :: child :: Nil) => - IsNull(nodeToExpr(child)) - case Token("TOK_FUNCTION", Token(IN(), Nil) :: value :: list) => - In(nodeToExpr(value), list.map(nodeToExpr)) - case Token("TOK_FUNCTION", - Token(BETWEEN(), Nil) :: - kw :: - target :: - minValue :: - maxValue :: Nil) => - - val targetExpression = nodeToExpr(target) - val betweenExpr = - And( - GreaterThanOrEqual(targetExpression, nodeToExpr(minValue)), - LessThanOrEqual(targetExpression, nodeToExpr(maxValue))) - kw match { - case Token("KW_FALSE", Nil) => betweenExpr - case Token("KW_TRUE", Nil) => Not(betweenExpr) - } + val (inRowFormat, inSerdeClass, inSerdeProps, useDefaultRecordReader) = + matchSerDe(inputSerdeClause) - /* Boolean Logic */ - case Token(AND(), left :: right:: Nil) => And(nodeToExpr(left), nodeToExpr(right)) - case Token(OR(), left :: right:: Nil) => Or(nodeToExpr(left), nodeToExpr(right)) - case Token(NOT(), child :: Nil) => Not(nodeToExpr(child)) - case Token("!", child :: Nil) => Not(nodeToExpr(child)) - - /* Case statements */ - case Token("TOK_FUNCTION", Token(WHEN(), Nil) :: branches) => - CaseWhen(branches.map(nodeToExpr)) - case Token("TOK_FUNCTION", Token(CASE(), Nil) :: branches) => - val keyExpr = nodeToExpr(branches.head) - CaseKeyWhen(keyExpr, branches.drop(1).map(nodeToExpr)) - - /* Complex datatype manipulation */ - case Token("[", child :: ordinal :: Nil) => - UnresolvedExtractValue(nodeToExpr(child), nodeToExpr(ordinal)) - - /* Window Functions */ - case Token(name, args :+ Token("TOK_WINDOWSPEC", spec)) => - val function = nodeToExpr(Token(name, args)) - nodesToWindowSpecification(spec) match { - case reference: WindowSpecReference => - UnresolvedWindowExpression(function, reference) - case definition: WindowSpecDefinition => - WindowExpression(function, definition) - } + val (outRowFormat, outSerdeClass, outSerdeProps, useDefaultRecordWriter) = + matchSerDe(outputSerdeClause) - /* UDFs - Must be last otherwise will preempt built in functions */ - case Token("TOK_FUNCTION", Token(name, Nil) :: args) => - UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = false) - // Aggregate function with DISTINCT keyword. - case Token("TOK_FUNCTIONDI", Token(name, Nil) :: args) => - UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = true) - case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: args) => - UnresolvedFunction(name, UnresolvedStar(None) :: Nil, isDistinct = false) - - /* Literals */ - case Token("TOK_NULL", Nil) => Literal.create(null, NullType) - case Token(TRUE(), Nil) => Literal.create(true, BooleanType) - case Token(FALSE(), Nil) => Literal.create(false, BooleanType) - case Token("TOK_STRINGLITERALSEQUENCE", strings) => - Literal(strings.map(s => SemanticAnalyzer.unescapeSQLString(s.getText)).mkString) - - // This code is adapted from - // /ql/src/java/org/apache/hadoop/hive/ql/parse/TypeCheckProcFactory.java#L223 - case ast: ASTNode if numericAstTypes contains ast.getType => - var v: Literal = null - try { - if (ast.getText.endsWith("L")) { - // Literal bigint. - v = Literal.create(ast.getText.substring(0, ast.getText.length() - 1).toLong, LongType) - } else if (ast.getText.endsWith("S")) { - // Literal smallint. - v = Literal.create(ast.getText.substring(0, ast.getText.length() - 1).toShort, ShortType) - } else if (ast.getText.endsWith("Y")) { - // Literal tinyint. - v = Literal.create(ast.getText.substring(0, ast.getText.length() - 1).toByte, ByteType) - } else if (ast.getText.endsWith("BD") || ast.getText.endsWith("D")) { - // Literal decimal - val strVal = ast.getText.stripSuffix("D").stripSuffix("B") - v = Literal(Decimal(strVal)) - } else { - v = Literal.create(ast.getText.toDouble, DoubleType) - v = Literal.create(ast.getText.toLong, LongType) - v = Literal.create(ast.getText.toInt, IntegerType) - } - } catch { - case nfe: NumberFormatException => // Do nothing - } + val unescapedScript = unescapeSQLString(script) - if (v == null) { - sys.error(s"Failed to parse number '${ast.getText}'.") + // TODO Adds support for user-defined record reader/writer classes + val recordReaderClass = if (useDefaultRecordReader) { + Option(hiveConf.getVar(ConfVars.HIVESCRIPTRECORDREADER)) } else { - v + None } - case ast: ASTNode if ast.getType == SparkSqlParser.StringLiteral => - Literal(SemanticAnalyzer.unescapeSQLString(ast.getText)) - - case ast: ASTNode if ast.getType == SparkSqlParser.TOK_DATELITERAL => - Literal(Date.valueOf(ast.getText.substring(1, ast.getText.length - 1))) - - case ast: ASTNode if ast.getType == SparkSqlParser.TOK_CHARSETLITERAL => - Literal(SemanticAnalyzer.charSetString(ast.getChild(0).getText, ast.getChild(1).getText)) - - case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_YEAR_MONTH_LITERAL => - Literal(CalendarInterval.fromYearMonthString(ast.getText)) - - case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_DAY_TIME_LITERAL => - Literal(CalendarInterval.fromDayTimeString(ast.getText)) - - case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_YEAR_LITERAL => - Literal(CalendarInterval.fromSingleUnitString("year", ast.getText)) - - case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_MONTH_LITERAL => - Literal(CalendarInterval.fromSingleUnitString("month", ast.getText)) - - case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_DAY_LITERAL => - Literal(CalendarInterval.fromSingleUnitString("day", ast.getText)) - - case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_HOUR_LITERAL => - Literal(CalendarInterval.fromSingleUnitString("hour", ast.getText)) - - case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_MINUTE_LITERAL => - Literal(CalendarInterval.fromSingleUnitString("minute", ast.getText)) - - case ast: ASTNode if ast.getType == SparkSqlParser.TOK_INTERVAL_SECOND_LITERAL => - Literal(CalendarInterval.fromSingleUnitString("second", ast.getText)) - - case a: ASTNode => - throw new NotImplementedError( - s"""No parse rules for ASTNode type: ${a.getType}, text: ${a.getText} : - |${dumpTree(a).toString}" + - """.stripMargin) - } - - /* Case insensitive matches for Window Specification */ - val PRECEDING = "(?i)preceding".r - val FOLLOWING = "(?i)following".r - val CURRENT = "(?i)current".r - def nodesToWindowSpecification(nodes: Seq[ASTNode]): WindowSpec = nodes match { - case Token(windowName, Nil) :: Nil => - // Refer to a window spec defined in the window clause. - WindowSpecReference(windowName) - case Nil => - // OVER() - WindowSpecDefinition( - partitionSpec = Nil, - orderSpec = Nil, - frameSpecification = UnspecifiedFrame) - case spec => - val (partitionClause :: rowFrame :: rangeFrame :: Nil) = - getClauses( - Seq( - "TOK_PARTITIONINGSPEC", - "TOK_WINDOWRANGE", - "TOK_WINDOWVALUES"), - spec) - - // Handle Partition By and Order By. - val (partitionSpec, orderSpec) = partitionClause.map { partitionAndOrdering => - val (partitionByClause :: orderByClause :: sortByClause :: clusterByClause :: Nil) = - getClauses( - Seq("TOK_DISTRIBUTEBY", "TOK_ORDERBY", "TOK_SORTBY", "TOK_CLUSTERBY"), - partitionAndOrdering.getChildren.asScala.asInstanceOf[Seq[ASTNode]]) - - (partitionByClause, orderByClause.orElse(sortByClause), clusterByClause) match { - case (Some(partitionByExpr), Some(orderByExpr), None) => - (partitionByExpr.getChildren.asScala.map(nodeToExpr), - orderByExpr.getChildren.asScala.map(nodeToSortOrder)) - case (Some(partitionByExpr), None, None) => - (partitionByExpr.getChildren.asScala.map(nodeToExpr), Nil) - case (None, Some(orderByExpr), None) => - (Nil, orderByExpr.getChildren.asScala.map(nodeToSortOrder)) - case (None, None, Some(clusterByExpr)) => - val expressions = clusterByExpr.getChildren.asScala.map(nodeToExpr) - (expressions, expressions.map(SortOrder(_, Ascending))) - case _ => - throw new NotImplementedError( - s"""No parse rules for Node ${partitionAndOrdering.getName} - """.stripMargin) - } - }.getOrElse { - (Nil, Nil) + val recordWriterClass = if (useDefaultRecordWriter) { + Option(hiveConf.getVar(ConfVars.HIVESCRIPTRECORDWRITER)) + } else { + None } - // Handle Window Frame - val windowFrame = - if (rowFrame.isEmpty && rangeFrame.isEmpty) { - UnspecifiedFrame - } else { - val frameType = rowFrame.map(_ => RowFrame).getOrElse(RangeFrame) - def nodeToBoundary(node: Node): FrameBoundary = node match { - case Token(PRECEDING(), Token(count, Nil) :: Nil) => - if (count.toLowerCase() == "unbounded") { - UnboundedPreceding - } else { - ValuePreceding(count.toInt) - } - case Token(FOLLOWING(), Token(count, Nil) :: Nil) => - if (count.toLowerCase() == "unbounded") { - UnboundedFollowing - } else { - ValueFollowing(count.toInt) - } - case Token(CURRENT(), Nil) => CurrentRow - case _ => - throw new NotImplementedError( - s"""No parse rules for the Window Frame Boundary based on Node ${node.getName} - """.stripMargin) - } - - rowFrame.orElse(rangeFrame).map { frame => - frame.getChildren.asScala.toList match { - case precedingNode :: followingNode :: Nil => - SpecifiedWindowFrame( - frameType, - nodeToBoundary(precedingNode), - nodeToBoundary(followingNode)) - case precedingNode :: Nil => - SpecifiedWindowFrame(frameType, nodeToBoundary(precedingNode), CurrentRow) - case _ => - throw new NotImplementedError( - s"""No parse rules for the Window Frame based on Node ${frame.getName} - """.stripMargin) - } - }.getOrElse(sys.error(s"If you see this, please file a bug report with your query.")) - } - - WindowSpecDefinition(partitionSpec, orderSpec, windowFrame) + val schema = HiveScriptIOSchema( + inRowFormat, outRowFormat, + inSerdeClass, outSerdeClass, + inSerdeProps, outSerdeProps, + recordReaderClass, recordWriterClass, + schemaLess) + + Some( + ScriptTransformation( + inputExprs.map(nodeToExpr), + unescapedScript, + output, + child, schema)) + case _ => None } - val explode = "(?i)explode".r - val jsonTuple = "(?i)json_tuple".r - def nodesToGenerator(nodes: Seq[Node]): (Generator, Seq[String]) = { - val function = nodes.head - - val attributes = nodes.flatMap { - case Token(a, Nil) => a.toLowerCase :: Nil - case _ => Nil - } - - function match { - case Token("TOK_FUNCTION", Token(explode(), Nil) :: child :: Nil) => - (Explode(nodeToExpr(child)), attributes) - - case Token("TOK_FUNCTION", Token(jsonTuple(), Nil) :: children) => - (JsonTuple(children.map(nodeToExpr)), attributes) - - case Token("TOK_FUNCTION", Token(functionName, Nil) :: children) => - val functionInfo: FunctionInfo = - Option(FunctionRegistry.getFunctionInfo(functionName.toLowerCase)).getOrElse( - sys.error(s"Couldn't find function $functionName")) - val functionClassName = functionInfo.getFunctionClass.getName - - (HiveGenericUDTF( - new HiveFunctionWrapper(functionClassName), - children.map(nodeToExpr)), attributes) - - case a: ASTNode => - throw new NotImplementedError( - s"""No parse rules for ASTNode type: ${a.getType}, text: ${a.getText}, tree: - |${dumpTree(a).toString} - """.stripMargin) - } + protected override def nodeToGenerator(node: ASTNode): Generator = node match { + case Token("TOK_FUNCTION", Token(functionName, Nil) :: children) => + val functionInfo: FunctionInfo = + Option(FunctionRegistry.getFunctionInfo(functionName.toLowerCase)).getOrElse( + sys.error(s"Couldn't find function $functionName")) + val functionClassName = functionInfo.getFunctionClass.getName + HiveGenericUDTF(new HiveFunctionWrapper(functionClassName), children.map(nodeToExpr)) + case other => super.nodeToGenerator(node) } - def dumpTree(node: Node, builder: StringBuilder = new StringBuilder, indent: Int = 0) - : StringBuilder = { - node match { - case a: ASTNode => builder.append( - (" " * indent) + a.getText + " " + - a.getLine + ", " + - a.getTokenStartIndex + "," + - a.getTokenStopIndex + ", " + - a.getCharPositionInLine + "\n") - case other => sys.error(s"Non ASTNode encountered: $other") + // This is based the getColumns methods in + // ql/src/java/org/apache/hadoop/hive/ql/parse/BaseSemanticAnalyzer.java + protected def nodeToColumns(node: ASTNode, lowerCase: Boolean): Seq[HiveColumn] = { + node.children.map(_.children).collect { + case Token(rawColName, Nil) :: colTypeNode :: comment => + val colName = if (!lowerCase) rawColName + else rawColName.toLowerCase + HiveColumn( + cleanIdentifier(colName), + nodeToTypeString(colTypeNode), + comment.headOption.map(n => unescapeSQLString(n.text)).orNull) } + } - Option(node.getChildren).map(_.asScala).getOrElse(Nil).foreach(dumpTree(_, builder, indent + 1)) - builder + // This is based on the following methods in + // ql/src/java/org/apache/hadoop/hive/ql/parse/BaseSemanticAnalyzer.java: + // getTypeStringFromAST + // getStructTypeStringFromAST + // getUnionTypeStringFromAST + protected def nodeToTypeString(node: ASTNode): String = node.tokenType match { + case SparkSqlParser.TOK_LIST => + val listType :: Nil = node.children + val listTypeString = nodeToTypeString(listType) + s"${serdeConstants.LIST_TYPE_NAME}<$listTypeString>" + + case SparkSqlParser.TOK_MAP => + val keyType :: valueType :: Nil = node.children + val keyTypeString = nodeToTypeString(keyType) + val valueTypeString = nodeToTypeString(valueType) + s"${serdeConstants.MAP_TYPE_NAME}<$keyTypeString,$valueTypeString>" + + case SparkSqlParser.TOK_STRUCT => + val typeNode = node.children.head + require(typeNode.children.nonEmpty, "Struct must have one or more columns.") + val structColStrings = typeNode.children.map { columnNode => + val Token(colName, Nil) :: colTypeNode :: Nil = columnNode.children + cleanIdentifier(colName) + ":" + nodeToTypeString(colTypeNode) + } + s"${serdeConstants.STRUCT_TYPE_NAME}<${structColStrings.mkString(",")}>" + + case SparkSqlParser.TOK_UNIONTYPE => + val typeNode = node.children.head + val unionTypesString = typeNode.children.map(nodeToTypeString).mkString(",") + s"${serdeConstants.UNION_TYPE_NAME}<$unionTypesString>" + + case SparkSqlParser.TOK_CHAR => + val Token(size, Nil) :: Nil = node.children + s"${serdeConstants.CHAR_TYPE_NAME}($size)" + + case SparkSqlParser.TOK_VARCHAR => + val Token(size, Nil) :: Nil = node.children + s"${serdeConstants.VARCHAR_TYPE_NAME}($size)" + + case SparkSqlParser.TOK_DECIMAL => + val precisionAndScale = node.children match { + case Token(precision, Nil) :: Token(scale, Nil) :: Nil => + precision + "," + scale + case Token(precision, Nil) :: Nil => + precision + "," + HiveDecimal.USER_DEFAULT_SCALE + case Nil => + HiveDecimal.USER_DEFAULT_PRECISION + "," + HiveDecimal.USER_DEFAULT_SCALE + case _ => + noParseRule("Decimal", node) + } + s"${serdeConstants.DECIMAL_TYPE_NAME}($precisionAndScale)" + + // Simple data types. + case SparkSqlParser.TOK_BOOLEAN => serdeConstants.BOOLEAN_TYPE_NAME + case SparkSqlParser.TOK_TINYINT => serdeConstants.TINYINT_TYPE_NAME + case SparkSqlParser.TOK_SMALLINT => serdeConstants.SMALLINT_TYPE_NAME + case SparkSqlParser.TOK_INT => serdeConstants.INT_TYPE_NAME + case SparkSqlParser.TOK_BIGINT => serdeConstants.BIGINT_TYPE_NAME + case SparkSqlParser.TOK_FLOAT => serdeConstants.FLOAT_TYPE_NAME + case SparkSqlParser.TOK_DOUBLE => serdeConstants.DOUBLE_TYPE_NAME + case SparkSqlParser.TOK_STRING => serdeConstants.STRING_TYPE_NAME + case SparkSqlParser.TOK_BINARY => serdeConstants.BINARY_TYPE_NAME + case SparkSqlParser.TOK_DATE => serdeConstants.DATE_TYPE_NAME + case SparkSqlParser.TOK_TIMESTAMP => serdeConstants.TIMESTAMP_TYPE_NAME + case SparkSqlParser.TOK_INTERVAL_YEAR_MONTH => serdeConstants.INTERVAL_YEAR_MONTH_TYPE_NAME + case SparkSqlParser.TOK_INTERVAL_DAY_TIME => serdeConstants.INTERVAL_DAY_TIME_TYPE_NAME + case SparkSqlParser.TOK_DATETIME => serdeConstants.DATETIME_TYPE_NAME + case _ => null } + } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala index 400f7f3708cf4c596a46735b6f74746c4c431bc3..a2d283622ca5210967814a4d58af551bc2b03a8c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala @@ -21,6 +21,7 @@ import scala.util.Try import org.scalatest.BeforeAndAfter +import org.apache.spark.sql.catalyst.parser.ParseDriver import org.apache.spark.sql.{AnalysisException, QueryTest} import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.hive.test.TestHiveSingleton @@ -116,8 +117,9 @@ class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAnd * @param token a unique token in the string that should be indicated by the exception */ def positionTest(name: String, query: String, token: String): Unit = { + def ast = ParseDriver.parse(query, hiveContext.conf) def parseTree = - Try(quietly(HiveQl.dumpTree(HiveQl.getAst(query)))).getOrElse("<failed to parse>") + Try(quietly(ast.treeString)).getOrElse("<failed to parse>") test(name) { val error = intercept[AnalysisException] { @@ -139,10 +141,7 @@ class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAnd val expectedStart = line.indexOf(token) val actualStart = error.startPosition.getOrElse { - fail( - s"start not returned for error on token $token\n" + - HiveQl.dumpTree(HiveQl.getAst(query)) - ) + fail(s"start not returned for error on token $token\n${ast.treeString}") } assert(expectedStart === actualStart, s"""Incorrect start position.