diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index afecf881c7440e0293cb3f27423d24abade934c1..5eb5b0d176fc1a81b8cdf8049aafe4b7db2e8fce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -67,6 +67,19 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { (DoublePrefixComparator.computePrefix(Double.NegativeInfinity), s"$DoublePrefixCmp.computePrefix((double)$input)") case StringType => (0L, s"$input.getPrefix()") + case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS => + val prefix = if (dt.precision <= Decimal.MAX_LONG_DIGITS) { + s"$input.toUnscaledLong()" + } else { + // reduce the scale to fit in a long + val p = Decimal.MAX_LONG_DIGITS + val s = p - (dt.precision - dt.scale) + s"$input.changePrecision($p, $s) ? $input.toUnscaledLong() : ${Long.MinValue}L" + } + (Long.MinValue, prefix) + case dt: DecimalType => + (DoublePrefixComparator.computePrefix(Double.NegativeInfinity), + s"$DoublePrefixCmp.computePrefix($input.toDouble())") case _ => (0L, "0L") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index 81267dc915c102d285c63d246fdd31405f2185e7..ea1fd23d0dbce2976606a2385746caa3a89f20ec 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -107,7 +107,10 @@ object RandomDataGenerator { case DateType => Some(() => new java.sql.Date(rand.nextInt())) case TimestampType => Some(() => new java.sql.Timestamp(rand.nextLong())) case DecimalType.Fixed(precision, scale) => Some( - () => BigDecimal.apply(rand.nextLong(), rand.nextInt(), new MathContext(precision))) + () => BigDecimal.apply( + rand.nextLong() % math.pow(10, precision).toLong, + scale, + new MathContext(precision))) case DoubleType => randomNumeric[Double]( rand, r => longBitsToDouble(r.nextLong()), Seq(Double.MinValue, Double.MinPositiveValue, Double.MaxValue, Double.PositiveInfinity, Double.NegativeInfinity, Double.NaN, 0.0)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala index 0ee9ddac815b81e97952f16ccec2d5cdf5683feb..417df006ab7c2778c3c70245b132c065175cdaaf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala @@ -34,8 +34,9 @@ object DataTypeTestUtils { * decimal types. */ val fractionalTypes: Set[FractionalType] = Set( + DecimalType.USER_DEFAULT, + DecimalType(20, 5), DecimalType.SYSTEM_DEFAULT, - DecimalType(2, 1), DoubleType, FloatType ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala index 676656518f4c0491de8491bb905ff2ad0e1c3581..2e870ec8ae965597f1026e11df95d46bc4eeea90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala @@ -36,16 +36,16 @@ object SortPrefixUtils { def getPrefixComparator(sortOrder: SortOrder): PrefixComparator = { sortOrder.dataType match { - case StringType if sortOrder.isAscending => PrefixComparators.STRING - case StringType if !sortOrder.isAscending => PrefixComparators.STRING_DESC - case BooleanType | ByteType | ShortType | IntegerType | LongType | DateType | TimestampType - if sortOrder.isAscending => - PrefixComparators.LONG - case BooleanType | ByteType | ShortType | IntegerType | LongType | DateType | TimestampType - if !sortOrder.isAscending => - PrefixComparators.LONG_DESC - case FloatType | DoubleType if sortOrder.isAscending => PrefixComparators.DOUBLE - case FloatType | DoubleType if !sortOrder.isAscending => PrefixComparators.DOUBLE_DESC + case StringType => + if (sortOrder.isAscending) PrefixComparators.STRING else PrefixComparators.STRING_DESC + case BooleanType | ByteType | ShortType | IntegerType | LongType | DateType | TimestampType => + if (sortOrder.isAscending) PrefixComparators.LONG else PrefixComparators.LONG_DESC + case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS => + if (sortOrder.isAscending) PrefixComparators.LONG else PrefixComparators.LONG_DESC + case FloatType | DoubleType => + if (sortOrder.isAscending) PrefixComparators.DOUBLE else PrefixComparators.DOUBLE_DESC + case dt: DecimalType => + if (sortOrder.isAscending) PrefixComparators.DOUBLE else PrefixComparators.DOUBLE_DESC case _ => NoOpPrefixComparator } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala index b3f821e0cdd3794577879fcb6948ca4fd77d5336..c7949848513cf404d488dd8ac69a453e751ef612 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala @@ -61,8 +61,7 @@ class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll { // Test sorting on different data types for ( - dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType) - if !dataType.isInstanceOf[DecimalType]; // We don't have an unsafe representation for decimals + dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType); nullable <- Seq(true, false); sortOrder <- Seq('a.asc :: Nil, 'a.desc :: Nil); randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable)