Skip to content
Snippets Groups Projects
Commit e99e34d0 authored by gagan taneja's avatar gagan taneja Committed by Herman van Hovell
Browse files

[SPARK-19118][SQL] Percentile support for frequency distribution table

## What changes were proposed in this pull request?

I have a frequency distribution table with following entries
Age,    No of person
21, 10
22, 15
23, 18
..
..
30, 14
Moreover it is common to have data in frequency distribution format to further calculate Percentile, Median. With current implementation
It would be very difficult and complex to find the percentile.
Therefore i am proposing enhancement to current Percentile and Approx Percentile implementation to take frequency distribution column into consideration

## How was this patch tested?
1) Enhanced /sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala to cover the additional functionality
2) Run some performance benchmark test with 20 million row in local environment and did not see any performance degradation

Please review http://spark.apache.org/contributing.html before opening a pull request.

Author: gagan taneja <tanejagagan@gagans-MacBook-Pro.local>

Closes #16497 from tanejagagan/branch-18940.
parent 3d314d08
No related branches found
No related tags found
No related merge requests found
......@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashMap
import org.apache.spark.SparkException
/**
* The Percentile aggregate function returns the exact percentile(s) of numeric column `expr` at
......@@ -44,22 +45,30 @@ import org.apache.spark.util.collection.OpenHashMap
@ExpressionDescription(
usage =
"""
_FUNC_(col, percentage) - Returns the exact percentile value of numeric column `col` at the
given percentage. The value of percentage must be between 0.0 and 1.0.
_FUNC_(col, percentage [, frequency]) - Returns the exact percentile value of numeric column
`col` at the given percentage. The value of percentage must be between 0.0 and 1.0. The
value of frequency should be positive integral
_FUNC_(col, array(percentage1 [, percentage2]...)) - Returns the exact percentile value array
of numeric column `col` at the given percentage(s). Each value of the percentage array must
be between 0.0 and 1.0.
""")
_FUNC_(col, array(percentage1 [, percentage2]...) [, frequency]) - Returns the exact
percentile value array of numeric column `col` at the given percentage(s). Each value
of the percentage array must be between 0.0 and 1.0. The value of frequency should be
positive integral
""")
case class Percentile(
child: Expression,
percentageExpression: Expression,
frequencyExpression : Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends TypedImperativeAggregate[OpenHashMap[Number, Long]] with ImplicitCastInputTypes {
def this(child: Expression, percentageExpression: Expression) = {
this(child, percentageExpression, 0, 0)
this(child, percentageExpression, Literal(1L), 0, 0)
}
def this(child: Expression, percentageExpression: Expression, frequency: Expression) = {
this(child, percentageExpression, frequency, 0, 0)
}
override def prettyName: String = "percentile"
......@@ -80,7 +89,9 @@ case class Percentile(
case arrayData: ArrayData => arrayData.toDoubleArray().toSeq
}
override def children: Seq[Expression] = child :: percentageExpression :: Nil
override def children: Seq[Expression] = {
child :: percentageExpression ::frequencyExpression :: Nil
}
// Returns null for empty inputs
override def nullable: Boolean = true
......@@ -90,9 +101,12 @@ case class Percentile(
case _ => DoubleType
}
override def inputTypes: Seq[AbstractDataType] = percentageExpression.dataType match {
case _: ArrayType => Seq(NumericType, ArrayType(DoubleType))
case _ => Seq(NumericType, DoubleType)
override def inputTypes: Seq[AbstractDataType] = {
val percentageExpType = percentageExpression.dataType match {
case _: ArrayType => ArrayType(DoubleType)
case _ => DoubleType
}
Seq(NumericType, percentageExpType, IntegralType)
}
// Check the inputTypes are valid, and the percentageExpression satisfies:
......@@ -125,10 +139,17 @@ case class Percentile(
buffer: OpenHashMap[Number, Long],
input: InternalRow): OpenHashMap[Number, Long] = {
val key = child.eval(input).asInstanceOf[Number]
val frqValue = frequencyExpression.eval(input)
// Null values are ignored in counts map.
if (key != null) {
buffer.changeValue(key, 1L, _ + 1L)
if (key != null && frqValue != null) {
val frqLong = frqValue.asInstanceOf[Number].longValue()
// add only when frequency is positive
if (frqLong > 0) {
buffer.changeValue(key, frqLong, _ + frqLong)
} else if (frqLong < 0) {
throw new SparkException(s"Negative values found in ${frequencyExpression.sql}")
}
}
buffer
}
......
......@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.SparkException
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult._
......@@ -50,25 +51,50 @@ class PercentileSuite extends SparkFunSuite {
test("class Percentile, high level interface, update, merge, eval...") {
val count = 10000
val data = (1 to count)
val percentages = Seq(0, 0.25, 0.5, 0.75, 1)
val expectedPercentiles = Seq(1, 2500.75, 5000.5, 7500.25, 10000)
val childExpression = Cast(BoundReference(0, IntegerType, nullable = false), DoubleType)
val percentageExpression = CreateArray(percentages.toSeq.map(Literal(_)))
val agg = new Percentile(childExpression, percentageExpression)
// Test with rows without frequency
val rows = (1 to count).map( x => Seq(x))
runTest( agg, rows, expectedPercentiles)
// Test with row with frequency. Second and third columns are frequency in Int and Long
val countForFrequencyTest = 1000
val rowsWithFrequency = (1 to countForFrequencyTest).map( x => Seq(x, x):+ x.toLong)
val expectedPercentilesWithFrquency = Seq(1.0, 500.0, 707.0, 866.0, 1000.0)
val frequencyExpressionInt = BoundReference(1, IntegerType, nullable = false)
val aggInt = new Percentile(childExpression, percentageExpression, frequencyExpressionInt)
runTest( aggInt, rowsWithFrequency, expectedPercentilesWithFrquency)
val frequencyExpressionLong = BoundReference(2, LongType, nullable = false)
val aggLong = new Percentile(childExpression, percentageExpression, frequencyExpressionLong)
runTest( aggLong, rowsWithFrequency, expectedPercentilesWithFrquency)
// Run test with Flatten data
val flattenRows = (1 to countForFrequencyTest).flatMap( current =>
(1 to current).map( y => current )).map( Seq(_))
runTest(agg, flattenRows, expectedPercentilesWithFrquency)
}
private def runTest(agg: Percentile,
rows : Seq[Seq[Any]],
expectedPercentiles : Seq[Double]) {
assert(agg.nullable)
val group1 = (0 until data.length / 2)
val group1 = (0 until rows.length / 2)
val group1Buffer = agg.createAggregationBuffer()
group1.foreach { index =>
val input = InternalRow(data(index))
val input = InternalRow(rows(index): _*)
agg.update(group1Buffer, input)
}
val group2 = (data.length / 2 until data.length)
val group2 = (rows.length / 2 until rows.length)
val group2Buffer = agg.createAggregationBuffer()
group2.foreach { index =>
val input = InternalRow(data(index))
val input = InternalRow(rows(index): _*)
agg.update(group2Buffer, input)
}
......@@ -116,40 +142,6 @@ class PercentileSuite extends SparkFunSuite {
assert(agg.eval(mutableAggBuffer).asInstanceOf[Double] == expectedPercentile)
}
test("call from sql query") {
// sql, single percentile
assertEqual(
s"percentile(`a`, 0.5D)",
new Percentile("a".attr, Literal(0.5)).sql: String)
// sql, array of percentile
assertEqual(
s"percentile(`a`, array(0.25D, 0.5D, 0.75D))",
new Percentile("a".attr, CreateArray(Seq(0.25, 0.5, 0.75).map(Literal(_)))).sql: String)
// sql(isDistinct = false), single percentile
assertEqual(
s"percentile(`a`, 0.5D)",
new Percentile("a".attr, Literal(0.5)).sql(isDistinct = false))
// sql(isDistinct = false), array of percentile
assertEqual(
s"percentile(`a`, array(0.25D, 0.5D, 0.75D))",
new Percentile("a".attr, CreateArray(Seq(0.25, 0.5, 0.75).map(Literal(_))))
.sql(isDistinct = false))
// sql(isDistinct = true), single percentile
assertEqual(
s"percentile(DISTINCT `a`, 0.5D)",
new Percentile("a".attr, Literal(0.5)).sql(isDistinct = true))
// sql(isDistinct = true), array of percentile
assertEqual(
s"percentile(DISTINCT `a`, array(0.25D, 0.5D, 0.75D))",
new Percentile("a".attr, CreateArray(Seq(0.25, 0.5, 0.75).map(Literal(_))))
.sql(isDistinct = true))
}
test("fail analysis if childExpression is invalid") {
val validDataTypes = Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType)
val percentage = Literal(0.5)
......@@ -160,6 +152,15 @@ class PercentileSuite extends SparkFunSuite {
assertEqual(percentile.checkInputDataTypes(), TypeCheckSuccess)
}
val validFrequencyTypes = Seq(ByteType, ShortType, IntegerType, LongType)
for ( dataType <- validDataTypes;
frequencyType <- validFrequencyTypes) {
val child = AttributeReference("a", dataType)()
val frq = AttributeReference("frq", frequencyType)()
val percentile = new Percentile(child, percentage, frq)
assertEqual(percentile.checkInputDataTypes(), TypeCheckSuccess)
}
val invalidDataTypes = Seq(BooleanType, StringType, DateType, TimestampType,
CalendarIntervalType, NullType)
......@@ -170,6 +171,30 @@ class PercentileSuite extends SparkFunSuite {
TypeCheckFailure(s"argument 1 requires numeric type, however, " +
s"'`a`' is of ${dataType.simpleString} type."))
}
val invalidFrequencyDataTypes = Seq(FloatType, DoubleType, BooleanType,
StringType, DateType, TimestampType,
CalendarIntervalType, NullType)
for( dataType <- invalidDataTypes;
frequencyType <- validFrequencyTypes) {
val child = AttributeReference("a", dataType)()
val frq = AttributeReference("frq", frequencyType)()
val percentile = new Percentile(child, percentage, frq)
assertEqual(percentile.checkInputDataTypes(),
TypeCheckFailure(s"argument 1 requires numeric type, however, " +
s"'`a`' is of ${dataType.simpleString} type."))
}
for( dataType <- validDataTypes;
frequencyType <- invalidFrequencyDataTypes) {
val child = AttributeReference("a", dataType)()
val frq = AttributeReference("frq", frequencyType)()
val percentile = new Percentile(child, percentage, frq)
assertEqual(percentile.checkInputDataTypes(),
TypeCheckFailure(s"argument 3 requires integral type, however, " +
s"'`frq`' is of ${frequencyType.simpleString} type."))
}
}
test("fails analysis if percentage(s) are invalid") {
......@@ -217,19 +242,59 @@ class PercentileSuite extends SparkFunSuite {
}
test("null handling") {
// Percentile without frequency column
val childExpression = Cast(BoundReference(0, IntegerType, nullable = true), DoubleType)
val agg = new Percentile(childExpression, Literal(0.5))
val buffer = new GenericInternalRow(new Array[Any](1))
agg.initialize(buffer)
// Empty aggregation buffer
assert(agg.eval(buffer) == null)
// Empty input row
agg.update(buffer, InternalRow(null))
assert(agg.eval(buffer) == null)
// Add some non-empty row
agg.update(buffer, InternalRow(0))
assert(agg.eval(buffer) != null)
// Percentile with Frequency column
val frequencyExpression = Cast(BoundReference(1, IntegerType, nullable = true), IntegerType)
val aggWithFrequency = new Percentile(childExpression, Literal(0.5), frequencyExpression)
val bufferWithFrequency = new GenericInternalRow(new Array[Any](2))
aggWithFrequency.initialize(bufferWithFrequency)
// Empty aggregation buffer
assert(aggWithFrequency.eval(bufferWithFrequency) == null)
// Empty input row
aggWithFrequency.update(bufferWithFrequency, InternalRow(null, null))
assert(aggWithFrequency.eval(bufferWithFrequency) == null)
// Add some non-empty row with empty frequency column
aggWithFrequency.update(bufferWithFrequency, InternalRow(0, null))
assert(aggWithFrequency.eval(bufferWithFrequency) == null)
// Add some non-empty row with zero frequency
aggWithFrequency.update(bufferWithFrequency, InternalRow(1, 0))
assert(aggWithFrequency.eval(bufferWithFrequency) == null)
// Add some non-empty row with positive frequency
aggWithFrequency.update(bufferWithFrequency, InternalRow(0, 1))
assert(aggWithFrequency.eval(bufferWithFrequency) != null)
}
test("negatives frequency column handling") {
val childExpression = Cast(BoundReference(0, IntegerType, nullable = true), DoubleType)
val freqExpression = Cast(BoundReference(1, IntegerType, nullable = true), IntegerType)
val agg = new Percentile(childExpression, Literal(0.5), freqExpression)
val buffer = new GenericInternalRow(new Array[Any](2))
agg.initialize(buffer)
val caught =
intercept[SparkException]{
// Add some non-empty row with negative frequency
agg.update(buffer, InternalRow(1, -5))
agg.eval(buffer)
}
assert( caught.getMessage.startsWith("Negative values found in "))
}
private def compareEquals(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment