diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index b423f0fa04f6978b09dadd5f05aba6136555b239..e5f115f74bf3b96974f86caa77d7c6ff60d42e9a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -332,8 +332,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val numericLiteral: Parser[Literal] = ( integral ^^ { case i => Literal(toNarrowestIntegerType(i)) } | sign.? ~ unsignedFloat ^^ { - // TODO(davies): some precisions may loss, we should create decimal literal - case s ~ f => Literal(BigDecimal(s.getOrElse("") + f).doubleValue()) + case s ~ f => Literal(toDecimalOrDouble(s.getOrElse("") + f)) } ) @@ -420,6 +419,17 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { } } + private def toDecimalOrDouble(value: String): Any = { + val decimal = BigDecimal(value) + // follow the behavior in MS SQL Server + // https://msdn.microsoft.com/en-us/library/ms179899.aspx + if (value.contains('E') || value.contains('e')) { + decimal.doubleValue() + } else { + decimal.underlying() + } + } + protected lazy val baseExpression: Parser[Expression] = ( "*" ^^^ UnresolvedStar(None) | ident <~ "." ~ "*" ^^ { case tableName => UnresolvedStar(Option(tableName)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index e0527503442f04d2c897e9cf8684594a069ca5fa..ecc48986e35d8d3de3c59c33fae76ab7c169742a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -109,13 +109,35 @@ object HiveTypeCoercion { * Find the tightest common type of a set of types by continuously applying * `findTightestCommonTypeOfTwo` on these types. */ - private def findTightestCommonType(types: Seq[DataType]) = { + private def findTightestCommonType(types: Seq[DataType]): Option[DataType] = { types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { case None => None case Some(d) => findTightestCommonTypeOfTwo(d, c) }) } + private def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] = (t1, t2) match { + case (t1: DecimalType, t2: DecimalType) => + Some(DecimalPrecision.widerDecimalType(t1, t2)) + case (t: IntegralType, d: DecimalType) => + Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) + case (d: DecimalType, t: IntegralType) => + Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) + case (t: FractionalType, d: DecimalType) => + Some(DoubleType) + case (d: DecimalType, t: FractionalType) => + Some(DoubleType) + case _ => + findTightestCommonTypeToString(t1, t2) + } + + private def findWiderCommonType(types: Seq[DataType]) = { + types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { + case Some(d) => findWiderTypeForTwo(d, c) + case None => None + }) + } + /** * Applies any changes to [[AttributeReference]] data types that are made by other rules to * instances higher in the query tree. @@ -182,20 +204,7 @@ object HiveTypeCoercion { val castedTypes = left.output.zip(right.output).map { case (lhs, rhs) if lhs.dataType != rhs.dataType => - (lhs.dataType, rhs.dataType) match { - case (t1: DecimalType, t2: DecimalType) => - Some(DecimalPrecision.widerDecimalType(t1, t2)) - case (t: IntegralType, d: DecimalType) => - Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) - case (d: DecimalType, t: IntegralType) => - Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) - case (t: FractionalType, d: DecimalType) => - Some(DoubleType) - case (d: DecimalType, t: FractionalType) => - Some(DoubleType) - case _ => - findTightestCommonTypeToString(lhs.dataType, rhs.dataType) - } + findWiderTypeForTwo(lhs.dataType, rhs.dataType) case other => None } @@ -236,8 +245,13 @@ object HiveTypeCoercion { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case a @ BinaryArithmetic(left @ StringType(), r) => - a.makeCopy(Array(Cast(left, DoubleType), r)) + case a @ BinaryArithmetic(left @ StringType(), right @ DecimalType.Expression(_, _)) => + a.makeCopy(Array(Cast(left, DecimalType.SYSTEM_DEFAULT), right)) + case a @ BinaryArithmetic(left @ DecimalType.Expression(_, _), right @ StringType()) => + a.makeCopy(Array(left, Cast(right, DecimalType.SYSTEM_DEFAULT))) + + case a @ BinaryArithmetic(left @ StringType(), right) => + a.makeCopy(Array(Cast(left, DoubleType), right)) case a @ BinaryArithmetic(left, right @ StringType()) => a.makeCopy(Array(left, Cast(right, DoubleType))) @@ -543,7 +557,7 @@ object HiveTypeCoercion { // compatible with every child column. case c @ Coalesce(es) if es.map(_.dataType).distinct.size > 1 => val types = es.map(_.dataType) - findTightestCommonTypeAndPromoteToString(types) match { + findWiderCommonType(types) match { case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType))) case None => c } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 4589facb49b766973528c2854c0193012ba30fbb..221b4e92f086cce0d6f431b9234eb1c10e037d74 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -145,11 +145,11 @@ class AnalysisSuite extends AnalysisTest { 'e / 'e as 'div5)) val pl = plan.asInstanceOf[Project].projectList - // StringType will be promoted into Double assert(pl(0).dataType == DoubleType) assert(pl(1).dataType == DoubleType) assert(pl(2).dataType == DoubleType) - assert(pl(3).dataType == DoubleType) + // StringType will be promoted into Decimal(38, 18) + assert(pl(3).dataType == DecimalType(38, 29)) assert(pl(4).dataType == DoubleType) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 21256704a5b165190047bbe086ed9734cc689c27..8cf2ef5957d8d72b9edd130c705babd20925c0a2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -216,7 +216,8 @@ class MathExpressionsSuite extends QueryTest { checkAnswer( ctx.sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " + s"round($pi, 0), round($pi, 1), round($pi, 2), round($pi, 3)"), - Seq(Row(0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142)) + Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3), + BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142"))) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 42724ed766af56eb34e97f4bf33c560742cabe58..d13dde1cdc8b21b576d6f1f0cde2aeb8df4c0a7c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -368,7 +368,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Row(1)) checkAnswer( sql("SELECT COALESCE(null, 1, 1.5)"), - Row(1.toDouble)) + Row(BigDecimal(1))) checkAnswer( sql("SELECT COALESCE(null, null, null)"), Row(null)) @@ -1234,19 +1234,19 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { test("Floating point number format") { checkAnswer( - sql("SELECT 0.3"), Row(0.3) + sql("SELECT 0.3"), Row(BigDecimal(0.3).underlying()) ) checkAnswer( - sql("SELECT -0.8"), Row(-0.8) + sql("SELECT -0.8"), Row(BigDecimal(-0.8).underlying()) ) checkAnswer( - sql("SELECT .5"), Row(0.5) + sql("SELECT .5"), Row(BigDecimal(0.5)) ) checkAnswer( - sql("SELECT -.18"), Row(-0.18) + sql("SELECT -.18"), Row(BigDecimal(-0.18)) ) } @@ -1279,11 +1279,11 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { ) checkAnswer( - sql("SELECT -5.2"), Row(-5.2) + sql("SELECT -5.2"), Row(BigDecimal(-5.2)) ) checkAnswer( - sql("SELECT +6.8"), Row(6.8) + sql("SELECT +6.8"), Row(BigDecimal(6.8)) ) checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 3ac312d6f4c506de8bd9af7f9fe73753991b9404..f19f22fca7d54c7b65b67b489d0df99026d5b0a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -422,14 +422,14 @@ class JsonSuite extends QueryTest with TestJsonData { Row(-89) :: Row(21474836370L) :: Row(21474836470L) :: Nil ) - // Widening to DoubleType + // Widening to DecimalType checkAnswer( sql("select num_num_2 + 1.3 from jsonTable where num_num_2 > 1.1"), - Row(21474836472.2) :: - Row(92233720368547758071.3) :: Nil + Row(BigDecimal("21474836472.2")) :: + Row(BigDecimal("92233720368547758071.3")) :: Nil ) - // Widening to DoubleType + // Widening to Double checkAnswer( sql("select num_num_3 + 1.2 from jsonTable where num_num_3 > 1.1"), Row(101.2) :: Row(21474836471.2) :: Nil @@ -438,13 +438,13 @@ class JsonSuite extends QueryTest with TestJsonData { // Number and String conflict: resolve the type as number in this query. checkAnswer( sql("select num_str + 1.2 from jsonTable where num_str > 14"), - Row(92233720368547758071.2) + Row(BigDecimal("92233720368547758071.2")) ) // Number and String conflict: resolve the type as number in this query. checkAnswer( sql("select num_str + 1.2 from jsonTable where num_str >= 92233720368547758060"), - Row(new java.math.BigDecimal("92233720368547758071.2").doubleValue) + Row(new java.math.BigDecimal("92233720368547758071.2")) ) // String and Boolean conflict: resolve the type as string. @@ -503,7 +503,7 @@ class JsonSuite extends QueryTest with TestJsonData { // Number and String conflict: resolve the type as number in this query. checkAnswer( sql("select num_str + 1.2 from jsonTable where num_str > 13"), - Row(14.3) :: Row(92233720368547758071.2) :: Nil + Row(BigDecimal("14.3")) :: Row(BigDecimal("92233720368547758071.2")) :: Nil ) }