diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala index 356e088d1d665fe11a5a49ab44c166d30aaa35de..8dd4f2c59243e6db263a82f1e37d64435f108207 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala @@ -57,7 +57,7 @@ case class Percentile( child: Expression, percentageExpression: Expression, mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[OpenHashMap[Number, Long]] { + inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[OpenHashMap[AnyRef, Long]] { def this(child: Expression, percentageExpression: Expression) = { this(child, percentageExpression, 0, 0) @@ -123,13 +123,18 @@ case class Percentile( } } - override def createAggregationBuffer(): OpenHashMap[Number, Long] = { + private def toDoubleValue(d: Any): Double = d match { + case d: Decimal => d.toDouble + case n: Number => n.doubleValue + } + + override def createAggregationBuffer(): OpenHashMap[AnyRef, Long] = { // Initialize new counts map instance here. - new OpenHashMap[Number, Long]() + new OpenHashMap[AnyRef, Long]() } - override def update(buffer: OpenHashMap[Number, Long], input: InternalRow): Unit = { - val key = child.eval(input).asInstanceOf[Number] + override def update(buffer: OpenHashMap[AnyRef, Long], input: InternalRow): Unit = { + val key = child.eval(input).asInstanceOf[AnyRef] // Null values are ignored in counts map. if (key != null) { @@ -137,30 +142,30 @@ case class Percentile( } } - override def merge(buffer: OpenHashMap[Number, Long], other: OpenHashMap[Number, Long]): Unit = { + override def merge(buffer: OpenHashMap[AnyRef, Long], other: OpenHashMap[AnyRef, Long]): Unit = { other.foreach { case (key, count) => buffer.changeValue(key, count, _ + count) } } - override def eval(buffer: OpenHashMap[Number, Long]): Any = { + override def eval(buffer: OpenHashMap[AnyRef, Long]): Any = { generateOutput(getPercentiles(buffer)) } - private def getPercentiles(buffer: OpenHashMap[Number, Long]): Seq[Double] = { + private def getPercentiles(buffer: OpenHashMap[AnyRef, Long]): Seq[Double] = { if (buffer.isEmpty) { return Seq.empty } val sortedCounts = buffer.toSeq.sortBy(_._1)( - child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[Number]]) + child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[AnyRef]]) val accumlatedCounts = sortedCounts.scanLeft(sortedCounts.head._1, 0L) { case ((key1, count1), (key2, count2)) => (key2, count1 + count2) }.tail val maxPosition = accumlatedCounts.last._2 - 1 percentages.map { percentile => - getPercentile(accumlatedCounts, maxPosition * percentile).doubleValue() + getPercentile(accumlatedCounts, maxPosition * percentile) } } @@ -180,7 +185,7 @@ case class Percentile( * This function has been based upon similar function from HIVE * `org.apache.hadoop.hive.ql.udf.UDAFPercentile.getPercentile()`. */ - private def getPercentile(aggreCounts: Seq[(Number, Long)], position: Double): Number = { + private def getPercentile(aggreCounts: Seq[(AnyRef, Long)], position: Double): Double = { // We may need to do linear interpolation to get the exact percentile val lower = position.floor.toLong val higher = position.ceil.toLong @@ -193,18 +198,17 @@ case class Percentile( val lowerKey = aggreCounts(lowerIndex)._1 if (higher == lower) { // no interpolation needed because position does not have a fraction - return lowerKey + return toDoubleValue(lowerKey) } val higherKey = aggreCounts(higherIndex)._1 if (higherKey == lowerKey) { // no interpolation needed because lower position and higher position has the same key - return lowerKey + return toDoubleValue(lowerKey) } // Linear interpolation to get the exact percentile - return (higher - position) * lowerKey.doubleValue() + - (position - lower) * higherKey.doubleValue() + (higher - position) * toDoubleValue(lowerKey) + (position - lower) * toDoubleValue(higherKey) } /** @@ -218,7 +222,7 @@ case class Percentile( } } - override def serialize(obj: OpenHashMap[Number, Long]): Array[Byte] = { + override def serialize(obj: OpenHashMap[AnyRef, Long]): Array[Byte] = { val buffer = new Array[Byte](4 << 10) // 4K val bos = new ByteArrayOutputStream() val out = new DataOutputStream(bos) @@ -241,11 +245,11 @@ case class Percentile( } } - override def deserialize(bytes: Array[Byte]): OpenHashMap[Number, Long] = { + override def deserialize(bytes: Array[Byte]): OpenHashMap[AnyRef, Long] = { val bis = new ByteArrayInputStream(bytes) val ins = new DataInputStream(bis) try { - val counts = new OpenHashMap[Number, Long] + val counts = new OpenHashMap[AnyRef, Long] // Read unsafeRow size and content in bytes. var sizeOfNextRow = ins.readInt() while (sizeOfNextRow >= 0) { @@ -254,7 +258,7 @@ case class Percentile( val row = new UnsafeRow(2) row.pointTo(bs, sizeOfNextRow) // Insert the pairs into counts map. - val key = row.get(0, child.dataType).asInstanceOf[Number] + val key = row.get(0, child.dataType) val count = row.get(1, LongType).asInstanceOf[Long] counts.update(key, count) sizeOfNextRow = ins.readInt() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala index f060ecc18426ab37d51c102a212b30b9d3292491..d7c25271f3567881ede19c7be5e13ec928a1d9c0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala @@ -38,12 +38,12 @@ class PercentileSuite extends SparkFunSuite { val agg = new Percentile(BoundReference(0, IntegerType, true), Literal(0.5)) // Check empty serialize and deserialize - val buffer = new OpenHashMap[Number, Long]() + val buffer = new OpenHashMap[AnyRef, Long]() assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer)) // Check non-empty buffer serializa and deserialize. data.foreach { key => - buffer.changeValue(key, 1L, _ + 1L) + buffer.changeValue(new Integer(key), 1L, _ + 1L) } assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer)) } @@ -233,7 +233,7 @@ class PercentileSuite extends SparkFunSuite { } private def compareEquals( - left: OpenHashMap[Number, Long], right: OpenHashMap[Number, Long]): Boolean = { + left: OpenHashMap[AnyRef, Long], right: OpenHashMap[AnyRef, Long]): Boolean = { left.size == right.size && left.forall { case (key, count) => right.apply(key) == count } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 312cd17c26d609b7392d8b05908ff68039166c6a..22dfc46acfc0f9277eba2565c133797dc34dbf61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1734,4 +1734,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val df = spark.createDataFrame(spark.sparkContext.makeRDD(rows), schema) assert(df.filter($"array1" === $"array2").count() == 1) } + + test("SPARK-19691 Calculating percentile of decimal column fails with ClassCastException") { + val df = spark.range(1).selectExpr("CAST(id as DECIMAL) as x").selectExpr("percentile(x, 0.5)") + checkAnswer(df, Row(BigDecimal(0.0)) :: Nil) + } }