diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 9bea990fcee4e8d0e13d6d2f4eccb3fbe4118df7..4e967713ede64c0267105042c887400022344a6c 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -303,33 +303,73 @@ class SqlParser extends AbstractSparkSQLParser { CAST ~ "(" ~> expression ~ (AS ~> dataType) <~ ")" ^^ { case exp ~ t => Cast(exp, t) } protected lazy val literal: Parser[Literal] = - ( numericLit ^^ { - case i if i.toLong > Int.MaxValue => Literal(i.toLong) - case i => Literal(i.toInt) - } - | NULL ^^^ Literal(null, NullType) - | floatLit ^^ {case f => Literal(f.toDouble) } + ( numericLiteral + | booleanLiteral | stringLit ^^ {case s => Literal(s, StringType) } + | NULL ^^^ Literal(null, NullType) + ) + + protected lazy val booleanLiteral: Parser[Literal] = + ( TRUE ^^^ Literal(true, BooleanType) + | FALSE ^^^ Literal(false, BooleanType) + ) + + protected lazy val numericLiteral: Parser[Literal] = + signedNumericLiteral | unsignedNumericLiteral + + protected lazy val sign: Parser[String] = + "+" | "-" + + protected lazy val signedNumericLiteral: Parser[Literal] = + ( sign ~ numericLit ^^ { case s ~ l => Literal(toNarrowestIntegerType(s + l)) } + | sign ~ floatLit ^^ { case s ~ f => Literal((s + f).toDouble) } + ) + + protected lazy val unsignedNumericLiteral: Parser[Literal] = + ( numericLit ^^ { n => Literal(toNarrowestIntegerType(n)) } + | floatLit ^^ { f => Literal(f.toDouble) } ) + private val longMax = BigDecimal(s"${Long.MaxValue}") + private val longMin = BigDecimal(s"${Long.MinValue}") + private val intMax = BigDecimal(s"${Int.MaxValue}") + private val intMin = BigDecimal(s"${Int.MinValue}") + + private def toNarrowestIntegerType(value: String) = { + val bigIntValue = BigDecimal(value) + + bigIntValue match { + case v if v < longMin || v > longMax => v + case v if v < intMin || v > intMax => v.toLong + case v => v.toInt + } + } + protected lazy val floatLit: Parser[String] = - elem("decimal", _.isInstanceOf[lexical.FloatLit]) ^^ (_.chars) + ( "." ~> unsignedNumericLiteral ^^ { u => "0." + u } + | elem("decimal", _.isInstanceOf[lexical.FloatLit]) ^^ (_.chars) + ) + + protected lazy val baseExpression: Parser[Expression] = + ( "*" ^^^ Star(None) + | primary + ) - protected lazy val baseExpression: PackratParser[Expression] = - ( expression ~ ("[" ~> expression <~ "]") ^^ + protected lazy val signedPrimary: Parser[Expression] = + sign ~ primary ^^ { case s ~ e => if (s == "-") UnaryMinus(e) else e} + + protected lazy val primary: PackratParser[Expression] = + ( literal + | expression ~ ("[" ~> expression <~ "]") ^^ { case base ~ ordinal => GetItem(base, ordinal) } | (expression <~ ".") ~ ident ^^ { case base ~ fieldName => GetField(base, fieldName) } - | TRUE ^^^ Literal(true, BooleanType) - | FALSE ^^^ Literal(false, BooleanType) | cast | "(" ~> expression <~ ")" | function - | "-" ~> literal ^^ UnaryMinus | dotExpressionHeader | ident ^^ UnresolvedAttribute - | "*" ^^^ Star(None) - | literal + | signedPrimary ) protected lazy val dotExpressionHeader: Parser[Expression] = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index ffb504b94992f4124c503b5fe79bee500f5615a4..12e1cfc1cb7eb651391c344bc500099120a67d9d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -738,6 +738,135 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { checkAggregation("SELECT key + 1 + 1, COUNT(*) FROM testData GROUP BY key + 1", false) } + test("Test to check we can use Long.MinValue") { + checkAnswer( + sql(s"SELECT ${Long.MinValue} FROM testData ORDER BY key LIMIT 1"), Long.MinValue + ) + + checkAnswer( + sql(s"SELECT key FROM testData WHERE key > ${Long.MinValue}"), (1 to 100).map(Row(_)).toSeq + ) + } + + test("Floating point number format") { + checkAnswer( + sql("SELECT 0.3"), 0.3 + ) + + checkAnswer( + sql("SELECT -0.8"), -0.8 + ) + + checkAnswer( + sql("SELECT .5"), 0.5 + ) + + checkAnswer( + sql("SELECT -.18"), -0.18 + ) + } + + test("Auto cast integer type") { + checkAnswer( + sql(s"SELECT ${Int.MaxValue + 1L}"), Int.MaxValue + 1L + ) + + checkAnswer( + sql(s"SELECT ${Int.MinValue - 1L}"), Int.MinValue - 1L + ) + + checkAnswer( + sql("SELECT 9223372036854775808"), BigDecimal("9223372036854775808") + ) + + checkAnswer( + sql("SELECT -9223372036854775809"), BigDecimal("-9223372036854775809") + ) + } + + test("Test to check we can apply sign to expression") { + + checkAnswer( + sql("SELECT -100"), -100 + ) + + checkAnswer( + sql("SELECT +230"), 230 + ) + + checkAnswer( + sql("SELECT -5.2"), -5.2 + ) + + checkAnswer( + sql("SELECT +6.8"), 6.8 + ) + + checkAnswer( + sql("SELECT -key FROM testData WHERE key = 2"), -2 + ) + + checkAnswer( + sql("SELECT +key FROM testData WHERE key = 3"), 3 + ) + + checkAnswer( + sql("SELECT -(key + 1) FROM testData WHERE key = 1"), -2 + ) + + checkAnswer( + sql("SELECT - key + 1 FROM testData WHERE key = 10"), -9 + ) + + checkAnswer( + sql("SELECT +(key + 5) FROM testData WHERE key = 5"), 10 + ) + + checkAnswer( + sql("SELECT -MAX(key) FROM testData"), -100 + ) + + checkAnswer( + sql("SELECT +MAX(key) FROM testData"), 100 + ) + + checkAnswer( + sql("SELECT - (-10)"), 10 + ) + + checkAnswer( + sql("SELECT + (-key) FROM testData WHERE key = 32"), -32 + ) + + checkAnswer( + sql("SELECT - (+Max(key)) FROM testData"), -100 + ) + + checkAnswer( + sql("SELECT - - 3"), 3 + ) + + checkAnswer( + sql("SELECT - + 20"), -20 + ) + + checkAnswer( + sql("SELEcT - + 45"), -45 + ) + + checkAnswer( + sql("SELECT + + 100"), 100 + ) + + checkAnswer( + sql("SELECT - - Max(key) FROM testData"), 100 + ) + + checkAnswer( + sql("SELECT + - key FROM testData WHERE key = 33"), -33 + ) + } + test("Multiple join") { checkAnswer( sql( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 41927e83698a5bf952b250c687bc480b833ff667..1ae75546aada14b8b11705c0e018bdf30a42b4fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -380,6 +380,12 @@ class JsonSuite extends QueryTest { 92233720368547758071.2 ) + // Number and String conflict: resolve the type as number in this query. + checkAnswer( + sql("select num_str + 1.2 from jsonTable where num_str > 92233720368547758060"), + BigDecimal("92233720368547758061.2").toDouble + ) + // String and Boolean conflict: resolve the type as string. checkAnswer( sql("select * from jsonTable where str_bool = 'str1'"), @@ -415,13 +421,6 @@ class JsonSuite extends QueryTest { false ) - // Right now, we have a parsing error. - // Number and String conflict: resolve the type as number in this query. - checkAnswer( - sql("select num_str + 1.2 from jsonTable where num_str > 92233720368547758060"), - BigDecimal("92233720368547758061.2") - ) - // The plan of the following DSL is // Project [(CAST(num_str#65:4, DoubleType) + 1.2) AS num#78] // Filter (CAST(CAST(num_str#65:4, DoubleType), DecimalType) > 92233720368547758060)