From 16b928c5436b9b500d25b49bf3670bc50ddafbf9 Mon Sep 17 00:00:00 2001
From: Davies Liu <davies@databricks.com>
Date: Sat, 1 Aug 2015 23:36:06 -0700
Subject: [PATCH] [SPARK-9529] [SQL] improve TungstenSort on DecimalType

Generate prefix for DecimalType, fix the random generator of decimal

cc JoshRosen

Author: Davies Liu <davies@databricks.com>

Closes #7857 from davies/sort_decimal and squashes the following commits:

2433959 [Davies Liu] Merge branch 'master' of github.com:apache/spark into sort_decimal
de24253 [Davies Liu] fix style
0a54c1a [Davies Liu] sort decimal
---
 .../sql/catalyst/expressions/SortOrder.scala  | 13 ++++++++++++
 .../spark/sql/RandomDataGenerator.scala       |  5 ++++-
 .../spark/sql/types/DataTypeTestUtils.scala   |  3 ++-
 .../spark/sql/execution/SortPrefixUtils.scala | 20 +++++++++----------
 .../sql/execution/TungstenSortSuite.scala     |  3 +--
 5 files changed, 30 insertions(+), 14 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
index afecf881c7..5eb5b0d176 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
@@ -67,6 +67,19 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression {
         (DoublePrefixComparator.computePrefix(Double.NegativeInfinity),
           s"$DoublePrefixCmp.computePrefix((double)$input)")
       case StringType => (0L, s"$input.getPrefix()")
+      case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS =>
+        val prefix = if (dt.precision <= Decimal.MAX_LONG_DIGITS) {
+          s"$input.toUnscaledLong()"
+        } else {
+          // reduce the scale to fit in a long
+          val p = Decimal.MAX_LONG_DIGITS
+          val s = p - (dt.precision - dt.scale)
+          s"$input.changePrecision($p, $s) ? $input.toUnscaledLong() : ${Long.MinValue}L"
+        }
+        (Long.MinValue, prefix)
+      case dt: DecimalType =>
+        (DoublePrefixComparator.computePrefix(Double.NegativeInfinity),
+          s"$DoublePrefixCmp.computePrefix($input.toDouble())")
       case _ => (0L, "0L")
     }
 
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
index 81267dc915..ea1fd23d0d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
@@ -107,7 +107,10 @@ object RandomDataGenerator {
       case DateType => Some(() => new java.sql.Date(rand.nextInt()))
       case TimestampType => Some(() => new java.sql.Timestamp(rand.nextLong()))
       case DecimalType.Fixed(precision, scale) => Some(
-        () => BigDecimal.apply(rand.nextLong(), rand.nextInt(), new MathContext(precision)))
+        () => BigDecimal.apply(
+          rand.nextLong() % math.pow(10, precision).toLong,
+          scale,
+          new MathContext(precision)))
       case DoubleType => randomNumeric[Double](
         rand, r => longBitsToDouble(r.nextLong()), Seq(Double.MinValue, Double.MinPositiveValue,
           Double.MaxValue, Double.PositiveInfinity, Double.NegativeInfinity, Double.NaN, 0.0))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala
index 0ee9ddac81..417df006ab 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala
@@ -34,8 +34,9 @@ object DataTypeTestUtils {
    * decimal types.
    */
   val fractionalTypes: Set[FractionalType] = Set(
+    DecimalType.USER_DEFAULT,
+    DecimalType(20, 5),
     DecimalType.SYSTEM_DEFAULT,
-    DecimalType(2, 1),
     DoubleType,
     FloatType
   )
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
index 676656518f..2e870ec8ae 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
@@ -36,16 +36,16 @@ object SortPrefixUtils {
 
   def getPrefixComparator(sortOrder: SortOrder): PrefixComparator = {
     sortOrder.dataType match {
-      case StringType if sortOrder.isAscending => PrefixComparators.STRING
-      case StringType if !sortOrder.isAscending => PrefixComparators.STRING_DESC
-      case BooleanType | ByteType | ShortType | IntegerType | LongType | DateType | TimestampType
-          if sortOrder.isAscending =>
-        PrefixComparators.LONG
-      case BooleanType | ByteType | ShortType | IntegerType | LongType | DateType | TimestampType
-        if !sortOrder.isAscending =>
-        PrefixComparators.LONG_DESC
-      case FloatType | DoubleType if sortOrder.isAscending => PrefixComparators.DOUBLE
-      case FloatType | DoubleType if !sortOrder.isAscending => PrefixComparators.DOUBLE_DESC
+      case StringType =>
+        if (sortOrder.isAscending) PrefixComparators.STRING else PrefixComparators.STRING_DESC
+      case BooleanType | ByteType | ShortType | IntegerType | LongType | DateType | TimestampType =>
+        if (sortOrder.isAscending) PrefixComparators.LONG else PrefixComparators.LONG_DESC
+      case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS =>
+        if (sortOrder.isAscending) PrefixComparators.LONG else PrefixComparators.LONG_DESC
+      case FloatType | DoubleType =>
+        if (sortOrder.isAscending) PrefixComparators.DOUBLE else PrefixComparators.DOUBLE_DESC
+      case dt: DecimalType =>
+        if (sortOrder.isAscending) PrefixComparators.DOUBLE else PrefixComparators.DOUBLE_DESC
       case _ => NoOpPrefixComparator
     }
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
index b3f821e0cd..c794984851 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
@@ -61,8 +61,7 @@ class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll {
 
   // Test sorting on different data types
   for (
-    dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType)
-    if !dataType.isInstanceOf[DecimalType]; // We don't have an unsafe representation for decimals
+    dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType);
     nullable <- Seq(true, false);
     sortOrder <- Seq('a.asc :: Nil, 'a.desc :: Nil);
     randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable)
-- 
GitLab