diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index e396cf41f2f7bff1183c0b49fe7d272b013b965d..c03cb9338ae685d72883ccce508054d3acf61c7f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1081,8 +1081,7 @@ class SQLTests(ReusedPySparkTestCase): def test_capture_analysis_exception(self): self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("select abc")) self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b")) - # RuntimeException should not be captured - self.assertRaises(py4j.protocol.Py4JJavaError, lambda: self.sqlCtx.sql("abc")) + self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("abc")) def test_capture_illegalargument_exception(self): self.assertRaisesRegexp(IllegalArgumentException, "Setting negative mapred.reduce.tasks", diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g index aabb5d49582c89e56d4f8cb498bdb6802e14fa23..047a7e56cb57736b3490c2bc718fdd003aa13e6e 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g @@ -123,7 +123,6 @@ constant | SmallintLiteral | TinyintLiteral | DecimalLiteral - | charSetStringLiteral | booleanValue ; @@ -132,13 +131,6 @@ stringLiteralSequence StringLiteral StringLiteral+ -> ^(TOK_STRINGLITERALSEQUENCE StringLiteral StringLiteral+) ; -charSetStringLiteral -@init { gParent.pushMsg("character string literal", state); } -@after { gParent.popMsg(state); } - : - csName=CharSetName csLiteral=CharSetLiteral -> ^(TOK_CHARSETLITERAL $csName $csLiteral) - ; - dateLiteral : KW_DATE StringLiteral -> @@ -163,22 +155,38 @@ timestampLiteral intervalLiteral : - KW_INTERVAL StringLiteral qualifiers=intervalQualifiers -> - { - adaptor.create($qualifiers.tree.token.getType(), $StringLiteral.text) + (KW_INTERVAL intervalConstant KW_YEAR KW_TO KW_MONTH) => KW_INTERVAL intervalConstant KW_YEAR KW_TO KW_MONTH + -> ^(TOK_INTERVAL_YEAR_MONTH_LITERAL intervalConstant) + | (KW_INTERVAL intervalConstant KW_DAY KW_TO KW_SECOND) => KW_INTERVAL intervalConstant KW_DAY KW_TO KW_SECOND + -> ^(TOK_INTERVAL_DAY_TIME_LITERAL intervalConstant) + | KW_INTERVAL + ((intervalConstant KW_YEAR)=> year=intervalConstant KW_YEAR)? + ((intervalConstant KW_MONTH)=> month=intervalConstant KW_MONTH)? + ((intervalConstant KW_WEEK)=> week=intervalConstant KW_WEEK)? + ((intervalConstant KW_DAY)=> day=intervalConstant KW_DAY)? + ((intervalConstant KW_HOUR)=> hour=intervalConstant KW_HOUR)? + ((intervalConstant KW_MINUTE)=> minute=intervalConstant KW_MINUTE)? + ((intervalConstant KW_SECOND)=> second=intervalConstant KW_SECOND)? + (millisecond=intervalConstant KW_MILLISECOND)? + (microsecond=intervalConstant KW_MICROSECOND)? + -> ^(TOK_INTERVAL + ^(TOK_INTERVAL_YEAR_LITERAL $year?) + ^(TOK_INTERVAL_MONTH_LITERAL $month?) + ^(TOK_INTERVAL_WEEK_LITERAL $week?) + ^(TOK_INTERVAL_DAY_LITERAL $day?) + ^(TOK_INTERVAL_HOUR_LITERAL $hour?) + ^(TOK_INTERVAL_MINUTE_LITERAL $minute?) + ^(TOK_INTERVAL_SECOND_LITERAL $second?) + ^(TOK_INTERVAL_MILLISECOND_LITERAL $millisecond?) + ^(TOK_INTERVAL_MICROSECOND_LITERAL $microsecond?)) + ; + +intervalConstant + : + sign=(MINUS|PLUS)? value=Number -> { + adaptor.create(Number, ($sign != null ? $sign.getText() : "") + $value.getText()) } - ; - -intervalQualifiers - : - KW_YEAR KW_TO KW_MONTH -> TOK_INTERVAL_YEAR_MONTH_LITERAL - | KW_DAY KW_TO KW_SECOND -> TOK_INTERVAL_DAY_TIME_LITERAL - | KW_YEAR -> TOK_INTERVAL_YEAR_LITERAL - | KW_MONTH -> TOK_INTERVAL_MONTH_LITERAL - | KW_DAY -> TOK_INTERVAL_DAY_LITERAL - | KW_HOUR -> TOK_INTERVAL_HOUR_LITERAL - | KW_MINUTE -> TOK_INTERVAL_MINUTE_LITERAL - | KW_SECOND -> TOK_INTERVAL_SECOND_LITERAL + | StringLiteral ; expression @@ -219,7 +227,8 @@ nullCondition precedenceUnaryPrefixExpression : - (precedenceUnaryOperator^)* precedenceFieldExpression + (precedenceUnaryOperator+)=> precedenceUnaryOperator^ precedenceUnaryPrefixExpression + | precedenceFieldExpression ; precedenceUnarySuffixExpression diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g index 972c52e3ffcec356883690028211d96783a70305..6d76afcd4ac0727ea7349ae76fdd6e4528c5ea70 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g @@ -206,11 +206,8 @@ tableName @init { gParent.pushMsg("table name", state); } @after { gParent.popMsg(state); } : - db=identifier DOT tab=identifier - -> ^(TOK_TABNAME $db $tab) - | - tab=identifier - -> ^(TOK_TABNAME $tab) + id1=identifier (DOT id2=identifier)? + -> ^(TOK_TABNAME $id1 $id2?) ; viewName diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g index 44a63fbef258c44b6465eb1a6bd3c31f6191b937..ee2882e51c4506ab36f72d771fb9f02a372a3361 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g @@ -307,12 +307,12 @@ KW_AUTHORIZATION: 'AUTHORIZATION'; KW_CONF: 'CONF'; KW_VALUES: 'VALUES'; KW_RELOAD: 'RELOAD'; -KW_YEAR: 'YEAR'; -KW_MONTH: 'MONTH'; -KW_DAY: 'DAY'; -KW_HOUR: 'HOUR'; -KW_MINUTE: 'MINUTE'; -KW_SECOND: 'SECOND'; +KW_YEAR: 'YEAR'|'YEARS'; +KW_MONTH: 'MONTH'|'MONTHS'; +KW_DAY: 'DAY'|'DAYS'; +KW_HOUR: 'HOUR'|'HOURS'; +KW_MINUTE: 'MINUTE'|'MINUTES'; +KW_SECOND: 'SECOND'|'SECONDS'; KW_START: 'START'; KW_TRANSACTION: 'TRANSACTION'; KW_COMMIT: 'COMMIT'; @@ -324,6 +324,9 @@ KW_ISOLATION: 'ISOLATION'; KW_LEVEL: 'LEVEL'; KW_SNAPSHOT: 'SNAPSHOT'; KW_AUTOCOMMIT: 'AUTOCOMMIT'; +KW_WEEK: 'WEEK'|'WEEKS'; +KW_MILLISECOND: 'MILLISECOND'|'MILLISECONDS'; +KW_MICROSECOND: 'MICROSECOND'|'MICROSECONDS'; // Operators // NOTE: if you add a new function/operator, add it to sysFuncNames so that describe function _FUNC_ will work. @@ -400,12 +403,6 @@ StringLiteral )+ ; -CharSetLiteral - : - StringLiteral - | '0' 'X' (HexDigit|Digit)+ - ; - BigintLiteral : (Digit)+ 'L' @@ -433,7 +430,7 @@ ByteLengthLiteral Number : - (Digit)+ ( DOT (Digit)* (Exponent)? | Exponent)? + ((Digit+ (DOT Digit*)?) | (DOT Digit+)) Exponent? ; /* @@ -456,10 +453,10 @@ An Identifier can be: - macro name - hint name - window name -*/ +*/ Identifier : - (Letter | Digit) (Letter | Digit | '_')* + (Letter | Digit | '_')+ | {allowQuotedId()}? QuotedIdentifier /* though at the language level we allow all Identifiers to be QuotedIdentifiers; at the API level only columns are allowed to be of this form */ | '`' RegexComponent+ '`' @@ -471,11 +468,6 @@ QuotedIdentifier '`' ( '``' | ~('`') )* '`' { setText(getText().substring(1, getText().length() -1 ).replaceAll("``", "`")); } ; -CharSetName - : - '_' (Letter | Digit | '_' | '-' | '.' | ':' )+ - ; - WS : (' '|'\r'|'\t'|'\n') {$channel=HIDDEN;} ; diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g index 2c13d3056f468a4917760f1adfe2ff17941d1ab5..c146ca591488464c7c90595922c8c4c53e8d1bf8 100644 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g @@ -116,16 +116,20 @@ TOK_DATELITERAL; TOK_DATETIME; TOK_TIMESTAMP; TOK_TIMESTAMPLITERAL; +TOK_INTERVAL; TOK_INTERVAL_YEAR_MONTH; TOK_INTERVAL_YEAR_MONTH_LITERAL; TOK_INTERVAL_DAY_TIME; TOK_INTERVAL_DAY_TIME_LITERAL; TOK_INTERVAL_YEAR_LITERAL; TOK_INTERVAL_MONTH_LITERAL; +TOK_INTERVAL_WEEK_LITERAL; TOK_INTERVAL_DAY_LITERAL; TOK_INTERVAL_HOUR_LITERAL; TOK_INTERVAL_MINUTE_LITERAL; TOK_INTERVAL_SECOND_LITERAL; +TOK_INTERVAL_MILLISECOND_LITERAL; +TOK_INTERVAL_MICROSECOND_LITERAL; TOK_STRING; TOK_CHAR; TOK_VARCHAR; @@ -228,7 +232,6 @@ TOK_TMP_FILE; TOK_TABSORTCOLNAMEASC; TOK_TABSORTCOLNAMEDESC; TOK_STRINGLITERALSEQUENCE; -TOK_CHARSETLITERAL; TOK_CREATEFUNCTION; TOK_DROPFUNCTION; TOK_RELOADFUNCTION; @@ -509,7 +512,9 @@ import java.util.HashMap; xlateMap.put("KW_UPDATE", "UPDATE"); xlateMap.put("KW_VALUES", "VALUES"); xlateMap.put("KW_PURGE", "PURGE"); - + xlateMap.put("KW_WEEK", "WEEK"); + xlateMap.put("KW_MILLISECOND", "MILLISECOND"); + xlateMap.put("KW_MICROSECOND", "MICROSECOND"); // Operators xlateMap.put("DOT", "."); @@ -2078,6 +2083,7 @@ primitiveType | KW_SMALLINT -> TOK_SMALLINT | KW_INT -> TOK_INT | KW_BIGINT -> TOK_BIGINT + | KW_LONG -> TOK_BIGINT | KW_BOOLEAN -> TOK_BOOLEAN | KW_FLOAT -> TOK_FLOAT | KW_DOUBLE -> TOK_DOUBLE diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java index 5bc87b680f9ad82851fe9f136c9bb4330e214fec..2520c7bb8dae4d6a8464fb736857270f093a9209 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java @@ -18,12 +18,10 @@ package org.apache.spark.sql.catalyst.parser; -import java.io.UnsupportedEncodingException; - /** * A couple of utility methods that help with parsing ASTs. * - * Both methods in this class were take from the SemanticAnalyzer in Hive: + * The 'unescapeSQLString' method in this class was take from the SemanticAnalyzer in Hive: * ql/src/java/org/apache/hadoop/hive/ql/parse/BaseSemanticAnalyzer.java */ public final class ParseUtils { @@ -31,33 +29,6 @@ public final class ParseUtils { super(); } - public static String charSetString(String charSetName, String charSetString) - throws UnsupportedEncodingException { - // The character set name starts with a _, so strip that - charSetName = charSetName.substring(1); - if (charSetString.charAt(0) == '\'') { - return new String(unescapeSQLString(charSetString).getBytes(), charSetName); - } else // hex input is also supported - { - assert charSetString.charAt(0) == '0'; - assert charSetString.charAt(1) == 'x'; - charSetString = charSetString.substring(2); - - byte[] bArray = new byte[charSetString.length() / 2]; - int j = 0; - for (int i = 0; i < charSetString.length(); i += 2) { - int val = Character.digit(charSetString.charAt(i), 16) * 16 - + Character.digit(charSetString.charAt(i + 1), 16); - if (val > 127) { - val = val - 256; - } - bArray[j++] = (byte)val; - } - - return new String(bArray, charSetName); - } - } - private static final int[] multiplier = new int[] {1000, 100, 10, 1}; @SuppressWarnings("nls") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala index bdc52c08acb66b45b652ba27be7506fde0c26f16..9443369808984c9291db037e5ec5a5e0549fb506 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala @@ -26,9 +26,9 @@ import scala.util.parsing.input.CharArrayReader.EofCh import org.apache.spark.sql.catalyst.plans.logical._ private[sql] abstract class AbstractSparkSQLParser - extends StandardTokenParsers with PackratParsers { + extends StandardTokenParsers with PackratParsers with ParserDialect { - def parse(input: String): LogicalPlan = synchronized { + def parsePlan(input: String): LogicalPlan = synchronized { // Initialize the Keywords. initLexical phrase(start)(new lexical.Scanner(input)) match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala index d0fbdacf6eafdc8a0320f58920c3141f657612d3..c1591ecfe2b4d1b3038f29c5872d93d3a0c94425 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala @@ -30,16 +30,10 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.random.RandomSampler -private[sql] object CatalystQl { - val parser = new CatalystQl - def parseExpression(sql: String): Expression = parser.parseExpression(sql) - def parseTableIdentifier(sql: String): TableIdentifier = parser.parseTableIdentifier(sql) -} - /** * This class translates a HQL String to a Catalyst [[LogicalPlan]] or [[Expression]]. */ -private[sql] class CatalystQl(val conf: ParserConf = SimpleParserConf()) { +private[sql] class CatalystQl(val conf: ParserConf = SimpleParserConf()) extends ParserDialect { object Token { def unapply(node: ASTNode): Some[(String, List[ASTNode])] = { CurrentOrigin.setPosition(node.line, node.positionInLine) @@ -611,13 +605,6 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case plainIdent => plainIdent } - val numericAstTypes = Seq( - SparkSqlParser.Number, - SparkSqlParser.TinyintLiteral, - SparkSqlParser.SmallintLiteral, - SparkSqlParser.BigintLiteral, - SparkSqlParser.DecimalLiteral) - /* Case insensitive matches */ val COUNT = "(?i)COUNT".r val SUM = "(?i)SUM".r @@ -635,6 +622,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val WHEN = "(?i)WHEN".r val CASE = "(?i)CASE".r + val INTEGRAL = "[+-]?\\d+".r + protected def nodeToExpr(node: ASTNode): Expression = node match { /* Attribute References */ case Token("TOK_TABLE_OR_COL", Token(name, Nil) :: Nil) => @@ -650,8 +639,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("TOK_ALLCOLREF", Nil) => UnresolvedStar(None) // The format of dbName.tableName.* cannot be parsed by HiveParser. TOK_TABNAME will only // has a single child which is tableName. - case Token("TOK_ALLCOLREF", Token("TOK_TABNAME", Token(name, Nil) :: Nil) :: Nil) => - UnresolvedStar(Some(UnresolvedAttribute.parseAttributeName(name))) + case Token("TOK_ALLCOLREF", Token("TOK_TABNAME", target) :: Nil) if target.nonEmpty => + UnresolvedStar(Some(target.map(_.text))) /* Aggregate Functions */ case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) => @@ -787,71 +776,71 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("TOK_STRINGLITERALSEQUENCE", strings) => Literal(strings.map(s => ParseUtils.unescapeSQLString(s.text)).mkString) - // This code is adapted from - // /ql/src/java/org/apache/hadoop/hive/ql/parse/TypeCheckProcFactory.java#L223 - case ast: ASTNode if numericAstTypes contains ast.tokenType => - var v: Literal = null - try { - if (ast.text.endsWith("L")) { - // Literal bigint. - v = Literal.create(ast.text.substring(0, ast.text.length() - 1).toLong, LongType) - } else if (ast.text.endsWith("S")) { - // Literal smallint. - v = Literal.create(ast.text.substring(0, ast.text.length() - 1).toShort, ShortType) - } else if (ast.text.endsWith("Y")) { - // Literal tinyint. - v = Literal.create(ast.text.substring(0, ast.text.length() - 1).toByte, ByteType) - } else if (ast.text.endsWith("BD") || ast.text.endsWith("D")) { - // Literal decimal - val strVal = ast.text.stripSuffix("D").stripSuffix("B") - v = Literal(Decimal(strVal)) - } else { - v = Literal.create(ast.text.toDouble, DoubleType) - v = Literal.create(ast.text.toLong, LongType) - v = Literal.create(ast.text.toInt, IntegerType) - } - } catch { - case nfe: NumberFormatException => // Do nothing - } - - if (v == null) { - sys.error(s"Failed to parse number '${ast.text}'.") - } else { - v - } - - case ast: ASTNode if ast.tokenType == SparkSqlParser.StringLiteral => - Literal(ParseUtils.unescapeSQLString(ast.text)) + case ast if ast.tokenType == SparkSqlParser.TinyintLiteral => + Literal.create(ast.text.substring(0, ast.text.length() - 1).toByte, ByteType) - case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_DATELITERAL => - Literal(Date.valueOf(ast.text.substring(1, ast.text.length - 1))) + case ast if ast.tokenType == SparkSqlParser.SmallintLiteral => + Literal.create(ast.text.substring(0, ast.text.length() - 1).toShort, ShortType) - case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_CHARSETLITERAL => - Literal(ParseUtils.charSetString(ast.children.head.text, ast.children(1).text)) + case ast if ast.tokenType == SparkSqlParser.BigintLiteral => + Literal.create(ast.text.substring(0, ast.text.length() - 1).toLong, LongType) - case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_YEAR_MONTH_LITERAL => - Literal(CalendarInterval.fromYearMonthString(ast.text)) + case ast if ast.tokenType == SparkSqlParser.DecimalLiteral => + Literal(Decimal(ast.text.substring(0, ast.text.length() - 2))) - case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_DAY_TIME_LITERAL => - Literal(CalendarInterval.fromDayTimeString(ast.text)) - - case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_YEAR_LITERAL => - Literal(CalendarInterval.fromSingleUnitString("year", ast.text)) - - case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_MONTH_LITERAL => - Literal(CalendarInterval.fromSingleUnitString("month", ast.text)) - - case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_DAY_LITERAL => - Literal(CalendarInterval.fromSingleUnitString("day", ast.text)) - - case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_HOUR_LITERAL => - Literal(CalendarInterval.fromSingleUnitString("hour", ast.text)) + case ast if ast.tokenType == SparkSqlParser.Number => + val text = ast.text + text match { + case INTEGRAL() => + BigDecimal(text) match { + case v if v.isValidInt => + Literal(v.intValue()) + case v if v.isValidLong => + Literal(v.longValue()) + case v => Literal(v.underlying()) + } + case _ => + Literal(text.toDouble) + } + case ast if ast.tokenType == SparkSqlParser.StringLiteral => + Literal(ParseUtils.unescapeSQLString(ast.text)) - case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_MINUTE_LITERAL => - Literal(CalendarInterval.fromSingleUnitString("minute", ast.text)) + case ast if ast.tokenType == SparkSqlParser.TOK_DATELITERAL => + Literal(Date.valueOf(ast.text.substring(1, ast.text.length - 1))) - case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_SECOND_LITERAL => - Literal(CalendarInterval.fromSingleUnitString("second", ast.text)) + case ast if ast.tokenType == SparkSqlParser.TOK_INTERVAL_YEAR_MONTH_LITERAL => + Literal(CalendarInterval.fromYearMonthString(ast.children.head.text)) + + case ast if ast.tokenType == SparkSqlParser.TOK_INTERVAL_DAY_TIME_LITERAL => + Literal(CalendarInterval.fromDayTimeString(ast.children.head.text)) + + case Token("TOK_INTERVAL", elements) => + var interval = new CalendarInterval(0, 0) + var updated = false + elements.foreach { + // The interval node will always contain children for all possible time units. A child node + // is only useful when it contains exactly one (numeric) child. + case e @ Token(name, Token(value, Nil) :: Nil) => + val unit = name match { + case "TOK_INTERVAL_YEAR_LITERAL" => "year" + case "TOK_INTERVAL_MONTH_LITERAL" => "month" + case "TOK_INTERVAL_WEEK_LITERAL" => "week" + case "TOK_INTERVAL_DAY_LITERAL" => "day" + case "TOK_INTERVAL_HOUR_LITERAL" => "hour" + case "TOK_INTERVAL_MINUTE_LITERAL" => "minute" + case "TOK_INTERVAL_SECOND_LITERAL" => "second" + case "TOK_INTERVAL_MILLISECOND_LITERAL" => "millisecond" + case "TOK_INTERVAL_MICROSECOND_LITERAL" => "microsecond" + case _ => noParseRule(s"Interval($name)", e) + } + interval = interval.add(CalendarInterval.fromSingleUnitString(unit, value)) + updated = true + case _ => + } + if (!updated) { + throw new AnalysisException("at least one time unit should be given for interval literal") + } + Literal(interval) case _ => noParseRule("Expression", node) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala index e21d3c05464b6325a6a20fda5aca61353fa573a4..7d9fbf2f12ee61c0fbe40e5e89efee52505f3d54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala @@ -18,52 +18,22 @@ package org.apache.spark.sql.catalyst import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan /** * Root class of SQL Parser Dialect, and we don't guarantee the binary * compatibility for the future release, let's keep it as the internal * interface for advanced user. - * */ @DeveloperApi -abstract class ParserDialect { - // this is the main function that will be implemented by sql parser. - def parse(sqlText: String): LogicalPlan -} +trait ParserDialect { + /** Creates LogicalPlan for a given SQL string. */ + def parsePlan(sqlText: String): LogicalPlan -/** - * Currently we support the default dialect named "sql", associated with the class - * [[DefaultParserDialect]] - * - * And we can also provide custom SQL Dialect, for example in Spark SQL CLI: - * {{{ - *-- switch to "hiveql" dialect - * spark-sql>SET spark.sql.dialect=hiveql; - * spark-sql>SELECT * FROM src LIMIT 1; - * - *-- switch to "sql" dialect - * spark-sql>SET spark.sql.dialect=sql; - * spark-sql>SELECT * FROM src LIMIT 1; - * - *-- register the new SQL dialect - * spark-sql> SET spark.sql.dialect=com.xxx.xxx.SQL99Dialect; - * spark-sql> SELECT * FROM src LIMIT 1; - * - *-- register the non-exist SQL dialect - * spark-sql> SET spark.sql.dialect=NotExistedClass; - * spark-sql> SELECT * FROM src LIMIT 1; - * - *-- Exception will be thrown and switch to dialect - *-- "sql" (for SQLContext) or - *-- "hiveql" (for HiveContext) - * }}} - */ -private[spark] class DefaultParserDialect extends ParserDialect { - @transient - protected val sqlParser = SqlParser + /** Creates Expression for a given SQL string. */ + def parseExpression(sqlText: String): Expression - override def parse(sqlText: String): LogicalPlan = { - sqlParser.parse(sqlText) - } + /** Creates TableIdentifier for a given SQL string. */ + def parseTableIdentifier(sqlText: String): TableIdentifier } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala deleted file mode 100644 index 85ff4ea0c946b696b888d0bb90b4bd8a3257407b..0000000000000000000000000000000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ /dev/null @@ -1,509 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst - -import scala.language.implicitConversions - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util.DataTypeParser -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.CalendarInterval - -/** - * A very simple SQL parser. Based loosely on: - * https://github.com/stephentu/scala-sql-parser/blob/master/src/main/scala/parser.scala - * - * Limitations: - * - Only supports a very limited subset of SQL. - * - * This is currently included mostly for illustrative purposes. Users wanting more complete support - * for a SQL like language should checkout the HiveQL support in the sql/hive sub-project. - */ -object SqlParser extends AbstractSparkSQLParser with DataTypeParser { - - def parseExpression(input: String): Expression = synchronized { - // Initialize the Keywords. - initLexical - phrase(projection)(new lexical.Scanner(input)) match { - case Success(plan, _) => plan - case failureOrError => sys.error(failureOrError.toString) - } - } - - def parseTableIdentifier(input: String): TableIdentifier = synchronized { - // Initialize the Keywords. - initLexical - phrase(tableIdentifier)(new lexical.Scanner(input)) match { - case Success(ident, _) => ident - case failureOrError => sys.error(failureOrError.toString) - } - } - - // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` - // properties via reflection the class in runtime for constructing the SqlLexical object - protected val ALL = Keyword("ALL") - protected val AND = Keyword("AND") - protected val APPROXIMATE = Keyword("APPROXIMATE") - protected val AS = Keyword("AS") - protected val ASC = Keyword("ASC") - protected val BETWEEN = Keyword("BETWEEN") - protected val BY = Keyword("BY") - protected val CASE = Keyword("CASE") - protected val CAST = Keyword("CAST") - protected val DESC = Keyword("DESC") - protected val DISTINCT = Keyword("DISTINCT") - protected val ELSE = Keyword("ELSE") - protected val END = Keyword("END") - protected val EXCEPT = Keyword("EXCEPT") - protected val FALSE = Keyword("FALSE") - protected val FROM = Keyword("FROM") - protected val FULL = Keyword("FULL") - protected val GROUP = Keyword("GROUP") - protected val HAVING = Keyword("HAVING") - protected val IN = Keyword("IN") - protected val INNER = Keyword("INNER") - protected val INSERT = Keyword("INSERT") - protected val INTERSECT = Keyword("INTERSECT") - protected val INTERVAL = Keyword("INTERVAL") - protected val INTO = Keyword("INTO") - protected val IS = Keyword("IS") - protected val JOIN = Keyword("JOIN") - protected val LEFT = Keyword("LEFT") - protected val LIKE = Keyword("LIKE") - protected val LIMIT = Keyword("LIMIT") - protected val NOT = Keyword("NOT") - protected val NULL = Keyword("NULL") - protected val ON = Keyword("ON") - protected val OR = Keyword("OR") - protected val ORDER = Keyword("ORDER") - protected val SORT = Keyword("SORT") - protected val OUTER = Keyword("OUTER") - protected val OVERWRITE = Keyword("OVERWRITE") - protected val REGEXP = Keyword("REGEXP") - protected val RIGHT = Keyword("RIGHT") - protected val RLIKE = Keyword("RLIKE") - protected val SELECT = Keyword("SELECT") - protected val SEMI = Keyword("SEMI") - protected val TABLE = Keyword("TABLE") - protected val THEN = Keyword("THEN") - protected val TRUE = Keyword("TRUE") - protected val UNION = Keyword("UNION") - protected val WHEN = Keyword("WHEN") - protected val WHERE = Keyword("WHERE") - protected val WITH = Keyword("WITH") - - protected lazy val start: Parser[LogicalPlan] = - start1 | insert | cte - - protected lazy val start1: Parser[LogicalPlan] = - (select | ("(" ~> select <~ ")")) * - ( UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) } - | INTERSECT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Intersect(q1, q2) } - | EXCEPT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Except(q1, q2)} - | UNION ~ DISTINCT.? ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) } - ) - - protected lazy val select: Parser[LogicalPlan] = - SELECT ~> DISTINCT.? ~ - repsep(projection, ",") ~ - (FROM ~> relations).? ~ - (WHERE ~> expression).? ~ - (GROUP ~ BY ~> rep1sep(expression, ",")).? ~ - (HAVING ~> expression).? ~ - sortType.? ~ - (LIMIT ~> expression).? ^^ { - case d ~ p ~ r ~ f ~ g ~ h ~ o ~ l => - val base = r.getOrElse(OneRowRelation) - val withFilter = f.map(Filter(_, base)).getOrElse(base) - val withProjection = g - .map(Aggregate(_, p.map(UnresolvedAlias(_)), withFilter)) - .getOrElse(Project(p.map(UnresolvedAlias(_)), withFilter)) - val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection) - val withHaving = h.map(Filter(_, withDistinct)).getOrElse(withDistinct) - val withOrder = o.map(_(withHaving)).getOrElse(withHaving) - val withLimit = l.map(Limit(_, withOrder)).getOrElse(withOrder) - withLimit - } - - protected lazy val insert: Parser[LogicalPlan] = - INSERT ~> (OVERWRITE ^^^ true | INTO ^^^ false) ~ (TABLE ~> relation) ~ select ^^ { - case o ~ r ~ s => InsertIntoTable(r, Map.empty[String, Option[String]], s, o, false) - } - - protected lazy val cte: Parser[LogicalPlan] = - WITH ~> rep1sep(ident ~ ( AS ~ "(" ~> start1 <~ ")"), ",") ~ (start1 | insert) ^^ { - case r ~ s => With(s, r.map({case n ~ s => (n, Subquery(n, s))}).toMap) - } - - protected lazy val projection: Parser[Expression] = - expression ~ (AS.? ~> ident.?) ^^ { - case e ~ a => a.fold(e)(Alias(e, _)()) - } - - // Based very loosely on the MySQL Grammar. - // http://dev.mysql.com/doc/refman/5.0/en/join.html - protected lazy val relations: Parser[LogicalPlan] = - ( relation ~ rep1("," ~> relation) ^^ { - case r1 ~ joins => joins.foldLeft(r1) { case(lhs, r) => Join(lhs, r, Inner, None) } } - | relation - ) - - protected lazy val relation: Parser[LogicalPlan] = - joinedRelation | relationFactor - - protected lazy val relationFactor: Parser[LogicalPlan] = - ( tableIdentifier ~ (opt(AS) ~> opt(ident)) ^^ { - case tableIdent ~ alias => UnresolvedRelation(tableIdent, alias) - } - | ("(" ~> start <~ ")") ~ (AS.? ~> ident) ^^ { case s ~ a => Subquery(a, s) } - ) - - protected lazy val joinedRelation: Parser[LogicalPlan] = - relationFactor ~ rep1(joinType.? ~ (JOIN ~> relationFactor) ~ joinConditions.?) ^^ { - case r1 ~ joins => - joins.foldLeft(r1) { case (lhs, jt ~ rhs ~ cond) => - Join(lhs, rhs, joinType = jt.getOrElse(Inner), cond) - } - } - - protected lazy val joinConditions: Parser[Expression] = - ON ~> expression - - protected lazy val joinType: Parser[JoinType] = - ( INNER ^^^ Inner - | LEFT ~ SEMI ^^^ LeftSemi - | LEFT ~ OUTER.? ^^^ LeftOuter - | RIGHT ~ OUTER.? ^^^ RightOuter - | FULL ~ OUTER.? ^^^ FullOuter - ) - - protected lazy val sortType: Parser[LogicalPlan => LogicalPlan] = - ( ORDER ~ BY ~> ordering ^^ { case o => l: LogicalPlan => Sort(o, true, l) } - | SORT ~ BY ~> ordering ^^ { case o => l: LogicalPlan => Sort(o, false, l) } - ) - - protected lazy val ordering: Parser[Seq[SortOrder]] = - ( rep1sep(expression ~ direction.?, ",") ^^ { - case exps => exps.map(pair => SortOrder(pair._1, pair._2.getOrElse(Ascending))) - } - ) - - protected lazy val direction: Parser[SortDirection] = - ( ASC ^^^ Ascending - | DESC ^^^ Descending - ) - - protected lazy val expression: Parser[Expression] = - orExpression - - protected lazy val orExpression: Parser[Expression] = - andExpression * (OR ^^^ { (e1: Expression, e2: Expression) => Or(e1, e2) }) - - protected lazy val andExpression: Parser[Expression] = - notExpression * (AND ^^^ { (e1: Expression, e2: Expression) => And(e1, e2) }) - - protected lazy val notExpression: Parser[Expression] = - NOT.? ~ comparisonExpression ^^ { case maybeNot ~ e => maybeNot.map(_ => Not(e)).getOrElse(e) } - - protected lazy val comparisonExpression: Parser[Expression] = - ( termExpression ~ ("=" ~> termExpression) ^^ { case e1 ~ e2 => EqualTo(e1, e2) } - | termExpression ~ ("<" ~> termExpression) ^^ { case e1 ~ e2 => LessThan(e1, e2) } - | termExpression ~ ("<=" ~> termExpression) ^^ { case e1 ~ e2 => LessThanOrEqual(e1, e2) } - | termExpression ~ (">" ~> termExpression) ^^ { case e1 ~ e2 => GreaterThan(e1, e2) } - | termExpression ~ (">=" ~> termExpression) ^^ { case e1 ~ e2 => GreaterThanOrEqual(e1, e2) } - | termExpression ~ ("!=" ~> termExpression) ^^ { case e1 ~ e2 => Not(EqualTo(e1, e2)) } - | termExpression ~ ("<>" ~> termExpression) ^^ { case e1 ~ e2 => Not(EqualTo(e1, e2)) } - | termExpression ~ ("<=>" ~> termExpression) ^^ { case e1 ~ e2 => EqualNullSafe(e1, e2) } - | termExpression ~ NOT.? ~ (BETWEEN ~> termExpression) ~ (AND ~> termExpression) ^^ { - case e ~ not ~ el ~ eu => - val betweenExpr: Expression = And(GreaterThanOrEqual(e, el), LessThanOrEqual(e, eu)) - not.fold(betweenExpr)(f => Not(betweenExpr)) - } - | termExpression ~ (RLIKE ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) } - | termExpression ~ (REGEXP ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) } - | termExpression ~ (LIKE ~> termExpression) ^^ { case e1 ~ e2 => Like(e1, e2) } - | termExpression ~ (NOT ~ LIKE ~> termExpression) ^^ { case e1 ~ e2 => Not(Like(e1, e2)) } - | termExpression ~ (IN ~ "(" ~> rep1sep(termExpression, ",")) <~ ")" ^^ { - case e1 ~ e2 => In(e1, e2) - } - | termExpression ~ (NOT ~ IN ~ "(" ~> rep1sep(termExpression, ",")) <~ ")" ^^ { - case e1 ~ e2 => Not(In(e1, e2)) - } - | termExpression <~ IS ~ NULL ^^ { case e => IsNull(e) } - | termExpression <~ IS ~ NOT ~ NULL ^^ { case e => IsNotNull(e) } - | termExpression - ) - - protected lazy val termExpression: Parser[Expression] = - productExpression * - ( "+" ^^^ { (e1: Expression, e2: Expression) => Add(e1, e2) } - | "-" ^^^ { (e1: Expression, e2: Expression) => Subtract(e1, e2) } - ) - - protected lazy val productExpression: Parser[Expression] = - baseExpression * - ( "*" ^^^ { (e1: Expression, e2: Expression) => Multiply(e1, e2) } - | "/" ^^^ { (e1: Expression, e2: Expression) => Divide(e1, e2) } - | "%" ^^^ { (e1: Expression, e2: Expression) => Remainder(e1, e2) } - | "&" ^^^ { (e1: Expression, e2: Expression) => BitwiseAnd(e1, e2) } - | "|" ^^^ { (e1: Expression, e2: Expression) => BitwiseOr(e1, e2) } - | "^" ^^^ { (e1: Expression, e2: Expression) => BitwiseXor(e1, e2) } - ) - - protected lazy val function: Parser[Expression] = - ( ident <~ ("(" ~ "*" ~ ")") ^^ { case udfName => - if (lexical.normalizeKeyword(udfName) == "count") { - AggregateExpression(Count(Literal(1)), mode = Complete, isDistinct = false) - } else { - throw new AnalysisException(s"invalid expression $udfName(*)") - } - } - | ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^ - { case udfName ~ exprs => UnresolvedFunction(udfName, exprs, isDistinct = false) } - | ident ~ ("(" ~ DISTINCT ~> repsep(expression, ",")) <~ ")" ^^ { case udfName ~ exprs => - lexical.normalizeKeyword(udfName) match { - case "count" => - aggregate.Count(exprs).toAggregateExpression(isDistinct = true) - case _ => UnresolvedFunction(udfName, exprs, isDistinct = true) - } - } - | APPROXIMATE ~> ident ~ ("(" ~ DISTINCT ~> expression <~ ")") ^^ { case udfName ~ exp => - if (lexical.normalizeKeyword(udfName) == "count") { - AggregateExpression(new HyperLogLogPlusPlus(exp), mode = Complete, isDistinct = false) - } else { - throw new AnalysisException(s"invalid function approximate $udfName") - } - } - | APPROXIMATE ~> "(" ~> unsignedFloat ~ ")" ~ ident ~ "(" ~ DISTINCT ~ expression <~ ")" ^^ - { case s ~ _ ~ udfName ~ _ ~ _ ~ exp => - if (lexical.normalizeKeyword(udfName) == "count") { - AggregateExpression( - HyperLogLogPlusPlus(exp, s.toDouble, 0, 0), - mode = Complete, - isDistinct = false) - } else { - throw new AnalysisException(s"invalid function approximate($s) $udfName") - } - } - | CASE ~> whenThenElse ^^ - { case branches => CaseWhen.createFromParser(branches) } - | CASE ~> expression ~ whenThenElse ^^ - { case keyPart ~ branches => CaseKeyWhen(keyPart, branches) } - ) - - protected lazy val whenThenElse: Parser[List[Expression]] = - rep1(WHEN ~> expression ~ (THEN ~> expression)) ~ (ELSE ~> expression).? <~ END ^^ { - case altPart ~ elsePart => - altPart.flatMap { case whenExpr ~ thenExpr => - Seq(whenExpr, thenExpr) - } ++ elsePart - } - - protected lazy val cast: Parser[Expression] = - CAST ~ "(" ~> expression ~ (AS ~> dataType) <~ ")" ^^ { - case exp ~ t => Cast(exp, t) - } - - protected lazy val literal: Parser[Literal] = - ( numericLiteral - | booleanLiteral - | stringLit ^^ { case s => Literal.create(s, StringType) } - | intervalLiteral - | NULL ^^^ Literal.create(null, NullType) - ) - - protected lazy val booleanLiteral: Parser[Literal] = - ( TRUE ^^^ Literal.create(true, BooleanType) - | FALSE ^^^ Literal.create(false, BooleanType) - ) - - protected lazy val numericLiteral: Parser[Literal] = - ( integral ^^ { case i => Literal(toNarrowestIntegerType(i)) } - | sign.? ~ unsignedFloat ^^ - { case s ~ f => Literal(toDecimalOrDouble(s.getOrElse("") + f)) } - ) - - protected lazy val unsignedFloat: Parser[String] = - ( "." ~> numericLit ^^ { u => "0." + u } - | elem("decimal", _.isInstanceOf[lexical.DecimalLit]) ^^ (_.chars) - ) - - protected lazy val sign: Parser[String] = ("+" | "-") - - protected lazy val integral: Parser[String] = - sign.? ~ numericLit ^^ { case s ~ n => s.getOrElse("") + n } - - private def intervalUnit(unitName: String) = acceptIf { - case lexical.Identifier(str) => - val normalized = lexical.normalizeKeyword(str) - normalized == unitName || normalized == unitName + "s" - case _ => false - } {_ => "wrong interval unit"} - - protected lazy val month: Parser[Int] = - integral <~ intervalUnit("month") ^^ { case num => num.toInt } - - protected lazy val year: Parser[Int] = - integral <~ intervalUnit("year") ^^ { case num => num.toInt * 12 } - - protected lazy val microsecond: Parser[Long] = - integral <~ intervalUnit("microsecond") ^^ { case num => num.toLong } - - protected lazy val millisecond: Parser[Long] = - integral <~ intervalUnit("millisecond") ^^ { - case num => num.toLong * CalendarInterval.MICROS_PER_MILLI - } - - protected lazy val second: Parser[Long] = - integral <~ intervalUnit("second") ^^ { - case num => num.toLong * CalendarInterval.MICROS_PER_SECOND - } - - protected lazy val minute: Parser[Long] = - integral <~ intervalUnit("minute") ^^ { - case num => num.toLong * CalendarInterval.MICROS_PER_MINUTE - } - - protected lazy val hour: Parser[Long] = - integral <~ intervalUnit("hour") ^^ { - case num => num.toLong * CalendarInterval.MICROS_PER_HOUR - } - - protected lazy val day: Parser[Long] = - integral <~ intervalUnit("day") ^^ { - case num => num.toLong * CalendarInterval.MICROS_PER_DAY - } - - protected lazy val week: Parser[Long] = - integral <~ intervalUnit("week") ^^ { - case num => num.toLong * CalendarInterval.MICROS_PER_WEEK - } - - private def intervalKeyword(keyword: String) = acceptIf { - case lexical.Identifier(str) => - lexical.normalizeKeyword(str) == keyword - case _ => false - } {_ => "wrong interval keyword"} - - protected lazy val intervalLiteral: Parser[Literal] = - ( INTERVAL ~> stringLit <~ intervalKeyword("year") ~ intervalKeyword("to") ~ - intervalKeyword("month") ^^ { case s => - Literal(CalendarInterval.fromYearMonthString(s)) - } - | INTERVAL ~> stringLit <~ intervalKeyword("day") ~ intervalKeyword("to") ~ - intervalKeyword("second") ^^ { case s => - Literal(CalendarInterval.fromDayTimeString(s)) - } - | INTERVAL ~> stringLit <~ intervalKeyword("year") ^^ { case s => - Literal(CalendarInterval.fromSingleUnitString("year", s)) - } - | INTERVAL ~> stringLit <~ intervalKeyword("month") ^^ { case s => - Literal(CalendarInterval.fromSingleUnitString("month", s)) - } - | INTERVAL ~> stringLit <~ intervalKeyword("day") ^^ { case s => - Literal(CalendarInterval.fromSingleUnitString("day", s)) - } - | INTERVAL ~> stringLit <~ intervalKeyword("hour") ^^ { case s => - Literal(CalendarInterval.fromSingleUnitString("hour", s)) - } - | INTERVAL ~> stringLit <~ intervalKeyword("minute") ^^ { case s => - Literal(CalendarInterval.fromSingleUnitString("minute", s)) - } - | INTERVAL ~> stringLit <~ intervalKeyword("second") ^^ { case s => - Literal(CalendarInterval.fromSingleUnitString("second", s)) - } - | INTERVAL ~> year.? ~ month.? ~ week.? ~ day.? ~ hour.? ~ minute.? ~ second.? ~ - millisecond.? ~ microsecond.? ^^ { case year ~ month ~ week ~ day ~ hour ~ minute ~ second ~ - millisecond ~ microsecond => - if (!Seq(year, month, week, day, hour, minute, second, - millisecond, microsecond).exists(_.isDefined)) { - throw new AnalysisException( - "at least one time unit should be given for interval literal") - } - val months = Seq(year, month).map(_.getOrElse(0)).sum - val microseconds = Seq(week, day, hour, minute, second, millisecond, microsecond) - .map(_.getOrElse(0L)).sum - Literal(new CalendarInterval(months, microseconds)) - } - ) - - private def toNarrowestIntegerType(value: String): Any = { - val bigIntValue = BigDecimal(value) - - bigIntValue match { - case v if bigIntValue.isValidInt => v.toIntExact - case v if bigIntValue.isValidLong => v.toLongExact - case v => v.underlying() - } - } - - private def toDecimalOrDouble(value: String): Any = { - val decimal = BigDecimal(value) - // follow the behavior in MS SQL Server - // https://msdn.microsoft.com/en-us/library/ms179899.aspx - if (value.contains('E') || value.contains('e')) { - decimal.doubleValue() - } else { - decimal.underlying() - } - } - - protected lazy val baseExpression: Parser[Expression] = - ( "*" ^^^ UnresolvedStar(None) - | rep1(ident <~ ".") <~ "*" ^^ { case target => UnresolvedStar(Option(target))} - | primary - ) - - protected lazy val signedPrimary: Parser[Expression] = - sign ~ primary ^^ { case s ~ e => if (s == "-") UnaryMinus(e) else e } - - protected lazy val attributeName: Parser[String] = acceptMatch("attribute name", { - case lexical.Identifier(str) => str - case lexical.Keyword(str) if !lexical.delimiters.contains(str) => str - }) - - protected lazy val primary: PackratParser[Expression] = - ( literal - | expression ~ ("[" ~> expression <~ "]") ^^ - { case base ~ ordinal => UnresolvedExtractValue(base, ordinal) } - | (expression <~ ".") ~ ident ^^ - { case base ~ fieldName => UnresolvedExtractValue(base, Literal(fieldName)) } - | cast - | "(" ~> expression <~ ")" - | function - | dotExpressionHeader - | signedPrimary - | "~" ~> expression ^^ BitwiseNot - | attributeName ^^ UnresolvedAttribute.quoted - ) - - protected lazy val dotExpressionHeader: Parser[Expression] = - (ident <~ ".") ~ ident ~ rep("." ~> ident) ^^ { - case i1 ~ i2 ~ rest => UnresolvedAttribute(Seq(i1, i2) ++ rest) - } - - protected lazy val tableIdentifier: Parser[TableIdentifier] = - (ident <~ ".").? ~ ident ^^ { - case maybeDbName ~ tableName => TableIdentifier(tableName, maybeDbName) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala index e1fd22e36764e7b8575f6c16e1c4317c6cfa06e2..ec833d6789e85c48cdd8c36358ca474ced428115 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala @@ -447,6 +447,7 @@ object HyperLogLogPlusPlus { private def validateDoubleLiteral(exp: Expression): Double = exp match { case Literal(d: Double, DoubleType) => d + case Literal(dec: Decimal, _) => dec.toDouble case _ => throw new AnalysisException("The second argument should be a double literal.") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala index ba9d2524a95513021d0f405b468928d97bdb0931..6d25de98cebc4773874dc281a35d8bd5459f368d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala @@ -108,6 +108,7 @@ class CatalystQlSuite extends PlanTest { } assertRight("9.0e1", 90) + assertRight(".9e+2", 90) assertRight("0.9e+2", 90) assertRight("900e-1", 90) assertRight("900.0E-1", 90) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala deleted file mode 100644 index b0884f528742fd9d9d2f67fa0fbbd233c8484e33..0000000000000000000000000000000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala +++ /dev/null @@ -1,150 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst - -import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias -import org.apache.spark.sql.catalyst.expressions.{Attribute, GreaterThan, Literal, Not} -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan, OneRowRelation, Project} -import org.apache.spark.unsafe.types.CalendarInterval - -private[sql] case class TestCommand(cmd: String) extends LogicalPlan with Command { - override def output: Seq[Attribute] = Seq.empty - override def children: Seq[LogicalPlan] = Seq.empty -} - -private[sql] class SuperLongKeywordTestParser extends AbstractSparkSQLParser { - protected val EXECUTE = Keyword("THISISASUPERLONGKEYWORDTEST") - - override protected lazy val start: Parser[LogicalPlan] = set - - private lazy val set: Parser[LogicalPlan] = - EXECUTE ~> ident ^^ { - case fileName => TestCommand(fileName) - } -} - -private[sql] class CaseInsensitiveTestParser extends AbstractSparkSQLParser { - protected val EXECUTE = Keyword("EXECUTE") - - override protected lazy val start: Parser[LogicalPlan] = set - - private lazy val set: Parser[LogicalPlan] = - EXECUTE ~> ident ^^ { - case fileName => TestCommand(fileName) - } -} - -class SqlParserSuite extends PlanTest { - - test("test long keyword") { - val parser = new SuperLongKeywordTestParser - assert(TestCommand("NotRealCommand") === - parser.parse("ThisIsASuperLongKeyWordTest NotRealCommand")) - } - - test("test case insensitive") { - val parser = new CaseInsensitiveTestParser - assert(TestCommand("NotRealCommand") === parser.parse("EXECUTE NotRealCommand")) - assert(TestCommand("NotRealCommand") === parser.parse("execute NotRealCommand")) - assert(TestCommand("NotRealCommand") === parser.parse("exEcute NotRealCommand")) - } - - test("test NOT operator with comparison operations") { - val parsed = SqlParser.parse("SELECT NOT TRUE > TRUE") - val expected = Project( - UnresolvedAlias( - Not( - GreaterThan(Literal(true), Literal(true))) - ) :: Nil, - OneRowRelation) - comparePlans(parsed, expected) - } - - test("support hive interval literal") { - def checkInterval(sql: String, result: CalendarInterval): Unit = { - val parsed = SqlParser.parse(sql) - val expected = Project( - UnresolvedAlias( - Literal(result) - ) :: Nil, - OneRowRelation) - comparePlans(parsed, expected) - } - - def checkYearMonth(lit: String): Unit = { - checkInterval( - s"SELECT INTERVAL '$lit' YEAR TO MONTH", - CalendarInterval.fromYearMonthString(lit)) - } - - def checkDayTime(lit: String): Unit = { - checkInterval( - s"SELECT INTERVAL '$lit' DAY TO SECOND", - CalendarInterval.fromDayTimeString(lit)) - } - - def checkSingleUnit(lit: String, unit: String): Unit = { - checkInterval( - s"SELECT INTERVAL '$lit' $unit", - CalendarInterval.fromSingleUnitString(unit, lit)) - } - - checkYearMonth("123-10") - checkYearMonth("496-0") - checkYearMonth("-2-3") - checkYearMonth("-123-0") - - checkDayTime("99 11:22:33.123456789") - checkDayTime("-99 11:22:33.123456789") - checkDayTime("10 9:8:7.123456789") - checkDayTime("1 0:0:0") - checkDayTime("-1 0:0:0") - checkDayTime("1 0:0:1") - - for (unit <- Seq("year", "month", "day", "hour", "minute", "second")) { - checkSingleUnit("7", unit) - checkSingleUnit("-7", unit) - checkSingleUnit("0", unit) - } - - checkSingleUnit("13.123456789", "second") - checkSingleUnit("-13.123456789", "second") - } - - test("support scientific notation") { - def assertRight(input: String, output: Double): Unit = { - val parsed = SqlParser.parse("SELECT " + input) - val expected = Project( - UnresolvedAlias( - Literal(output) - ) :: Nil, - OneRowRelation) - comparePlans(parsed, expected) - } - - assertRight("9.0e1", 90) - assertRight(".9e+2", 90) - assertRight("0.9e+2", 90) - assertRight("900e-1", 90) - assertRight("900.0E-1", 90) - assertRight("9.e+1", 90) - - intercept[RuntimeException](SqlParser.parse("SELECT .e3")) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 6a020f9f2883e5c174ccad06338880e8fba01622..97bf7a0cc4514d7eb5babf99711f28f1266fd011 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -21,7 +21,6 @@ import scala.language.implicitConversions import org.apache.spark.Logging import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.SqlParser._ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 91bf2f8ce4d2f6da44db9960991f99afeaacfd91..3422d0ead4fc12b57ef0604b82070594e7411b87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -30,7 +30,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.python.PythonRDD import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection, SqlParser} +import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -737,7 +737,7 @@ class DataFrame private[sql]( @scala.annotation.varargs def selectExpr(exprs: String*): DataFrame = { select(exprs.map { expr => - Column(SqlParser.parseExpression(expr)) + Column(sqlContext.sqlParser.parseExpression(expr)) }: _*) } @@ -764,7 +764,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ def filter(conditionExpr: String): DataFrame = { - filter(Column(SqlParser.parseExpression(conditionExpr))) + filter(Column(sqlContext.sqlParser.parseExpression(conditionExpr))) } /** @@ -788,7 +788,7 @@ class DataFrame private[sql]( * @since 1.5.0 */ def where(conditionExpr: String): DataFrame = { - filter(Column(SqlParser.parseExpression(conditionExpr))) + filter(Column(sqlContext.sqlParser.parseExpression(conditionExpr))) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index d948e4894253c9f1f838d2b54e9fec2d79bab993..8f852e521668a8d8d7592e9b2d038fc2e51048c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -29,7 +29,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.SqlParser +import org.apache.spark.sql.catalyst.{CatalystQl} import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} import org.apache.spark.sql.execution.datasources.json.JSONRelation @@ -337,7 +337,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { */ def table(tableName: String): DataFrame = { DataFrame(sqlContext, - sqlContext.catalog.lookupRelation(SqlParser.parseTableIdentifier(tableName))) + sqlContext.catalog.lookupRelation(sqlContext.sqlParser.parseTableIdentifier(tableName))) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 00f9817b53976ae37ad74d59b5a06e6f43212a8f..ab63fe4aa88b732ac2f1939faad4f911372928aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -22,7 +22,7 @@ import java.util.Properties import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier} +import org.apache.spark.sql.catalyst.{CatalystQl, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Project} import org.apache.spark.sql.execution.datasources.{BucketSpec, CreateTableUsingAsSelect, ResolvedDataSource} @@ -192,7 +192,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def insertInto(tableName: String): Unit = { - insertInto(SqlParser.parseTableIdentifier(tableName)) + insertInto(df.sqlContext.sqlParser.parseTableIdentifier(tableName)) } private def insertInto(tableIdent: TableIdentifier): Unit = { @@ -282,7 +282,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def saveAsTable(tableName: String): Unit = { - saveAsTable(SqlParser.parseTableIdentifier(tableName)) + saveAsTable(df.sqlContext.sqlParser.parseTableIdentifier(tableName)) } private def saveAsTable(tableIdent: TableIdentifier): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index b909765a7c6dd92531c4aa95504352bbe9bb3f00..a0939adb6d5ae256f386f2f2d5fa62f1de5bea3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.Optimizer +import org.apache.spark.sql.catalyst.parser.ParserConf import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.execution._ @@ -205,15 +206,17 @@ class SQLContext private[sql]( protected[sql] lazy val optimizer: Optimizer = new SparkOptimizer(this) @transient - protected[sql] val ddlParser = new DDLParser(sqlParser.parse(_)) + protected[sql] val ddlParser = new DDLParser(sqlParser) @transient - protected[sql] val sqlParser = new SparkSQLParser(getSQLDialect().parse(_)) + protected[sql] val sqlParser = new SparkSQLParser(getSQLDialect()) protected[sql] def getSQLDialect(): ParserDialect = { try { val clazz = Utils.classForName(dialectClassName) - clazz.newInstance().asInstanceOf[ParserDialect] + clazz.getConstructor(classOf[ParserConf]) + .newInstance(conf) + .asInstanceOf[ParserDialect] } catch { case NonFatal(e) => // Since we didn't find the available SQL Dialect, it will fail even for SET command: @@ -237,7 +240,7 @@ class SQLContext private[sql]( new sparkexecution.QueryExecution(this, plan) protected[sql] def dialectClassName = if (conf.dialect == "sql") { - classOf[DefaultParserDialect].getCanonicalName + classOf[SparkQl].getCanonicalName } else { conf.dialect } @@ -682,7 +685,7 @@ class SQLContext private[sql]( tableName: String, source: String, options: Map[String, String]): DataFrame = { - val tableIdent = SqlParser.parseTableIdentifier(tableName) + val tableIdent = sqlParser.parseTableIdentifier(tableName) val cmd = CreateTableUsing( tableIdent, @@ -728,7 +731,7 @@ class SQLContext private[sql]( source: String, schema: StructType, options: Map[String, String]): DataFrame = { - val tableIdent = SqlParser.parseTableIdentifier(tableName) + val tableIdent = sqlParser.parseTableIdentifier(tableName) val cmd = CreateTableUsing( tableIdent, @@ -833,7 +836,7 @@ class SQLContext private[sql]( * @since 1.3.0 */ def table(tableName: String): DataFrame = { - table(SqlParser.parseTableIdentifier(tableName)) + table(sqlParser.parseTableIdentifier(tableName)) } private def table(tableIdent: TableIdentifier): DataFrame = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala index b3e8d0d84937e60530b19cc21642ae7ce9cbd999..1af2c756cdc5a0c26af454ee673466c387f05f0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.execution import scala.util.parsing.combinator.RegexParsers -import org.apache.spark.sql.catalyst.AbstractSparkSQLParser -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, ParserDialect, TableIdentifier} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types.StringType @@ -29,9 +29,16 @@ import org.apache.spark.sql.types.StringType * The top level Spark SQL parser. This parser recognizes syntaxes that are available for all SQL * dialects supported by Spark SQL, and delegates all the other syntaxes to the `fallback` parser. * - * @param fallback A function that parses an input string to a logical plan + * @param fallback A function that returns the next parser in the chain. This is a call-by-name + * parameter because this allows us to return a different dialect if we + * have to. */ -class SparkSQLParser(fallback: String => LogicalPlan) extends AbstractSparkSQLParser { +class SparkSQLParser(fallback: => ParserDialect) extends AbstractSparkSQLParser { + + override def parseExpression(sql: String): Expression = fallback.parseExpression(sql) + + override def parseTableIdentifier(sql: String): TableIdentifier = + fallback.parseTableIdentifier(sql) // A parser for the key-value part of the "SET [key = [value ]]" syntax private object SetCommandParser extends RegexParsers { @@ -74,7 +81,7 @@ class SparkSQLParser(fallback: String => LogicalPlan) extends AbstractSparkSQLPa private lazy val cache: Parser[LogicalPlan] = CACHE ~> LAZY.? ~ (TABLE ~> ident) ~ (AS ~> restInput).? ^^ { case isLazy ~ tableName ~ plan => - CacheTableCommand(tableName, plan.map(fallback), isLazy.isDefined) + CacheTableCommand(tableName, plan.map(fallback.parsePlan), isLazy.isDefined) } private lazy val uncache: Parser[LogicalPlan] = @@ -111,7 +118,7 @@ class SparkSQLParser(fallback: String => LogicalPlan) extends AbstractSparkSQLPa private lazy val others: Parser[LogicalPlan] = wholeInput ^^ { - case input => fallback(input) + case input => fallback.parsePlan(input) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala index d8d21b06b8b359e33c14c1054328fb0989d21079..10655a85ccf893e5d84ae29c504428c7cddc83d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala @@ -22,25 +22,30 @@ import scala.util.matching.Regex import org.apache.spark.Logging import org.apache.spark.sql.SaveMode -import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, TableIdentifier} +import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, ParserDialect, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.DataTypeParser import org.apache.spark.sql.types._ - /** * A parser for foreign DDL commands. */ -class DDLParser(parseQuery: String => LogicalPlan) +class DDLParser(fallback: => ParserDialect) extends AbstractSparkSQLParser with DataTypeParser with Logging { + override def parseExpression(sql: String): Expression = fallback.parseExpression(sql) + + override def parseTableIdentifier(sql: String): TableIdentifier = + + fallback.parseTableIdentifier(sql) def parse(input: String, exceptionOnError: Boolean): LogicalPlan = { try { - parse(input) + parsePlan(input) } catch { case ddlException: DDLException => throw ddlException - case _ if !exceptionOnError => parseQuery(input) + case _ if !exceptionOnError => fallback.parsePlan(input) case x: Throwable => throw x } } @@ -104,7 +109,7 @@ class DDLParser(parseQuery: String => LogicalPlan) SaveMode.ErrorIfExists } - val queryPlan = parseQuery(query.get) + val queryPlan = fallback.parsePlan(query.get) CreateTableUsingAsSelect(tableIdent, provider, temp.isDefined, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index b8ea2261e94e263644bedd819be37fd1930eae76..8c2530fd684a744495a1b27fac087232c1cee871 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -22,7 +22,7 @@ import scala.reflect.runtime.universe.{typeTag, TypeTag} import scala.util.Try import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.{ScalaReflection, SqlParser} +import org.apache.spark.sql.catalyst.{CatalystQl, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ @@ -1063,7 +1063,10 @@ object functions extends LegacyFunctions { * * @group normal_funcs */ - def expr(expr: String): Column = Column(SqlParser.parseExpression(expr)) + def expr(expr: String): Column = { + val parser = SQLContext.getActive().map(_.getSQLDialect()).getOrElse(new CatalystQl()) + Column(parser.parseExpression(expr)) + } ////////////////////////////////////////////////////////////////////////////////////////////// // Math Functions diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 58f982c2bc932ebdb722dc039ee82f15dad382af..aec450e0a6084f03f04e2996421c612287031ba1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -212,7 +212,7 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext { Seq(Row(5, 10, 0), Row(55, 60, 100), Row(555, 560, 600)) ) - val pi = 3.1415 + val pi = "3.1415BD" checkAnswer( sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " + s"round($pi, 0), round($pi, 1), round($pi, 2), round($pi, 3)"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 03d67c4e91f7f8f983c14fea30902af74441def8..75e81b9c9174df0c998d6c4e82f96073c8b978ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -21,10 +21,11 @@ import java.math.MathContext import java.sql.Timestamp import org.apache.spark.AccumulatorSuite -import org.apache.spark.sql.catalyst.DefaultParserDialect +import org.apache.spark.sql.catalyst.CatalystQl import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.errors.DialectException -import org.apache.spark.sql.execution.aggregate +import org.apache.spark.sql.catalyst.parser.ParserConf +import org.apache.spark.sql.execution.{aggregate, SparkQl} import org.apache.spark.sql.execution.joins.{CartesianProduct, SortMergeJoin} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext} @@ -32,7 +33,7 @@ import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types._ /** A SQL Dialect for testing purpose, and it can not be nested type */ -class MyDialect extends DefaultParserDialect +class MyDialect(conf: ParserConf) extends CatalystQl(conf) class SQLQuerySuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -161,7 +162,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { newContext.sql("SELECT 1") } // test if the dialect set back to DefaultSQLDialect - assert(newContext.getSQLDialect().getClass === classOf[DefaultParserDialect]) + assert(newContext.getSQLDialect().getClass === classOf[SparkQl]) } test("SPARK-4625 support SORT BY in SimpleSQLParser & DSL") { @@ -586,7 +587,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("Allow only a single WITH clause per query") { - intercept[RuntimeException] { + intercept[AnalysisException] { sql( "with q1 as (select * from testData) with q2 as (select * from q1) select * from q2") } @@ -602,8 +603,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("from follow multiple brackets") { checkAnswer(sql( """ - |select key from ((select * from testData limit 1) - | union all (select * from testData limit 1)) x limit 1 + |select key from ((select * from testData) + | union all (select * from testData)) x limit 1 """.stripMargin), Row(1) ) @@ -616,7 +617,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(sql( """ |select key from - | (select * from testData limit 1 union all select * from testData limit 1) x + | (select * from testData union all select * from testData) x | limit 1 """.stripMargin), Row(1) @@ -649,13 +650,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("approximate count distinct") { checkAnswer( - sql("SELECT APPROXIMATE COUNT(DISTINCT a) FROM testData2"), + sql("SELECT APPROX_COUNT_DISTINCT(a) FROM testData2"), Row(3)) } test("approximate count distinct with user provided standard deviation") { checkAnswer( - sql("SELECT APPROXIMATE(0.04) COUNT(DISTINCT a) FROM testData2"), + sql("SELECT APPROX_COUNT_DISTINCT(a, 0.04) FROM testData2"), Row(3)) } @@ -1192,19 +1193,19 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("Floating point number format") { checkAnswer( - sql("SELECT 0.3"), Row(BigDecimal(0.3).underlying()) + sql("SELECT 0.3"), Row(0.3) ) checkAnswer( - sql("SELECT -0.8"), Row(BigDecimal(-0.8).underlying()) + sql("SELECT -0.8"), Row(-0.8) ) checkAnswer( - sql("SELECT .5"), Row(BigDecimal(0.5)) + sql("SELECT .5"), Row(0.5) ) checkAnswer( - sql("SELECT -.18"), Row(BigDecimal(-0.18)) + sql("SELECT -.18"), Row(-0.18) ) } @@ -1218,11 +1219,11 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { ) checkAnswer( - sql("SELECT 9223372036854775808"), Row(new java.math.BigDecimal("9223372036854775808")) + sql("SELECT 9223372036854775808BD"), Row(new java.math.BigDecimal("9223372036854775808")) ) checkAnswer( - sql("SELECT -9223372036854775809"), Row(new java.math.BigDecimal("-9223372036854775809")) + sql("SELECT -9223372036854775809BD"), Row(new java.math.BigDecimal("-9223372036854775809")) ) } @@ -1237,11 +1238,11 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { ) checkAnswer( - sql("SELECT -5.2"), Row(BigDecimal(-5.2)) + sql("SELECT -5.2BD"), Row(BigDecimal(-5.2)) ) checkAnswer( - sql("SELECT +6.8"), Row(BigDecimal(6.8)) + sql("SELECT +6.8"), Row(6.8d) ) checkAnswer( @@ -1616,20 +1617,20 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("decimal precision with multiply/division") { - checkAnswer(sql("select 10.3 * 3.0"), Row(BigDecimal("30.90"))) - checkAnswer(sql("select 10.3000 * 3.0"), Row(BigDecimal("30.90000"))) - checkAnswer(sql("select 10.30000 * 30.0"), Row(BigDecimal("309.000000"))) - checkAnswer(sql("select 10.300000000000000000 * 3.000000000000000000"), + checkAnswer(sql("select 10.3BD * 3.0BD"), Row(BigDecimal("30.90"))) + checkAnswer(sql("select 10.3000BD * 3.0BD"), Row(BigDecimal("30.90000"))) + checkAnswer(sql("select 10.30000BD * 30.0BD"), Row(BigDecimal("309.000000"))) + checkAnswer(sql("select 10.300000000000000000BD * 3.000000000000000000BD"), Row(BigDecimal("30.900000000000000000000000000000000000", new MathContext(38)))) - checkAnswer(sql("select 10.300000000000000000 * 3.0000000000000000000"), + checkAnswer(sql("select 10.300000000000000000BD * 3.0000000000000000000BD"), Row(null)) - checkAnswer(sql("select 10.3 / 3.0"), Row(BigDecimal("3.433333"))) - checkAnswer(sql("select 10.3000 / 3.0"), Row(BigDecimal("3.4333333"))) - checkAnswer(sql("select 10.30000 / 30.0"), Row(BigDecimal("0.343333333"))) - checkAnswer(sql("select 10.300000000000000000 / 3.00000000000000000"), + checkAnswer(sql("select 10.3BD / 3.0BD"), Row(BigDecimal("3.433333"))) + checkAnswer(sql("select 10.3000BD / 3.0BD"), Row(BigDecimal("3.4333333"))) + checkAnswer(sql("select 10.30000BD / 30.0BD"), Row(BigDecimal("0.343333333"))) + checkAnswer(sql("select 10.300000000000000000BD / 3.00000000000000000BD"), Row(BigDecimal("3.433333333333333333333333333", new MathContext(38)))) - checkAnswer(sql("select 10.3000000000000000000 / 3.00000000000000000"), + checkAnswer(sql("select 10.3000000000000000000BD / 3.00000000000000000BD"), Row(BigDecimal("3.4333333333333333333333333333", new MathContext(38)))) } @@ -1655,13 +1656,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("precision smaller than scale") { - checkAnswer(sql("select 10.00"), Row(BigDecimal("10.00"))) - checkAnswer(sql("select 1.00"), Row(BigDecimal("1.00"))) - checkAnswer(sql("select 0.10"), Row(BigDecimal("0.10"))) - checkAnswer(sql("select 0.01"), Row(BigDecimal("0.01"))) - checkAnswer(sql("select 0.001"), Row(BigDecimal("0.001"))) - checkAnswer(sql("select -0.01"), Row(BigDecimal("-0.01"))) - checkAnswer(sql("select -0.001"), Row(BigDecimal("-0.001"))) + checkAnswer(sql("select 10.00BD"), Row(BigDecimal("10.00"))) + checkAnswer(sql("select 1.00BD"), Row(BigDecimal("1.00"))) + checkAnswer(sql("select 0.10BD"), Row(BigDecimal("0.10"))) + checkAnswer(sql("select 0.01BD"), Row(BigDecimal("0.01"))) + checkAnswer(sql("select 0.001BD"), Row(BigDecimal("0.001"))) + checkAnswer(sql("select -0.01BD"), Row(BigDecimal("-0.01"))) + checkAnswer(sql("select -0.001BD"), Row(BigDecimal("-0.001"))) } test("external sorting updates peak execution memory") { @@ -1750,7 +1751,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { assert(e1.message.contains("Table not found")) val e2 = intercept[AnalysisException] { - sql("select * from no_db.no_table") + sql("select * from no_db.no_table").show() } assert(e2.message.contains("Table not found")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 860e07c68cef133f0ed53e4f31ec44bd72f97c97..e70eb2a060309eed9264062b29d28283a110ffec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -442,13 +442,13 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { // Number and String conflict: resolve the type as number in this query. checkAnswer( - sql("select num_str + 1.2 from jsonTable where num_str > 14"), + sql("select num_str + 1.2BD from jsonTable where num_str > 14"), Row(BigDecimal("92233720368547758071.2")) ) // Number and String conflict: resolve the type as number in this query. checkAnswer( - sql("select num_str + 1.2 from jsonTable where num_str >= 92233720368547758060"), + sql("select num_str + 1.2BD from jsonTable where num_str >= 92233720368547758060BD"), Row(new java.math.BigDecimal("92233720368547758071.2")) ) @@ -856,7 +856,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { jsonWithSimpleMap.registerTempTable("jsonWithSimpleMap") checkAnswer( - sql("select map from jsonWithSimpleMap"), + sql("select `map` from jsonWithSimpleMap"), Row(Map("a" -> 1)) :: Row(Map("b" -> 2)) :: Row(Map("c" -> 3)) :: @@ -865,7 +865,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) checkAnswer( - sql("select map['c'] from jsonWithSimpleMap"), + sql("select `map`['c'] from jsonWithSimpleMap"), Row(null) :: Row(null) :: Row(3) :: @@ -884,7 +884,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { jsonWithComplexMap.registerTempTable("jsonWithComplexMap") checkAnswer( - sql("select map from jsonWithComplexMap"), + sql("select `map` from jsonWithComplexMap"), Row(Map("a" -> Row(Seq(1, 2, 3, null), null))) :: Row(Map("b" -> Row(null, 2))) :: Row(Map("c" -> Row(Seq(), 4))) :: @@ -894,7 +894,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) checkAnswer( - sql("select map['a'].field1, map['c'].field2 from jsonWithComplexMap"), + sql("select `map`['a'].field1, `map`['c'].field2 from jsonWithComplexMap"), Row(Seq(1, 2, 3, null), null) :: Row(null, null) :: Row(null, 4) :: diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index afd2f611580fc855f861c664c84227e64c678d60..828ec9710550cc349835434fe547b9226cf0adb2 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -296,6 +296,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // Odd changes to output "merge4", + // Unsupported underscore syntax. + "inputddl5", + // Thift is broken... "inputddl8", @@ -603,7 +606,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "inputddl2", "inputddl3", "inputddl4", - "inputddl5", "inputddl6", "inputddl7", "inputddl8", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala index b22f4249813254ce37b5e8fc6ccce71df685814a..313ba18f6aef0f22e1561edfab30a9ad55449064 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala @@ -19,14 +19,23 @@ package org.apache.spark.sql.hive import scala.language.implicitConversions +import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, TableIdentifier} +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.AbstractSparkSQLParser import org.apache.spark.sql.hive.execution.{AddFile, AddJar, HiveNativeCommand} /** * A parser that recognizes all HiveQL constructs together with Spark SQL specific extensions. */ -private[hive] class ExtendedHiveQlParser extends AbstractSparkSQLParser { +private[hive] class ExtendedHiveQlParser(sqlContext: HiveContext) extends AbstractSparkSQLParser { + + val parser = new HiveQl(sqlContext.conf) + + override def parseExpression(sql: String): Expression = parser.parseExpression(sql) + + override def parseTableIdentifier(sql: String): TableIdentifier = + parser.parseTableIdentifier(sql) + // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` // properties via reflection the class in runtime for constructing the SqlLexical object protected val ADD = Keyword("ADD") @@ -38,7 +47,10 @@ private[hive] class ExtendedHiveQlParser extends AbstractSparkSQLParser { protected lazy val hiveQl: Parser[LogicalPlan] = restInput ^^ { - case statement => HiveQl.parsePlan(statement.trim) + case statement => + sqlContext.executionHive.withHiveState { + parser.parsePlan(statement.trim) + } } protected lazy val dfs: Parser[LogicalPlan] = diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index cbaf00603e18908f0f4e9cc6cd08ba64cc7e8c23..7bdca522003055f532c5217e94f0c2944cc7d7d1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -42,7 +42,7 @@ import org.apache.spark.api.java.JavaSparkContext import org.apache.spark.sql._ import org.apache.spark.sql.SQLConf.SQLConfEntry import org.apache.spark.sql.SQLConf.SQLConfEntry._ -import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, SqlParser} +import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback @@ -56,17 +56,6 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils -/** - * This is the HiveQL Dialect, this dialect is strongly bind with HiveContext - */ -private[hive] class HiveQLDialect(sqlContext: HiveContext) extends ParserDialect { - override def parse(sqlText: String): LogicalPlan = { - sqlContext.executionHive.withHiveState { - HiveQl.parseSql(sqlText) - } - } -} - /** * Returns the current database of metadataHive. */ @@ -342,12 +331,12 @@ class HiveContext private[hive]( * @since 1.3.0 */ def refreshTable(tableName: String): Unit = { - val tableIdent = SqlParser.parseTableIdentifier(tableName) + val tableIdent = sqlParser.parseTableIdentifier(tableName) catalog.refreshTable(tableIdent) } protected[hive] def invalidateTable(tableName: String): Unit = { - val tableIdent = SqlParser.parseTableIdentifier(tableName) + val tableIdent = sqlParser.parseTableIdentifier(tableName) catalog.invalidateTable(tableIdent) } @@ -361,7 +350,7 @@ class HiveContext private[hive]( * @since 1.2.0 */ def analyze(tableName: String) { - val tableIdent = SqlParser.parseTableIdentifier(tableName) + val tableIdent = sqlParser.parseTableIdentifier(tableName) val relation = EliminateSubQueries(catalog.lookupRelation(tableIdent)) relation match { @@ -559,7 +548,7 @@ class HiveContext private[hive]( protected[sql] override def getSQLDialect(): ParserDialect = { if (conf.dialect == "hiveql") { - new HiveQLDialect(this) + new ExtendedHiveQlParser(this) } else { super.getSQLDialect() } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index daaa5a5709bdc68b7fff4c1179f685542aeea7ee..3d54048c24782c33d0f972bdaa3b53f1be7b9b17 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -416,8 +416,8 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive alias match { // because hive use things like `_c0` to build the expanded text // currently we cannot support view from "create view v1(c1) as ..." - case None => Subquery(table.name, HiveQl.parsePlan(viewText)) - case Some(aliasText) => Subquery(aliasText, HiveQl.parsePlan(viewText)) + case None => Subquery(table.name, hive.parseSql(viewText)) + case Some(aliasText) => Subquery(aliasText, hive.parseSql(viewText)) } } else { MetastoreRelation(qualifiedTableName.database, qualifiedTableName.name, alias)(table)(hive) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index ca9ddf94c11a70be6dbefbe19d297f2b85bcfe1e..46246f8191db10aed0b1e28699197b8523362020 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -79,7 +79,7 @@ private[hive] case class CreateViewAsSelect( } /** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */ -private[hive] object HiveQl extends SparkQl with Logging { +private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging { protected val nativeCommands = Seq( "TOK_ALTERDATABASE_OWNER", "TOK_ALTERDATABASE_PROPERTIES", @@ -168,8 +168,6 @@ private[hive] object HiveQl extends SparkQl with Logging { "TOK_TRUNCATETABLE" // truncate table" is a NativeCommand, does not need to explain. ) ++ nativeCommands - protected val hqlParser = new ExtendedHiveQlParser - /** * Returns the HiveConf */ @@ -186,9 +184,6 @@ private[hive] object HiveQl extends SparkQl with Logging { ss.getConf } - /** Returns a LogicalPlan for a given HiveQL string. */ - def parseSql(sql: String): LogicalPlan = hqlParser.parse(sql) - protected def getProperties(node: ASTNode): Seq[(String, String)] = node match { case Token("TOK_TABLEPROPLIST", list) => list.map { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala index 53d15c14cb3d5ba562fcce9a3c5418bffcaa663b..137dadd6c6bb365d8e2f28f9bb30f68325735153 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala @@ -23,12 +23,15 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.JsonTuple +import org.apache.spark.sql.catalyst.parser.SimpleParserConf import org.apache.spark.sql.catalyst.plans.logical.Generate import org.apache.spark.sql.hive.client.{ExternalTable, HiveColumn, HiveTable, ManagedTable} class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { + val parser = new HiveQl(SimpleParserConf()) + private def extractTableDesc(sql: String): (HiveTable, Boolean) = { - HiveQl.parsePlan(sql).collect { + parser.parsePlan(sql).collect { case CreateTableAsSelect(desc, child, allowExisting) => (desc, allowExisting) }.head } @@ -173,7 +176,7 @@ class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { test("Invalid interval term should throw AnalysisException") { def assertError(sql: String, errorMessage: String): Unit = { val e = intercept[AnalysisException] { - HiveQl.parseSql(sql) + parser.parsePlan(sql) } assert(e.getMessage.contains(errorMessage)) } @@ -186,7 +189,7 @@ class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { } test("use native json_tuple instead of hive's UDTF in LATERAL VIEW") { - val plan = HiveQl.parseSql( + val plan = parser.parsePlan( """ |SELECT * |FROM (SELECT '{"f1": "value1", "f2": 12}' json) test diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 78f74cdc19ddbfb2d2539fbfaa7e86b9eab6ef7d..91bedf9c5af5af93f4320a3c850333cf444312dc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -21,6 +21,7 @@ import scala.reflect.ClassTag import org.apache.spark.sql.{QueryTest, Row, SQLConf} import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.parser.SimpleParserConf import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.execution._ import org.apache.spark.sql.hive.test.TestHiveSingleton @@ -28,9 +29,11 @@ import org.apache.spark.sql.hive.test.TestHiveSingleton class StatisticsSuite extends QueryTest with TestHiveSingleton { import hiveContext.sql + val parser = new HiveQl(SimpleParserConf()) + test("parse analyze commands") { def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) { - val parsed = HiveQl.parseSql(analyzeCommand) + val parsed = parser.parsePlan(analyzeCommand) val operators = parsed.collect { case a: AnalyzeTable => a case o => o diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index f6c687aab7a1beb2826691631eeb3dfe804835f0..61d5aa7ae6b31da6b2c8bdbf291bd440f17fbbc9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -22,12 +22,14 @@ import java.sql.{Date, Timestamp} import scala.collection.JavaConverters._ import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{DefaultParserDialect, TableIdentifier} +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, FunctionRegistry} import org.apache.spark.sql.catalyst.errors.DialectException +import org.apache.spark.sql.catalyst.parser.ParserConf +import org.apache.spark.sql.execution.SparkQl import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation -import org.apache.spark.sql.hive.{HiveContext, HiveQLDialect, MetastoreRelation} +import org.apache.spark.sql.hive.{ExtendedHiveQlParser, HiveContext, HiveQl, MetastoreRelation} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ @@ -56,7 +58,7 @@ case class WindowData( area: String, product: Int) /** A SQL Dialect for testing purpose, and it can not be nested type */ -class MyDialect extends DefaultParserDialect +class MyDialect(conf: ParserConf) extends HiveQl(conf) /** * A collection of hive query tests where we generate the answers ourselves instead of depending on @@ -339,20 +341,20 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { val hiveContext = new HiveContext(sqlContext.sparkContext) val dialectConf = "spark.sql.dialect" checkAnswer(hiveContext.sql(s"set $dialectConf"), Row(dialectConf, "hiveql")) - assert(hiveContext.getSQLDialect().getClass === classOf[HiveQLDialect]) + assert(hiveContext.getSQLDialect().getClass === classOf[ExtendedHiveQlParser]) } test("SQL Dialect Switching") { - assert(getSQLDialect().getClass === classOf[HiveQLDialect]) + assert(getSQLDialect().getClass === classOf[ExtendedHiveQlParser]) setConf("spark.sql.dialect", classOf[MyDialect].getCanonicalName()) assert(getSQLDialect().getClass === classOf[MyDialect]) assert(sql("SELECT 1").collect() === Array(Row(1))) // set the dialect back to the DefaultSQLDialect sql("SET spark.sql.dialect=sql") - assert(getSQLDialect().getClass === classOf[DefaultParserDialect]) + assert(getSQLDialect().getClass === classOf[SparkQl]) sql("SET spark.sql.dialect=hiveql") - assert(getSQLDialect().getClass === classOf[HiveQLDialect]) + assert(getSQLDialect().getClass === classOf[ExtendedHiveQlParser]) // set invalid dialect sql("SET spark.sql.dialect.abc=MyTestClass") @@ -361,14 +363,14 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { sql("SELECT 1") } // test if the dialect set back to HiveQLDialect - getSQLDialect().getClass === classOf[HiveQLDialect] + getSQLDialect().getClass === classOf[ExtendedHiveQlParser] sql("SET spark.sql.dialect=MyTestClass") intercept[DialectException] { sql("SELECT 1") } // test if the dialect set back to HiveQLDialect - assert(getSQLDialect().getClass === classOf[HiveQLDialect]) + assert(getSQLDialect().getClass === classOf[ExtendedHiveQlParser]) } test("CTAS with serde") { diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java index 30e1758076361cf451119324d3ecf188faba87a3..62edf6c64bbc77158020d8dcce30752444ae00c4 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java @@ -188,6 +188,11 @@ public final class CalendarInterval implements Serializable { Integer.MIN_VALUE, Integer.MAX_VALUE); result = new CalendarInterval(month, 0L); + } else if (unit.equals("week")) { + long week = toLongWithRange("week", m.group(1), + Long.MIN_VALUE / MICROS_PER_WEEK, Long.MAX_VALUE / MICROS_PER_WEEK); + result = new CalendarInterval(0, week * MICROS_PER_WEEK); + } else if (unit.equals("day")) { long day = toLongWithRange("day", m.group(1), Long.MIN_VALUE / MICROS_PER_DAY, Long.MAX_VALUE / MICROS_PER_DAY); @@ -206,6 +211,15 @@ public final class CalendarInterval implements Serializable { } else if (unit.equals("second")) { long micros = parseSecondNano(m.group(1)); result = new CalendarInterval(0, micros); + + } else if (unit.equals("millisecond")) { + long millisecond = toLongWithRange("millisecond", m.group(1), + Long.MIN_VALUE / MICROS_PER_MILLI, Long.MAX_VALUE / MICROS_PER_MILLI); + result = new CalendarInterval(0, millisecond * MICROS_PER_MILLI); + + } else if (unit.equals("microsecond")) { + long micros = Long.valueOf(m.group(1)); + result = new CalendarInterval(0, micros); } } catch (Exception e) { throw new IllegalArgumentException("Error parsing interval string: " + e.getMessage(), e);